mlp.py 22.6 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX MLP modules"""

6
from typing import List, Tuple, Sequence, Union, Callable
7
from functools import partial
8
9
10

import jax
import jax.numpy as jnp
11
from jax.ad_checkpoint import checkpoint_name
12

13
14
from .cpp_extensions import cast_fp8, transpose, cast_transpose, dbias_cast_transpose
from .cpp_extensions import gelu
15
from .cpp_extensions import gelu_fp8, dgelu, dgelu_dbias_cast_transpose
16
17
from .cpp_extensions import gated_gelu, gated_gelu_fp8
from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose
18
19
20
21
from .cpp_extensions import silu, silu_fp8
from .cpp_extensions import dsilu, dsilu_dbias_cast_transpose
from .cpp_extensions import gated_silu, gated_silu_fp8
from .cpp_extensions import dgated_silu, dgated_silu_cast_transpose
22
23
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd
24
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize, dequantize
25
from .layernorm import canonicalize_layernorm_type
26
from .fp8 import FP8Helper, FP8MetaPackage
27
28
from .sharding import with_sharding_constraint_by_logical_axes

29
activation_dict = {
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    ('gelu',): {
        'fwd': gelu,
        "bwd": dgelu
    },
    ('gelu', 'linear'): {
        'fwd': gated_gelu,
        'bwd': dgated_gelu
    },
    ('silu',): {
        'fwd': silu,
        "bwd": dsilu
    },
    ('silu', 'linear'): {
        'fwd': gated_silu,
        'bwd': dgated_silu
    }
46
}
47

48
activation_fp8_dict = {
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    ('gelu',): {
        'fwd': gelu_fp8,
        'bwd': dgelu_dbias_cast_transpose
    },
    ('gelu', 'linear'): {
        'fwd': gated_gelu_fp8,
        'bwd': dgated_gelu_cast_transpose
    },
    ('silu',): {
        'fwd': silu_fp8,
        'bwd': dsilu_dbias_cast_transpose
    },
    ('silu', 'linear'): {
        'fwd': gated_silu_fp8,
        'bwd': dgated_silu_cast_transpose
    }
65
}
66

67

68
def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
69
    """
70
    Activation Unit
71
    """
72
    if len(activation_type) > 1:
73
        assert x.shape[-2] == 2    # Linear + GeLU
74
    output = _activation_lu(x, activation_type)
75
76
    return output

77

78
79
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
80

81
    _output, _ = _activation_lu_fwd_rule(x, activation_type)
82

83
    return _output
84

85

86
87
88
def _activation_lu_fwd_rule(x, activation_type):
    fwd_output = activation_dict[activation_type]["fwd"](x)
    return fwd_output, (x,)
89

90

91
def _activation_lu_bwd_rule(activation_type, ctx, g):
92
93
    x, = ctx
    assert x.dtype == g.dtype
94

95
    dx = activation_dict[activation_type]["bwd"](g, x)
96
97
    dx = jnp.reshape(dx, x.shape)
    return (dx,)
98

99

100
_activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule)
101

102

103
104
def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
                      fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
105
106
107
108
109
110
111
112
                      activation_type: Sequence[Union[str, Callable]]):
    """
    Activation Unit
    """
    transpose_indices = (1, 2, 0) if len(activation_type) > 1 else (2, 0, 1)
    dx_trans_no_use = jnp.empty([x.shape[i] for i in transpose_indices], dtype=x.dtype)
    dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)

113
114
    output = _activation_lu_fp8(x, dx_trans_no_use, dbias_no_use, amax, scale, scale_inv, fwd_dtype,
                                bwd_dtype, activation_type)
115
116
    return output

117
118
119

@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8))
def _activation_lu_fp8(x: jnp.ndarray, dx_trans_no_use: jnp.ndarray, dbias_no_use: jnp.ndarray,
120
121
122
123
                       amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
                       fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
                       activation_type: Sequence[Union[str, Callable]]):

124
125
    output = _activation_lu_fp8_fwd_rule(x, dx_trans_no_use, dbias_no_use, amax, scale, scale_inv,
                                         fwd_dtype, bwd_dtype, activation_type)
126
127
128

    return output

129
130
131
132
133
134
135
136
137
138
139
140
141

def _activation_lu_fp8_fwd_rule(
        x,
        dx_trans_no_use,    # pylint: disable=unused-argument
        dbias_no_use,    # pylint: disable=unused-argument
        amax,
        scale,
        scale_inv,
        fwd_dtype,
        bwd_dtype,    # pylint: disable=unused-argument
        activation_type):
    activation_lu_out, _ = activation_fp8_dict[activation_type]["fwd"](x, amax, scale, scale_inv,
                                                                       fwd_dtype)
142
143
144
145
146

    activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
    ctx = (x, amax, scale, scale_inv)
    return activation_lu_out, ctx

147
148
149
150
151
152
153

def _activation_lu_fp8_bwd_rule(
        fwd_dtype,    # pylint: disable=unused-argument
        bwd_dtype,
        activation_type,
        ctx,
        g):
154
155
156
    x, amax, scale, scale_inv = ctx

    activation_lu_fp8_bwd = activation_fp8_dict[activation_type]["bwd"]
157
    if len(activation_type) > 1:    #gated, no bias
158
159
160
161
162
163
164
165
166
167
168
        dactivation_lu, dactivation_lu_trans, amax_out = \
        activation_lu_fp8_bwd(g, x, amax, scale, scale_inv, bwd_dtype, -1)
        dbias = jnp.empty(x.shape[-1], x.dtype)
    else:
        dactivation_lu, dactivation_lu_trans, dbias, amax_out = \
        activation_lu_fp8_bwd(g, x, amax, scale, scale_inv, bwd_dtype, -1)
    dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
    dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
    ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out, scale, scale_inv)
    return ctx

169

170
171
172
_activation_lu_fp8.defvjp(_activation_lu_fp8_fwd_rule, _activation_lu_fp8_bwd_rule)


173
def fused_layernorm_fp8_mlp(x: jnp.ndarray,
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
                            gamma: jnp.ndarray,
                            beta: jnp.ndarray,
                            kernels: List[jnp.ndarray],
                            biases: List[jnp.ndarray],
                            fp8_gemm_pkg: FP8MetaPackage,
                            layernorm_type: str,
                            zero_centered_gamma: bool = False,
                            epsilon: float = 1e-6,
                            layernorm_input_axes: Tuple[str, ...] = None,
                            dot_1_input_axes: Tuple[str, ...] = None,
                            dot_2_input_axes: Tuple[str, ...] = None,
                            ffn1_ckpt_name: str = 'ffn1',
                            ffn2_ckpt_name: str = 'ffn2',
                            activation_type: Sequence[Union[str, Callable]] = ('gelu',),
                            use_bias: bool = True) -> jnp.ndarray:
189
    """
190
    Layernorm + GEMM1 + bias + activation + GEMM2 + bias
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    """

    assert len(kernels) == 2
    assert fp8_gemm_pkg.num_of_gemm == len(kernels)

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]
    fp8_max = fp8_gemm_pkg.fp8_max
    amax = fp8_gemm_pkg.amax
    scale = fp8_gemm_pkg.scale
    scale_inv = fp8_gemm_pkg.scale_inv

    fwd_dtype = FP8Helper.FWD_DTYPE
    bwd_dtype = FP8Helper.BWD_DTYPE

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
    if layernorm_type == 'rmsnorm':
        assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"

214
    output = _fused_layernorm_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max,
215
216
217
218
                                      amax, scale, scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
                                      zero_centered_gamma, epsilon, layernorm_input_axes,
                                      dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name,
                                      ffn2_ckpt_name, activation_type, use_bias)
219
220
221
    return output


222
223
@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22))
def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
224
225
226
227
228
229
230
231
232
233
234
235
236
                             kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray,
                             bias_2: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray,
                             scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype,
                             bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool,
                             epsilon: float, layernorm_input_axes: Tuple[str, ...],
                             dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
                             ffn1_ckpt_name: str, ffn2_ckpt_name: str,
                             activation_type: Sequence[Union[str, Callable]], use_bias: bool):
    output, _ = _fused_layernorm_fp8_mlp_fwd_rule(
        x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max, amax, scale, scale_inv,
        fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma, epsilon, layernorm_input_axes,
        dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, activation_type,
        use_bias)
237
238
239
    return output


240
def _fused_layernorm_fp8_mlp_fwd_rule(
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
        fp8_max,
        amax,
        scale,
        scale_inv,
        fwd_dtype,
        bwd_dtype,    # pylint: disable=unused-argument
        layernorm_type,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,
261
262
263
        ffn2_ckpt_name,
        activation_type,
        use_bias):
264

265
    is_gated = len(activation_type) > 1
266
267
268
269
    # x should be in shape of (batch..., hidden)
    # Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out)
    # Kernel_2 should be in shape of (Hidden_in, Hidden_out)
    assert len(kernel_1.shape) == 3
270
    assert kernel_1.shape[-2] == len(activation_type)
271
272
273
274
275
276
277
278
279
280
    assert len(kernel_2.shape) == 2

    x_contracting_dims = (len(x.shape) - 1,)
    xt_batch_dims = tuple(range(1, x.ndim))

    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
    assert kernel_1.shape[-1] == kernel_2.shape[0]

    # Squeeze act axis
    # (hidden_in, 1, hidden_out) -> (hidden_in, hidden_out)
281
282
    if not is_gated:
        kernel_1 = jnp.squeeze(kernel_1, axis=-2)
283

284
285
286
287
288
    maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
        FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv)
    fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv)

    scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    amax = FP8Helper.update_amax_history(amax)

    gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)

    x_amax = amax[gemm1_x_idx, 0:1]
    x_scale = scale[gemm1_x_idx]
    x_scale_inv = scale_inv[gemm1_x_idx]

    x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)

    if layernorm_type == 'layernorm':
        ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
            x,
            gamma,
            beta,
            x_amax,
            x_scale,
            x_scale_inv,
            out_dtype=fwd_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon)
    else:
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
        ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
                                                         gamma,
                                                         x_amax,
                                                         x_scale,
                                                         x_scale_inv,
                                                         out_dtype=fwd_dtype,
                                                         epsilon=epsilon)
        mu = None

    assert x.shape == ln_out.shape

    kernel_1_amax = amax[gemm1_kernel_idx, 0:1]
    kernel_1_scale = scale[gemm1_kernel_idx]
    kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]

    # Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel_1, updated_kernel_1_amax = \
        cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)

    ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes)

    # (batch..., hidden_in) x (hidden_in, hidden_out)
    dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
                                (x_contracting_dims, (0,)),
                                get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
339
340
341
    if use_bias:
        bias_1_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1.shape
        dot_1_output += jnp.reshape(bias_1, bias_1_shape)
342
343
344
345
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)

    gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)

346
347
348
349
    activation_lu_out_amax = amax[gemm2_x_idx, 0:1]
    activation_lu_out_scale = scale[gemm2_x_idx]
    activation_lu_out_scale_inv = scale_inv[gemm2_x_idx]

350
    activation_lu_fwd_fp8 = activation_fp8_dict[activation_type]["fwd"]
351
352

    # (batch..., hidden_in) -> (batch..., hidden)
353
354
    casted_activation_lu_out, updated_activation_lu_amax = \
        activation_lu_fwd_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale,
355
                                                    activation_lu_out_scale_inv, fwd_dtype)
356

357
358
    casted_activation_lu_out = with_sharding_constraint_by_logical_axes(
        casted_activation_lu_out, dot_2_input_axes)
359
360
361
362
363
364
365
366

    kernel_2_scale = scale[gemm2_kernel_idx]
    kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
    # Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)

    # (batch..., hidden_in) x (hidden_out, hidden_in)
367
    dot_2_output = fp8_dot_impl(casted_activation_lu_out, casted_kernel_2,
368
369
                                activation_lu_out_scale_inv, kernel_2_scale_inv, x.dtype,
                                (x_contracting_dims, (0,)),
370
371
                                get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))

372
373
374
375
    if use_bias:
        bias_2_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2.shape
        dot_2_output += jnp.reshape(bias_2, bias_2_shape)

376
377
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

378
379
380
    ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1,
           casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax,
           updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax,
381
           x_contracting_dims, xt_batch_dims, bias_1.shape, bias_2.shape, maybe_fp32_to_fm32)
382
383
384
385

    return dot_2_output, ctx


386
def _fused_layernorm_fp8_mlp_bwd_rule(
387
388
389
390
391
392
393
394
395
396
        fwd_dtype,    # pylint: disable=unused-argument
        bwd_dtype,
        layernorm_type,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,    # pylint: disable=unused-argument
        ffn2_ckpt_name,    # pylint: disable=unused-argument
397
398
        activation_type,
        use_bias,
399
400
        ctx,
        grad):
401
    x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, \
402
    casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \
403
    updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
404
    x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32 = ctx
405

406
407
    is_gated = len(activation_type) > 1

408
409
410
411
412
413
414
415
416
    gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)

    grad_amax = amax[gemm2_grad_idx, 0:1]
    grad_scale = scale[gemm2_grad_idx]
    grad_scale_inv = scale_inv[gemm2_grad_idx]

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)

417
418
419
420
421
422
423
424
425
426
427
428
429
430
    if use_bias:
        casted_grad, casted_grad_t, dbias_2, updated_grad_amax = \
        dbias_cast_transpose(grad, grad_amax, grad_scale,
                             grad_scale_inv, bwd_dtype,
                             static_axis_boundary=-1,
                             transpose_axis_boundary=-1)
        dbias_2 = jnp.reshape(dbias_2, bias_2_shape)
    else:
        casted_grad, casted_grad_t, updated_grad_amax = \
        cast_transpose(grad, grad_amax, grad_scale,
                       grad_scale_inv, bwd_dtype,
                       static_axis_boundary=-1,
                       transpose_axis_boundary=-1)
        dbias_2 = jnp.empty(bias_2_shape, grad.dtype)
431

432
433
434
    casted_activation_lu_out_t = transpose(casted_activation_lu_out,
                                           static_axis_boundary=-1,
                                           transpose_axis_boundary=-1)
435
436
437

    # (hidden, batch...,) x (hidden, batch...)
    gemm2_x_scale_inv = scale_inv[gemm2_x_idx]
438
439
    wgrad_2 = fp8_dot_impl(casted_activation_lu_out_t, casted_grad_t, gemm2_x_scale_inv,
                           grad_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims),
440
441
442
443
444
445
446
447
448
449
450
451
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))

    # (batch..., hidden_out) x (hidden_in, hidden_out)
    kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
    dgrad_2 = fp8_dot_impl(casted_grad, casted_kernel_2, grad_scale_inv, kernel_2_scale_inv,
                           grad.dtype, (x_contracting_dims, (1,)),
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

    gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)

452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    dactivation_lu_amax = amax[gemm1_grad_idx, 0:1]
    dactivation_lu_scale = scale[gemm1_grad_idx]
    dactivation_lu_scale_inv = scale_inv[gemm1_grad_idx]

    dactivation_lu_cast_transpose = activation_fp8_dict[activation_type]["bwd"]
    dactivation_lu = activation_dict[activation_type]["bwd"](dgrad_2, dot_1_output)

    if is_gated:
        if use_bias:
            casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \
            dbias_cast_transpose(
                dactivation_lu,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
                static_axis_boundary=-1,
                transpose_axis_boundary=-2)
            dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
        else:
            casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
            dactivation_lu_cast_transpose(
                dgrad_2,
                dot_1_output,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
                static_axis_boundary=-1)
            dbias_1 = jnp.empty(bias_1_shape, bwd_dtype)
    else:
        if use_bias:
            casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \
            dactivation_lu_cast_transpose(
                dgrad_2,
                dot_1_output,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
                static_axis_boundary=-1,
                transpose_axis_boundary=-1)
            dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
        else:
            casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
            cast_transpose(
                dactivation_lu,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
                static_axis_boundary=-1,
                transpose_axis_boundary=-1)
            dbias_1 = jnp.empty(bias_1_shape, bwd_dtype)
506
507
508
509
510

    ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)

    # (hidden, batch...) x (hidden, batch...)
    gemm1_x_scale_inv = scale_inv[gemm1_x_idx]
511
512
513
    xt_batch_dims_2 = xt_batch_dims if not is_gated \
        else tuple(i + 1 for i in xt_batch_dims)
    wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv,
514
                           dactivation_lu_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims_2),
515
516
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
    # Expand act axis to match the shape with the given kernel_1
517
518
    if not is_gated:
        wgrad_1 = jnp.expand_dims(wgrad_1, axis=-2)
519
520

    # (batch..., hidden_out) x (hidden_in, hidden_out)
521
    if is_gated:
522
523
        x_contracting_dims = ((min(x_contracting_dims),) + tuple(i + 1 for i in x_contracting_dims),
                              (1, 2))
524
525
    else:
        x_contracting_dims = (x_contracting_dims, (1,))
526
    kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
527
528
    dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, dactivation_lu_scale_inv,
                           kernel_1_scale_inv, grad.dtype, x_contracting_dims,
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))

    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)

    if layernorm_type == 'layernorm':
        dx, dgamma, dbeta = layernorm_bwd(dgrad_1,
                                          x,
                                          mu,
                                          rsigma,
                                          gamma,
                                          zero_centered_gamma=zero_centered_gamma,
                                          epsilon=epsilon)
    else:
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
        dx, dgamma = rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
        dbeta = None

    amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0])
    amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0])
549
550
    amax = amax.at[gemm1_grad_idx, 0].set(updated_dactivation_lu_amax[0])
    amax = amax.at[gemm2_x_idx, 0].set(updated_activation_lu_amax[0])
551
552
553
    amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax)
    amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0])

554
555
    fp8_max, amax, scale, scale_inv = maybe_fp32_to_fm32(fp8_max, amax, scale, scale_inv)

556
557
558
559
    return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \
           fp8_max, amax, scale, scale_inv


560
_fused_layernorm_fp8_mlp.defvjp(_fused_layernorm_fp8_mlp_fwd_rule,
561
                                _fused_layernorm_fp8_mlp_bwd_rule)