mlp.py 19.7 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX MLP modules"""

6
from typing import List, Tuple, Sequence, Union, Callable
7
from functools import partial
8
9
10

import jax
import jax.numpy as jnp
11
from jax.ad_checkpoint import checkpoint_name
12

13
from .cpp_extensions import cast_fp8, transpose, cast_transpose, dbias_cast_transpose
14
15
from .cpp_extensions import act_lu, act_lu_fp8, dact_lu
from .cpp_extensions import dact_lu_dbias_cast_transpose, dgated_act_lu_cast_transpose
16
17
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd
18
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize
19
from .layernorm import canonicalize_layernorm_type
20
from .fp8 import FP8Helper, FP8MetaPackage
21
22
from .sharding import with_sharding_constraint_by_logical_axes

23

24
def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
25
    """
26
    Activation Unit
27
    """
28
    if len(activation_type) > 1:
29
        assert x.shape[-2] == 2    # Linear + GeLU
30
    output = _activation_lu(x, activation_type)
31
32
    return output

33

34
35
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
36

37
    _output, _ = _activation_lu_fwd_rule(x, activation_type)
38

39
    return _output
40

41

42
def _activation_lu_fwd_rule(x, activation_type):
43
    fwd_output = act_lu(x, activation_type)
44
    return fwd_output, (x,)
45

46

47
def _activation_lu_bwd_rule(activation_type, ctx, g):
48
49
    x, = ctx
    assert x.dtype == g.dtype
50

51
    dx = dact_lu(g, x, activation_type)
52
53
    dx = jnp.reshape(dx, x.shape)
    return (dx,)
54

55

56
_activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule)
57

58

59
def fused_layernorm_fp8_mlp(x: jnp.ndarray,
60
61
62
63
                            gamma: jnp.ndarray,
                            beta: jnp.ndarray,
                            kernels: List[jnp.ndarray],
                            biases: List[jnp.ndarray],
64
                            fp8_meta_pkgs: List[FP8MetaPackage],
65
66
67
68
69
70
71
72
73
74
                            layernorm_type: str,
                            zero_centered_gamma: bool = False,
                            epsilon: float = 1e-6,
                            layernorm_input_axes: Tuple[str, ...] = None,
                            dot_1_input_axes: Tuple[str, ...] = None,
                            dot_2_input_axes: Tuple[str, ...] = None,
                            ffn1_ckpt_name: str = 'ffn1',
                            ffn2_ckpt_name: str = 'ffn2',
                            activation_type: Sequence[Union[str, Callable]] = ('gelu',),
                            use_bias: bool = True) -> jnp.ndarray:
75
    """
76
    Layernorm + GEMM1 + bias + activation + GEMM2 + bias
77
78
79
    """

    assert len(kernels) == 2
80
    assert len(fp8_meta_pkgs) == len(kernels)
81
82
83
84
85

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]
86
87
88
89
    amax_list_1 = fp8_meta_pkgs[0].amax_list
    amax_list_2 = fp8_meta_pkgs[1].amax_list
    scale_list_1 = fp8_meta_pkgs[0].scale_list
    scale_list_2 = fp8_meta_pkgs[1].scale_list
90
91
92
93
94
95
96
97
98
99

    fwd_dtype = FP8Helper.FWD_DTYPE
    bwd_dtype = FP8Helper.BWD_DTYPE

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
    if layernorm_type == 'rmsnorm':
        assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"

100
101
102
103
104
105
    output = _fused_layernorm_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2,
                                      amax_list_1, amax_list_2, scale_list_1, scale_list_2,
                                      fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma,
                                      epsilon, layernorm_input_axes, dot_1_input_axes,
                                      dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name,
                                      activation_type, use_bias)
106
107
108
    return output


109
110
@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22))
def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
111
                             kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray,
112
113
114
                             bias_2: jnp.ndarray, amax_list_1: List[jnp.ndarray],
                             amax_list_2: List[jnp.ndarray], scale_list_1: List[jnp.ndarray],
                             scale_list_2: List[jnp.ndarray], fwd_dtype: jnp.dtype,
115
116
117
118
119
120
                             bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool,
                             epsilon: float, layernorm_input_axes: Tuple[str, ...],
                             dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
                             ffn1_ckpt_name: str, ffn2_ckpt_name: str,
                             activation_type: Sequence[Union[str, Callable]], use_bias: bool):
    output, _ = _fused_layernorm_fp8_mlp_fwd_rule(
121
122
123
124
        x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, amax_list_1, amax_list_2, scale_list_1,
        scale_list_2, fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma, epsilon,
        layernorm_input_axes, dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name,
        activation_type, use_bias)
125
126
127
    return output


128
def _fused_layernorm_fp8_mlp_fwd_rule(
129
130
131
132
133
134
135
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
136
137
138
139
        amax_list_1,
        amax_list_2,
        scale_list_1,
        scale_list_2,
140
141
142
143
144
145
146
147
148
        fwd_dtype,
        bwd_dtype,    # pylint: disable=unused-argument
        layernorm_type,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,
149
150
151
        ffn2_ckpt_name,
        activation_type,
        use_bias):
152
153
154
155
156

    # x should be in shape of (batch..., hidden)
    # Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out)
    # Kernel_2 should be in shape of (Hidden_in, Hidden_out)
    assert len(kernel_1.shape) == 3
157
    assert kernel_1.shape[-2] == len(activation_type)
158
159
160
161
162
163
164
165
    assert len(kernel_2.shape) == 2

    x_contracting_dims = (len(x.shape) - 1,)
    xt_batch_dims = tuple(range(1, x.ndim))

    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
    assert kernel_1.shape[-1] == kernel_2.shape[0]

166
    maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        FP8Helper.generate_fp8_meta_dtype_converter_pair(*amax_list_1, *scale_list_1,
                                                         *amax_list_2, *scale_list_2)
    amax_list_1 = maybe_fm32_to_fp32(*amax_list_1)
    scale_list_1 = maybe_fm32_to_fp32(*scale_list_1)
    amax_list_2 = maybe_fm32_to_fp32(*amax_list_2)
    scale_list_2 = maybe_fm32_to_fp32(*scale_list_2)

    fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype]
    scale_list_1, scale_inv_list_1 = FP8MetaPackage.update_fp8_scale(amax_list_1, scale_list_1,
                                                                     fp8_dtype_list)
    amax_list_1 = FP8MetaPackage.update_amax_list(amax_list_1)
    scale_list_2, scale_inv_list_2 = FP8MetaPackage.update_fp8_scale(amax_list_2, scale_list_2,
                                                                     fp8_dtype_list)
    amax_list_2 = FP8MetaPackage.update_amax_list(amax_list_2)

    x_amax = amax_list_1[FP8MetaPackage.INPUT_IDX][0:1]
    x_scale = scale_list_1[FP8MetaPackage.INPUT_IDX]
    x_scale_inv = scale_inv_list_1[FP8MetaPackage.INPUT_IDX]
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

    x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)

    if layernorm_type == 'layernorm':
        ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
            x,
            gamma,
            beta,
            x_amax,
            x_scale,
            x_scale_inv,
            out_dtype=fwd_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon)
    else:
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
        ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
                                                         gamma,
                                                         x_amax,
                                                         x_scale,
                                                         x_scale_inv,
                                                         out_dtype=fwd_dtype,
                                                         epsilon=epsilon)
        mu = None

    assert x.shape == ln_out.shape

213
214
215
    kernel_1_amax = amax_list_1[FP8MetaPackage.WEIGHT_IDX][0:1]
    kernel_1_scale = scale_list_1[FP8MetaPackage.WEIGHT_IDX]
    kernel_1_scale_inv = scale_inv_list_1[FP8MetaPackage.WEIGHT_IDX]
216
217
218
219
220
221
222
223
224
225
226
227

    # Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel_1, updated_kernel_1_amax = \
        cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)

    ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes)

    # (batch..., hidden_in) x (hidden_in, hidden_out)
    dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
                                (x_contracting_dims, (0,)),
                                get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
228
    if use_bias:
229
230
231
232
233
        bias_1_shape = bias_1.shape
        bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
        dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
    else:
        bias_1_shape = None
234
235
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)

236
237
238
    activation_lu_out_amax = amax_list_2[FP8MetaPackage.INPUT_IDX][0:1]
    activation_lu_out_scale = scale_list_2[FP8MetaPackage.INPUT_IDX]
    activation_lu_out_scale_inv = scale_inv_list_2[FP8MetaPackage.INPUT_IDX]
239
240

    # (batch..., hidden_in) -> (batch..., hidden)
241
    casted_activation_lu_out, updated_activation_lu_amax = \
242
243
    act_lu_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale,
               activation_lu_out_scale_inv, fwd_dtype, activation_type)
244

245
246
    casted_activation_lu_out = with_sharding_constraint_by_logical_axes(
        casted_activation_lu_out, dot_2_input_axes)
247

248
249
    kernel_2_scale = scale_list_2[FP8MetaPackage.WEIGHT_IDX]
    kernel_2_scale_inv = scale_inv_list_2[FP8MetaPackage.WEIGHT_IDX]
250
251
252
253
254
    # Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)

    # (batch..., hidden_in) x (hidden_out, hidden_in)
255
    dot_2_output = fp8_dot_impl(casted_activation_lu_out, casted_kernel_2,
256
257
                                activation_lu_out_scale_inv, kernel_2_scale_inv, x.dtype,
                                (x_contracting_dims, (0,)),
258
259
                                get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))

260
    if use_bias:
261
262
263
264
265
        bias_2_shape = bias_2.shape
        bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
        dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
    else:
        bias_2_shape = None
266

267
268
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

269
    ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1,
270
271
272
273
           casted_kernel_2, amax_list_1, amax_list_2, scale_list_1, scale_list_2, scale_inv_list_1,
           scale_inv_list_2, updated_x_amax, updated_activation_lu_amax, updated_kernel_1_amax,
           updated_kernel_2_amax, x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape,
           maybe_fp32_to_fm32)
274
275
276
277

    return dot_2_output, ctx


278
def _fused_layernorm_fp8_mlp_bwd_rule(
279
280
281
282
283
284
285
286
287
288
        fwd_dtype,    # pylint: disable=unused-argument
        bwd_dtype,
        layernorm_type,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,    # pylint: disable=unused-argument
        ffn2_ckpt_name,    # pylint: disable=unused-argument
289
290
        activation_type,
        use_bias,
291
292
        ctx,
        grad):
293
    x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, \
294
295
    casted_kernel_1, casted_kernel_2, amax_list_1, amax_list_2, scale_list_1, scale_list_2, \
    scale_inv_list_1, scale_inv_list_2, updated_x_amax, \
296
    updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
297
    x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32 = ctx
298

299
300
301
    grad_amax = amax_list_2[FP8MetaPackage.GRAD_IDX][0:1]
    grad_scale = scale_list_2[FP8MetaPackage.GRAD_IDX]
    grad_scale_inv = scale_inv_list_2[FP8MetaPackage.GRAD_IDX]
302
303
304

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
305
306
307
308
309
310
311
312
313
314
315
316
317
    if use_bias:
        casted_grad, casted_grad_t, dbias_2, updated_grad_amax = \
        dbias_cast_transpose(grad, grad_amax, grad_scale,
                             grad_scale_inv, bwd_dtype,
                             static_axis_boundary=-1,
                             transpose_axis_boundary=-1)
        dbias_2 = jnp.reshape(dbias_2, bias_2_shape)
    else:
        casted_grad, casted_grad_t, updated_grad_amax = \
        cast_transpose(grad, grad_amax, grad_scale,
                       grad_scale_inv, bwd_dtype,
                       static_axis_boundary=-1,
                       transpose_axis_boundary=-1)
318
        dbias_2 = None
319

320
321
322
    casted_activation_lu_out_t = transpose(casted_activation_lu_out,
                                           static_axis_boundary=-1,
                                           transpose_axis_boundary=-1)
323
324

    # (hidden, batch...,) x (hidden, batch...)
325
    gemm2_x_scale_inv = scale_inv_list_2[FP8MetaPackage.INPUT_IDX]
326
327
    wgrad_2 = fp8_dot_impl(casted_activation_lu_out_t, casted_grad_t, gemm2_x_scale_inv,
                           grad_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims),
328
329
330
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))

    # (batch..., hidden_out) x (hidden_in, hidden_out)
331
    kernel_2_scale_inv = scale_inv_list_2[FP8MetaPackage.WEIGHT_IDX]
332
333
334
335
336
337
    dgrad_2 = fp8_dot_impl(casted_grad, casted_kernel_2, grad_scale_inv, kernel_2_scale_inv,
                           grad.dtype, (x_contracting_dims, (1,)),
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

338
339
340
    dactivation_lu_amax = amax_list_1[FP8MetaPackage.GRAD_IDX][0:1]
    dactivation_lu_scale = scale_list_1[FP8MetaPackage.GRAD_IDX]
    dactivation_lu_scale_inv = scale_inv_list_1[FP8MetaPackage.GRAD_IDX]
341

342
    if len(activation_type) > 1:    # if gated
343
        if use_bias:
344
            dactivation_lu = dact_lu(dgrad_2, dot_1_output, activation_type)
345
346
347
348
349
350
351
352
353
354
355
356
            casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \
            dbias_cast_transpose(
                dactivation_lu,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
                static_axis_boundary=-1,
                transpose_axis_boundary=-2)
            dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
        else:
            casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
357
            dgated_act_lu_cast_transpose(
358
359
360
361
362
363
                dgrad_2,
                dot_1_output,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
364
365
366
                static_axis_boundary=-1,
                activation_type=activation_type)
            dbias_1 = None
367
368
    else:
        if use_bias:
369
370
            casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax=\
            dact_lu_dbias_cast_transpose(
371
372
373
374
375
376
377
                dgrad_2,
                dot_1_output,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
                static_axis_boundary=-1,
378
379
                transpose_axis_boundary=-2,
                activation_type=activation_type)
380
381
            dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
        else:
382
            dactivation_lu = dact_lu(dgrad_2, dot_1_output, activation_type)
383
384
385
386
387
388
389
390
            casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
            cast_transpose(
                dactivation_lu,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
                static_axis_boundary=-1,
391
392
                transpose_axis_boundary=-2)
            dbias_1 = None
393
394
395
396

    ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)

    # (hidden, batch...) x (hidden, batch...)
397
    gemm1_x_scale_inv = scale_inv_list_1[FP8MetaPackage.INPUT_IDX]
398
    xt_batch_dims_2 = tuple(i + 1 for i in xt_batch_dims)
399
    wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv,
400
                           dactivation_lu_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims_2),
401
402
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))

403
404
405
    x_contracting_dims = ((min(x_contracting_dims),) + tuple(i + 1 for i in x_contracting_dims),
                          (1, 2))
    kernel_1_scale_inv = scale_inv_list_1[FP8MetaPackage.WEIGHT_IDX]
406
407
    dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, dactivation_lu_scale_inv,
                           kernel_1_scale_inv, grad.dtype, x_contracting_dims,
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))

    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)

    if layernorm_type == 'layernorm':
        dx, dgamma, dbeta = layernorm_bwd(dgrad_1,
                                          x,
                                          mu,
                                          rsigma,
                                          gamma,
                                          zero_centered_gamma=zero_centered_gamma,
                                          epsilon=epsilon)
    else:
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
        dx, dgamma = rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
        dbeta = None

426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    amax_list_1[FP8MetaPackage.INPUT_IDX] = \
        amax_list_1[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0])
    amax_list_1[FP8MetaPackage.WEIGHT_IDX] = \
        amax_list_1[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_1_amax[0])
    amax_list_1[FP8MetaPackage.GRAD_IDX] = \
        amax_list_1[FP8MetaPackage.GRAD_IDX].at[0].set(updated_dactivation_lu_amax[0])
    amax_list_2[FP8MetaPackage.INPUT_IDX] = \
        amax_list_2[FP8MetaPackage.INPUT_IDX].at[0].set(updated_activation_lu_amax[0])
    amax_list_2[FP8MetaPackage.WEIGHT_IDX] = \
        amax_list_2[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_2_amax)
    amax_list_2[FP8MetaPackage.GRAD_IDX] = \
        amax_list_2[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0])

    amax_list_1 = maybe_fp32_to_fm32(*amax_list_1)
    scale_list_1 = maybe_fp32_to_fm32(*scale_list_1)
    amax_list_2 = maybe_fp32_to_fm32(*amax_list_2)
    scale_list_2 = maybe_fp32_to_fm32(*scale_list_2)
443

444
    return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \
445
           amax_list_1, amax_list_2, scale_list_1, scale_list_2
446
447


448
_fused_layernorm_fp8_mlp.defvjp(_fused_layernorm_fp8_mlp_fwd_rule,
449
                                _fused_layernorm_fp8_mlp_bwd_rule)