"tests/pytorch/attention/test_attention.py" did not exist on "30cad990d09fce3c37951d09c6ec085c1216a313"
layernorm_mlp.py 19.4 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX MLP modules"""

6
from typing import List, Tuple, Sequence, Union, Callable
7
from functools import partial
8
9
10

import jax
import jax.numpy as jnp
11
from jax.ad_checkpoint import checkpoint_name
12

13
from . import cpp_extensions as tex
14
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize
15
from .layernorm import canonicalize_layernorm_type
16
from .fp8 import FP8Helper, FP8MetaPackage
17
18
from .sharding import with_sharding_constraint_by_logical_axes

19

20
def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
21
    """
22
    Activation Unit
23
    """
24
    if len(activation_type) > 1:
25
        assert x.shape[-2] == 2    # Linear + GeLU
26
    output = _activation_lu(x, activation_type)
27
28
    return output

29

30
31
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
32

33
    _output, _ = _activation_lu_fwd_rule(x, activation_type)
34

35
    return _output
36

37

38
def _activation_lu_fwd_rule(x, activation_type):
39
    fwd_output = tex.act_lu(x, activation_type)
40
    return fwd_output, (x,)
41

42

43
def _activation_lu_bwd_rule(activation_type, ctx, g):
44
45
    x, = ctx
    assert x.dtype == g.dtype
46

47
    dx = tex.dact_lu(g, x, activation_type)
48
49
    dx = jnp.reshape(dx, x.shape)
    return (dx,)
50

51

52
_activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule)
53

54

55
def fused_layernorm_fp8_mlp(x: jnp.ndarray,
56
57
58
59
                            gamma: jnp.ndarray,
                            beta: jnp.ndarray,
                            kernels: List[jnp.ndarray],
                            biases: List[jnp.ndarray],
60
                            fp8_meta_pkgs: List[FP8MetaPackage],
61
62
63
64
65
66
67
68
69
70
                            layernorm_type: str,
                            zero_centered_gamma: bool = False,
                            epsilon: float = 1e-6,
                            layernorm_input_axes: Tuple[str, ...] = None,
                            dot_1_input_axes: Tuple[str, ...] = None,
                            dot_2_input_axes: Tuple[str, ...] = None,
                            ffn1_ckpt_name: str = 'ffn1',
                            ffn2_ckpt_name: str = 'ffn2',
                            activation_type: Sequence[Union[str, Callable]] = ('gelu',),
                            use_bias: bool = True) -> jnp.ndarray:
71
    """
72
    Layernorm + GEMM1 + bias + activation + GEMM2 + bias
73
74
75
    """

    assert len(kernels) == 2
76
    assert len(fp8_meta_pkgs) == len(kernels)
77
78
79
80
81

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]
82
83
84
85
    amax_list_1 = fp8_meta_pkgs[0].amax_list
    amax_list_2 = fp8_meta_pkgs[1].amax_list
    scale_list_1 = fp8_meta_pkgs[0].scale_list
    scale_list_2 = fp8_meta_pkgs[1].scale_list
86
87
88
89
90
91
92
93
94
95

    fwd_dtype = FP8Helper.FWD_DTYPE
    bwd_dtype = FP8Helper.BWD_DTYPE

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
    if layernorm_type == 'rmsnorm':
        assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"

96
97
98
99
100
101
    output = _fused_layernorm_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2,
                                      amax_list_1, amax_list_2, scale_list_1, scale_list_2,
                                      fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma,
                                      epsilon, layernorm_input_axes, dot_1_input_axes,
                                      dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name,
                                      activation_type, use_bias)
102
103
104
    return output


105
106
@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22))
def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
107
                             kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray,
108
109
110
                             bias_2: jnp.ndarray, amax_list_1: List[jnp.ndarray],
                             amax_list_2: List[jnp.ndarray], scale_list_1: List[jnp.ndarray],
                             scale_list_2: List[jnp.ndarray], fwd_dtype: jnp.dtype,
111
112
113
114
115
116
                             bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool,
                             epsilon: float, layernorm_input_axes: Tuple[str, ...],
                             dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
                             ffn1_ckpt_name: str, ffn2_ckpt_name: str,
                             activation_type: Sequence[Union[str, Callable]], use_bias: bool):
    output, _ = _fused_layernorm_fp8_mlp_fwd_rule(
117
118
119
120
        x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, amax_list_1, amax_list_2, scale_list_1,
        scale_list_2, fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma, epsilon,
        layernorm_input_axes, dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name,
        activation_type, use_bias)
121
122
123
    return output


124
def _fused_layernorm_fp8_mlp_fwd_rule(
125
126
127
128
129
130
131
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
132
133
134
135
        amax_list_1,
        amax_list_2,
        scale_list_1,
        scale_list_2,
136
137
138
139
140
141
142
143
144
        fwd_dtype,
        bwd_dtype,    # pylint: disable=unused-argument
        layernorm_type,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,
145
146
147
        ffn2_ckpt_name,
        activation_type,
        use_bias):
148
149
150
151
152

    # x should be in shape of (batch..., hidden)
    # Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out)
    # Kernel_2 should be in shape of (Hidden_in, Hidden_out)
    assert len(kernel_1.shape) == 3
153
    assert kernel_1.shape[-2] == len(activation_type)
154
155
156
157
158
159
160
161
    assert len(kernel_2.shape) == 2

    x_contracting_dims = (len(x.shape) - 1,)
    xt_batch_dims = tuple(range(1, x.ndim))

    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
    assert kernel_1.shape[-1] == kernel_2.shape[0]

162
    maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        FP8Helper.generate_fp8_meta_dtype_converter_pair(*amax_list_1, *scale_list_1,
                                                         *amax_list_2, *scale_list_2)
    amax_list_1 = maybe_fm32_to_fp32(*amax_list_1)
    scale_list_1 = maybe_fm32_to_fp32(*scale_list_1)
    amax_list_2 = maybe_fm32_to_fp32(*amax_list_2)
    scale_list_2 = maybe_fm32_to_fp32(*scale_list_2)

    fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype]
    scale_list_1, scale_inv_list_1 = FP8MetaPackage.update_fp8_scale(amax_list_1, scale_list_1,
                                                                     fp8_dtype_list)
    amax_list_1 = FP8MetaPackage.update_amax_list(amax_list_1)
    scale_list_2, scale_inv_list_2 = FP8MetaPackage.update_fp8_scale(amax_list_2, scale_list_2,
                                                                     fp8_dtype_list)
    amax_list_2 = FP8MetaPackage.update_amax_list(amax_list_2)

    x_amax = amax_list_1[FP8MetaPackage.INPUT_IDX][0:1]
    x_scale = scale_list_1[FP8MetaPackage.INPUT_IDX]
    x_scale_inv = scale_inv_list_1[FP8MetaPackage.INPUT_IDX]
181
182
183
184

    x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)

    if layernorm_type == 'layernorm':
185
        ln_out, mu, rsigma, updated_x_amax = tex.layernorm_fwd_fp8(
186
187
188
189
190
191
192
193
194
195
196
197
            x,
            gamma,
            beta,
            x_amax,
            x_scale,
            x_scale_inv,
            out_dtype=fwd_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon)
    else:
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
198
        ln_out, rsigma, updated_x_amax = tex.rmsnorm_fwd_fp8(x,
199
200
201
202
203
204
205
206
207
208
                                                         gamma,
                                                         x_amax,
                                                         x_scale,
                                                         x_scale_inv,
                                                         out_dtype=fwd_dtype,
                                                         epsilon=epsilon)
        mu = None

    assert x.shape == ln_out.shape

209
210
211
    kernel_1_amax = amax_list_1[FP8MetaPackage.WEIGHT_IDX][0:1]
    kernel_1_scale = scale_list_1[FP8MetaPackage.WEIGHT_IDX]
    kernel_1_scale_inv = scale_inv_list_1[FP8MetaPackage.WEIGHT_IDX]
212
213
214
215

    # Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel_1, updated_kernel_1_amax = \
216
        tex.cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)
217
218
219
220
221
222
223

    ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes)

    # (batch..., hidden_in) x (hidden_in, hidden_out)
    dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
                                (x_contracting_dims, (0,)),
                                get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
224
    if use_bias:
225
226
227
228
229
        bias_1_shape = bias_1.shape
        bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
        dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
    else:
        bias_1_shape = None
230
231
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)

232
233
234
    activation_lu_out_amax = amax_list_2[FP8MetaPackage.INPUT_IDX][0:1]
    activation_lu_out_scale = scale_list_2[FP8MetaPackage.INPUT_IDX]
    activation_lu_out_scale_inv = scale_inv_list_2[FP8MetaPackage.INPUT_IDX]
235
236

    # (batch..., hidden_in) -> (batch..., hidden)
237
    casted_activation_lu_out, updated_activation_lu_amax = \
238
    tex.act_lu_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale,
239
               activation_lu_out_scale_inv, fwd_dtype, activation_type)
240

241
242
    casted_activation_lu_out = with_sharding_constraint_by_logical_axes(
        casted_activation_lu_out, dot_2_input_axes)
243

244
245
    kernel_2_scale = scale_list_2[FP8MetaPackage.WEIGHT_IDX]
    kernel_2_scale_inv = scale_inv_list_2[FP8MetaPackage.WEIGHT_IDX]
246
247
248
249
250
    # Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)

    # (batch..., hidden_in) x (hidden_out, hidden_in)
251
    dot_2_output = fp8_dot_impl(casted_activation_lu_out, casted_kernel_2,
252
253
                                activation_lu_out_scale_inv, kernel_2_scale_inv, x.dtype,
                                (x_contracting_dims, (0,)),
254
255
                                get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))

256
    if use_bias:
257
258
259
260
261
        bias_2_shape = bias_2.shape
        bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
        dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
    else:
        bias_2_shape = None
262

263
264
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

265
    ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1,
266
267
268
269
           casted_kernel_2, amax_list_1, amax_list_2, scale_list_1, scale_list_2, scale_inv_list_1,
           scale_inv_list_2, updated_x_amax, updated_activation_lu_amax, updated_kernel_1_amax,
           updated_kernel_2_amax, x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape,
           maybe_fp32_to_fm32)
270
271
272
273

    return dot_2_output, ctx


274
def _fused_layernorm_fp8_mlp_bwd_rule(
275
276
277
278
279
280
281
282
283
284
        fwd_dtype,    # pylint: disable=unused-argument
        bwd_dtype,
        layernorm_type,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,    # pylint: disable=unused-argument
        ffn2_ckpt_name,    # pylint: disable=unused-argument
285
286
        activation_type,
        use_bias,
287
288
        ctx,
        grad):
289
    x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, \
290
291
    casted_kernel_1, casted_kernel_2, amax_list_1, amax_list_2, scale_list_1, scale_list_2, \
    scale_inv_list_1, scale_inv_list_2, updated_x_amax, \
292
    updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
293
    x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32 = ctx
294

295
296
297
    grad_amax = amax_list_2[FP8MetaPackage.GRAD_IDX][0:1]
    grad_scale = scale_list_2[FP8MetaPackage.GRAD_IDX]
    grad_scale_inv = scale_inv_list_2[FP8MetaPackage.GRAD_IDX]
298
299
300

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
301
302
    if use_bias:
        casted_grad, casted_grad_t, dbias_2, updated_grad_amax = \
303
        tex.dbias_cast_transpose(grad, grad_amax, grad_scale,
304
305
306
307
308
309
                             grad_scale_inv, bwd_dtype,
                             static_axis_boundary=-1,
                             transpose_axis_boundary=-1)
        dbias_2 = jnp.reshape(dbias_2, bias_2_shape)
    else:
        casted_grad, casted_grad_t, updated_grad_amax = \
310
        tex.cast_transpose(grad, grad_amax, grad_scale,
311
312
313
                       grad_scale_inv, bwd_dtype,
                       static_axis_boundary=-1,
                       transpose_axis_boundary=-1)
314
        dbias_2 = None
315

316
    casted_activation_lu_out_t = tex.transpose(casted_activation_lu_out,
317
318
                                           static_axis_boundary=-1,
                                           transpose_axis_boundary=-1)
319
320

    # (hidden, batch...,) x (hidden, batch...)
321
    gemm2_x_scale_inv = scale_inv_list_2[FP8MetaPackage.INPUT_IDX]
322
323
    wgrad_2 = fp8_dot_impl(casted_activation_lu_out_t, casted_grad_t, gemm2_x_scale_inv,
                           grad_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims),
324
325
326
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))

    # (batch..., hidden_out) x (hidden_in, hidden_out)
327
    kernel_2_scale_inv = scale_inv_list_2[FP8MetaPackage.WEIGHT_IDX]
328
329
330
331
332
333
    dgrad_2 = fp8_dot_impl(casted_grad, casted_kernel_2, grad_scale_inv, kernel_2_scale_inv,
                           grad.dtype, (x_contracting_dims, (1,)),
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

334
335
336
    dactivation_lu_amax = amax_list_1[FP8MetaPackage.GRAD_IDX][0:1]
    dactivation_lu_scale = scale_list_1[FP8MetaPackage.GRAD_IDX]
    dactivation_lu_scale_inv = scale_inv_list_1[FP8MetaPackage.GRAD_IDX]
337

338
    if len(activation_type) > 1:    # if gated
339
        if use_bias:
340
            dactivation_lu = tex.dact_lu(dgrad_2, dot_1_output, activation_type)
341
            casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \
342
            tex.dbias_cast_transpose(
343
344
345
346
347
348
349
350
351
352
                dactivation_lu,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
                static_axis_boundary=-1,
                transpose_axis_boundary=-2)
            dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
        else:
            casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
353
            tex.dgated_act_lu_cast_transpose(
354
355
356
357
358
359
                dgrad_2,
                dot_1_output,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
360
361
362
                static_axis_boundary=-1,
                activation_type=activation_type)
            dbias_1 = None
363
364
    else:
        if use_bias:
365
            casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax=\
366
            tex.dact_lu_dbias_cast_transpose(
367
368
369
370
371
372
373
                dgrad_2,
                dot_1_output,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
                static_axis_boundary=-1,
374
375
                transpose_axis_boundary=-2,
                activation_type=activation_type)
376
377
            dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
        else:
378
            dactivation_lu = tex.dact_lu(dgrad_2, dot_1_output, activation_type)
379
            casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
380
            tex.cast_transpose(
381
382
383
384
385
386
                dactivation_lu,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
                static_axis_boundary=-1,
387
388
                transpose_axis_boundary=-2)
            dbias_1 = None
389

390
    ln_out_t = tex.transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
391
392

    # (hidden, batch...) x (hidden, batch...)
393
    gemm1_x_scale_inv = scale_inv_list_1[FP8MetaPackage.INPUT_IDX]
394
    xt_batch_dims_2 = tuple(i + 1 for i in xt_batch_dims)
395
    wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv,
396
                           dactivation_lu_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims_2),
397
398
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))

399
400
401
    x_contracting_dims = ((min(x_contracting_dims),) + tuple(i + 1 for i in x_contracting_dims),
                          (1, 2))
    kernel_1_scale_inv = scale_inv_list_1[FP8MetaPackage.WEIGHT_IDX]
402
403
    dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, dactivation_lu_scale_inv,
                           kernel_1_scale_inv, grad.dtype, x_contracting_dims,
404
405
406
407
408
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))

    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)

    if layernorm_type == 'layernorm':
409
        dx, dgamma, dbeta = tex.layernorm_bwd(dgrad_1,
410
411
412
413
414
415
416
417
418
                                          x,
                                          mu,
                                          rsigma,
                                          gamma,
                                          zero_centered_gamma=zero_centered_gamma,
                                          epsilon=epsilon)
    else:
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
419
        dx, dgamma = tex.rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
420
421
        dbeta = None

422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
    amax_list_1[FP8MetaPackage.INPUT_IDX] = \
        amax_list_1[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0])
    amax_list_1[FP8MetaPackage.WEIGHT_IDX] = \
        amax_list_1[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_1_amax[0])
    amax_list_1[FP8MetaPackage.GRAD_IDX] = \
        amax_list_1[FP8MetaPackage.GRAD_IDX].at[0].set(updated_dactivation_lu_amax[0])
    amax_list_2[FP8MetaPackage.INPUT_IDX] = \
        amax_list_2[FP8MetaPackage.INPUT_IDX].at[0].set(updated_activation_lu_amax[0])
    amax_list_2[FP8MetaPackage.WEIGHT_IDX] = \
        amax_list_2[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_2_amax)
    amax_list_2[FP8MetaPackage.GRAD_IDX] = \
        amax_list_2[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0])

    amax_list_1 = maybe_fp32_to_fm32(*amax_list_1)
    scale_list_1 = maybe_fp32_to_fm32(*scale_list_1)
    amax_list_2 = maybe_fp32_to_fm32(*amax_list_2)
    scale_list_2 = maybe_fp32_to_fm32(*scale_list_2)
439

440
    return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \
441
           amax_list_1, amax_list_2, scale_list_1, scale_list_2
442
443


444
_fused_layernorm_fp8_mlp.defvjp(_fused_layernorm_fp8_mlp_fwd_rule,
445
                                _fused_layernorm_fp8_mlp_bwd_rule)