mlp.py 21.1 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX MLP modules"""

6
from typing import List, Tuple, Sequence, Union, Callable
7
from functools import partial
8
9
10

import jax
import jax.numpy as jnp
11
from jax.ad_checkpoint import checkpoint_name
12

13
from .cpp_extensions import cast_fp8, transpose, cast_transpose, dbias_cast_transpose
14
15
from .cpp_extensions import act_lu, act_lu_fp8, dact_lu
from .cpp_extensions import dact_lu_dbias_cast_transpose, dgated_act_lu_cast_transpose
16
17
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd
18
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize, dequantize
19
from .layernorm import canonicalize_layernorm_type
20
from .fp8 import FP8Helper, FP8MetaPackage
21
22
from .sharding import with_sharding_constraint_by_logical_axes

23

24
def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
25
    """
26
    Activation Unit
27
    """
28
    if len(activation_type) > 1:
29
        assert x.shape[-2] == 2    # Linear + GeLU
30
    output = _activation_lu(x, activation_type)
31
32
    return output

33

34
35
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
36

37
    _output, _ = _activation_lu_fwd_rule(x, activation_type)
38

39
    return _output
40

41

42
def _activation_lu_fwd_rule(x, activation_type):
43
    fwd_output = act_lu(x, activation_type)
44
    return fwd_output, (x,)
45

46

47
def _activation_lu_bwd_rule(activation_type, ctx, g):
48
49
    x, = ctx
    assert x.dtype == g.dtype
50

51
    dx = dact_lu(g, x, activation_type)
52
53
    dx = jnp.reshape(dx, x.shape)
    return (dx,)
54

55

56
_activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule)
57

58

59
60
def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
                      fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
61
62
63
64
                      activation_type: Sequence[Union[str, Callable]]):
    """
    Activation Unit
    """
65
    transpose_indices = (1, 2, 0)
66
67
68
    dx_trans_no_use = jnp.empty([x.shape[i] for i in transpose_indices], dtype=x.dtype)
    dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)

69
70
    output = _activation_lu_fp8(x, dx_trans_no_use, dbias_no_use, amax, scale, scale_inv, fwd_dtype,
                                bwd_dtype, activation_type)
71
72
    return output

73
74
75

@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8))
def _activation_lu_fp8(x: jnp.ndarray, dx_trans_no_use: jnp.ndarray, dbias_no_use: jnp.ndarray,
76
77
78
79
                       amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
                       fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
                       activation_type: Sequence[Union[str, Callable]]):

80
81
    output = _activation_lu_fp8_fwd_rule(x, dx_trans_no_use, dbias_no_use, amax, scale, scale_inv,
                                         fwd_dtype, bwd_dtype, activation_type)
82
83
84

    return output

85

86
87
88
89
90
91
92
93
94
def _activation_lu_fp8_fwd_rule(x,
                                dx_trans_no_use,    # pylint: disable=unused-argument
                                dbias_no_use,   # pylint: disable=unused-argument
                                amax,
                                scale, scale_inv,
                                fwd_dtype, bwd_dtype,   # pylint: disable=unused-argument
                                activation_type):
    activation_lu_out, _ = act_lu_fp8(x, amax, scale, scale_inv, fwd_dtype,
                                      activation_type)
95
96
97
98
    activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
    ctx = (x, amax, scale, scale_inv)
    return activation_lu_out, ctx

99
100
101
102
103
104
105

def _activation_lu_fp8_bwd_rule(
        fwd_dtype,    # pylint: disable=unused-argument
        bwd_dtype,
        activation_type,
        ctx,
        g):
106
107
    x, amax, scale, scale_inv = ctx

108
    if len(activation_type) > 1: #gated, no bias
109
        dactivation_lu, dactivation_lu_trans, amax_out = \
110
        dgated_act_lu_cast_transpose(g, x, amax, scale, scale_inv, bwd_dtype, -1, activation_type)
111
        dbias = jnp.empty(x.shape[-1], x.dtype)
112
    else: #not gated, with bias
113
        dactivation_lu, dactivation_lu_trans, dbias, amax_out = \
114
115
        dact_lu_dbias_cast_transpose(g, x, amax, scale, scale_inv, bwd_dtype,
                                     -1, -2, activation_type)
116
117
118
119
120
    dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
    dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
    ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out, scale, scale_inv)
    return ctx

121

122
123
124
_activation_lu_fp8.defvjp(_activation_lu_fp8_fwd_rule, _activation_lu_fp8_bwd_rule)


125
def fused_layernorm_fp8_mlp(x: jnp.ndarray,
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
                            gamma: jnp.ndarray,
                            beta: jnp.ndarray,
                            kernels: List[jnp.ndarray],
                            biases: List[jnp.ndarray],
                            fp8_gemm_pkg: FP8MetaPackage,
                            layernorm_type: str,
                            zero_centered_gamma: bool = False,
                            epsilon: float = 1e-6,
                            layernorm_input_axes: Tuple[str, ...] = None,
                            dot_1_input_axes: Tuple[str, ...] = None,
                            dot_2_input_axes: Tuple[str, ...] = None,
                            ffn1_ckpt_name: str = 'ffn1',
                            ffn2_ckpt_name: str = 'ffn2',
                            activation_type: Sequence[Union[str, Callable]] = ('gelu',),
                            use_bias: bool = True) -> jnp.ndarray:
141
    """
142
    Layernorm + GEMM1 + bias + activation + GEMM2 + bias
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    """

    assert len(kernels) == 2
    assert fp8_gemm_pkg.num_of_gemm == len(kernels)

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]
    fp8_max = fp8_gemm_pkg.fp8_max
    amax = fp8_gemm_pkg.amax
    scale = fp8_gemm_pkg.scale
    scale_inv = fp8_gemm_pkg.scale_inv

    fwd_dtype = FP8Helper.FWD_DTYPE
    bwd_dtype = FP8Helper.BWD_DTYPE

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
    if layernorm_type == 'rmsnorm':
        assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"

166
    output = _fused_layernorm_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max,
167
168
169
170
                                      amax, scale, scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
                                      zero_centered_gamma, epsilon, layernorm_input_axes,
                                      dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name,
                                      ffn2_ckpt_name, activation_type, use_bias)
171
172
173
    return output


174
175
@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22))
def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
176
177
178
179
180
181
182
183
184
185
186
187
188
                             kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray,
                             bias_2: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray,
                             scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype,
                             bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool,
                             epsilon: float, layernorm_input_axes: Tuple[str, ...],
                             dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
                             ffn1_ckpt_name: str, ffn2_ckpt_name: str,
                             activation_type: Sequence[Union[str, Callable]], use_bias: bool):
    output, _ = _fused_layernorm_fp8_mlp_fwd_rule(
        x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max, amax, scale, scale_inv,
        fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma, epsilon, layernorm_input_axes,
        dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, activation_type,
        use_bias)
189
190
191
    return output


192
def _fused_layernorm_fp8_mlp_fwd_rule(
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
        fp8_max,
        amax,
        scale,
        scale_inv,
        fwd_dtype,
        bwd_dtype,    # pylint: disable=unused-argument
        layernorm_type,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,
213
214
215
        ffn2_ckpt_name,
        activation_type,
        use_bias):
216
217
218
219
220

    # x should be in shape of (batch..., hidden)
    # Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out)
    # Kernel_2 should be in shape of (Hidden_in, Hidden_out)
    assert len(kernel_1.shape) == 3
221
    assert kernel_1.shape[-2] == len(activation_type)
222
223
224
225
226
227
228
229
    assert len(kernel_2.shape) == 2

    x_contracting_dims = (len(x.shape) - 1,)
    xt_batch_dims = tuple(range(1, x.ndim))

    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
    assert kernel_1.shape[-1] == kernel_2.shape[0]

230
231
232
233
    maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
        FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv)
    fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv)
    scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    amax = FP8Helper.update_amax_history(amax)

    gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)

    x_amax = amax[gemm1_x_idx, 0:1]
    x_scale = scale[gemm1_x_idx]
    x_scale_inv = scale_inv[gemm1_x_idx]

    x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)

    if layernorm_type == 'layernorm':
        ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
            x,
            gamma,
            beta,
            x_amax,
            x_scale,
            x_scale_inv,
            out_dtype=fwd_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon)
    else:
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
        ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
                                                         gamma,
                                                         x_amax,
                                                         x_scale,
                                                         x_scale_inv,
                                                         out_dtype=fwd_dtype,
                                                         epsilon=epsilon)
        mu = None

    assert x.shape == ln_out.shape

    kernel_1_amax = amax[gemm1_kernel_idx, 0:1]
    kernel_1_scale = scale[gemm1_kernel_idx]
    kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]

    # Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel_1, updated_kernel_1_amax = \
        cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)

    ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes)

    # (batch..., hidden_in) x (hidden_in, hidden_out)
    dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
                                (x_contracting_dims, (0,)),
                                get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
284
    if use_bias:
285
286
287
288
289
        bias_1_shape = bias_1.shape
        bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
        dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
    else:
        bias_1_shape = None
290
291
292
293
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)

    gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)

294
295
296
297
    activation_lu_out_amax = amax[gemm2_x_idx, 0:1]
    activation_lu_out_scale = scale[gemm2_x_idx]
    activation_lu_out_scale_inv = scale_inv[gemm2_x_idx]

298
299

    # (batch..., hidden_in) -> (batch..., hidden)
300
    casted_activation_lu_out, updated_activation_lu_amax = \
301
302
    act_lu_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale,
               activation_lu_out_scale_inv, fwd_dtype, activation_type)
303

304
305
    casted_activation_lu_out = with_sharding_constraint_by_logical_axes(
        casted_activation_lu_out, dot_2_input_axes)
306
307
308
309
310
311
312
313

    kernel_2_scale = scale[gemm2_kernel_idx]
    kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
    # Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)

    # (batch..., hidden_in) x (hidden_out, hidden_in)
314
    dot_2_output = fp8_dot_impl(casted_activation_lu_out, casted_kernel_2,
315
316
                                activation_lu_out_scale_inv, kernel_2_scale_inv, x.dtype,
                                (x_contracting_dims, (0,)),
317
318
                                get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))

319
    if use_bias:
320
321
322
323
324
        bias_2_shape = bias_2.shape
        bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
        dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
    else:
        bias_2_shape = None
325

326
327
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

328
329
330
    ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1,
           casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax,
           updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax,
331
           x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32)
332
333
334
335

    return dot_2_output, ctx


336
def _fused_layernorm_fp8_mlp_bwd_rule(
337
338
339
340
341
342
343
344
345
346
        fwd_dtype,    # pylint: disable=unused-argument
        bwd_dtype,
        layernorm_type,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,    # pylint: disable=unused-argument
        ffn2_ckpt_name,    # pylint: disable=unused-argument
347
348
        activation_type,
        use_bias,
349
350
        ctx,
        grad):
351
    x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, \
352
    casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \
353
    updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
354
    x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32 = ctx
355
356
357
358
359
360
361
362
363

    gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)

    grad_amax = amax[gemm2_grad_idx, 0:1]
    grad_scale = scale[gemm2_grad_idx]
    grad_scale_inv = scale_inv[gemm2_grad_idx]

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
364
365
366
367
368
369
370
371
372
373
374
375
376
    if use_bias:
        casted_grad, casted_grad_t, dbias_2, updated_grad_amax = \
        dbias_cast_transpose(grad, grad_amax, grad_scale,
                             grad_scale_inv, bwd_dtype,
                             static_axis_boundary=-1,
                             transpose_axis_boundary=-1)
        dbias_2 = jnp.reshape(dbias_2, bias_2_shape)
    else:
        casted_grad, casted_grad_t, updated_grad_amax = \
        cast_transpose(grad, grad_amax, grad_scale,
                       grad_scale_inv, bwd_dtype,
                       static_axis_boundary=-1,
                       transpose_axis_boundary=-1)
377
        dbias_2 = None
378

379
380
381
    casted_activation_lu_out_t = transpose(casted_activation_lu_out,
                                           static_axis_boundary=-1,
                                           transpose_axis_boundary=-1)
382
383
384

    # (hidden, batch...,) x (hidden, batch...)
    gemm2_x_scale_inv = scale_inv[gemm2_x_idx]
385
386
    wgrad_2 = fp8_dot_impl(casted_activation_lu_out_t, casted_grad_t, gemm2_x_scale_inv,
                           grad_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims),
387
388
389
390
391
392
393
394
395
396
397
398
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))

    # (batch..., hidden_out) x (hidden_in, hidden_out)
    kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
    dgrad_2 = fp8_dot_impl(casted_grad, casted_kernel_2, grad_scale_inv, kernel_2_scale_inv,
                           grad.dtype, (x_contracting_dims, (1,)),
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

    gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)

399
400
401
402
    dactivation_lu_amax = amax[gemm1_grad_idx, 0:1]
    dactivation_lu_scale = scale[gemm1_grad_idx]
    dactivation_lu_scale_inv = scale_inv[gemm1_grad_idx]

403
    if len(activation_type) > 1:    # if gated
404
        if use_bias:
405
            dactivation_lu = dact_lu(dgrad_2, dot_1_output, activation_type)
406
407
408
409
410
411
412
413
414
415
416
417
            casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \
            dbias_cast_transpose(
                dactivation_lu,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
                static_axis_boundary=-1,
                transpose_axis_boundary=-2)
            dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
        else:
            casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
418
            dgated_act_lu_cast_transpose(
419
420
421
422
423
424
                dgrad_2,
                dot_1_output,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
425
426
427
                static_axis_boundary=-1,
                activation_type=activation_type)
            dbias_1 = None
428
429
    else:
        if use_bias:
430
431
            casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax=\
            dact_lu_dbias_cast_transpose(
432
433
434
435
436
437
438
                dgrad_2,
                dot_1_output,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
                static_axis_boundary=-1,
439
440
                transpose_axis_boundary=-2,
                activation_type=activation_type)
441
442
            dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
        else:
443
            dactivation_lu = dact_lu(dgrad_2, dot_1_output, activation_type)
444
445
446
447
448
449
450
451
            casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
            cast_transpose(
                dactivation_lu,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
                static_axis_boundary=-1,
452
453
                transpose_axis_boundary=-2)
            dbias_1 = None
454
455
456
457
458

    ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)

    # (hidden, batch...) x (hidden, batch...)
    gemm1_x_scale_inv = scale_inv[gemm1_x_idx]
459
    xt_batch_dims_2 = tuple(i + 1 for i in xt_batch_dims)
460
    wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv,
461
                           dactivation_lu_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims_2),
462
463
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))

464
465
    x_contracting_dims = ((min(x_contracting_dims),) + tuple(
            i + 1 for i in x_contracting_dims), (1,2))
466
    kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
467
468
    dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, dactivation_lu_scale_inv,
                           kernel_1_scale_inv, grad.dtype, x_contracting_dims,
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))

    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)

    if layernorm_type == 'layernorm':
        dx, dgamma, dbeta = layernorm_bwd(dgrad_1,
                                          x,
                                          mu,
                                          rsigma,
                                          gamma,
                                          zero_centered_gamma=zero_centered_gamma,
                                          epsilon=epsilon)
    else:
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
        dx, dgamma = rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
        dbeta = None

    amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0])
    amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0])
489
490
    amax = amax.at[gemm1_grad_idx, 0].set(updated_dactivation_lu_amax[0])
    amax = amax.at[gemm2_x_idx, 0].set(updated_activation_lu_amax[0])
491
492
493
    amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax)
    amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0])

494
495
    fp8_max, amax, scale, scale_inv = maybe_fp32_to_fm32(fp8_max, amax, scale, scale_inv)

496
497
498
499
    return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \
           fp8_max, amax, scale, scale_inv


500
_fused_layernorm_fp8_mlp.defvjp(_fused_layernorm_fp8_mlp_fwd_rule,
501
                                _fused_layernorm_fp8_mlp_bwd_rule)