layernorm_mlp.py 18.9 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX MLP modules"""

6
from typing import List, Tuple, Sequence, Union, Callable
7
from functools import partial
8
9
10

import jax
import jax.numpy as jnp
11
from jax.ad_checkpoint import checkpoint_name
12

13
from . import cpp_extensions as tex
14
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize
15
from .layernorm import canonicalize_layernorm_type
16
from .fp8 import FP8Helper, FP8MetaPackage
17
18
from .sharding import with_sharding_constraint_by_logical_axes

19

20
def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
21
    """
22
    Activation Unit
23
    """
24
    if len(activation_type) > 1:
25
        assert x.shape[-2] == 2  # Linear + GeLU
26
    output = _activation_lu(x, activation_type)
27
28
    return output

29

30
31
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
32

33
    _output, _ = _activation_lu_fwd_rule(x, activation_type)
34

35
    return _output
36

37

38
def _activation_lu_fwd_rule(x, activation_type):
39
    fwd_output = tex.act_lu(x, activation_type)
40
    return fwd_output, (x,)
41

42

43
def _activation_lu_bwd_rule(activation_type, ctx, g):
44
    (x,) = ctx
45
    assert x.dtype == g.dtype
46

47
    dx = tex.dact_lu(g, x, activation_type)
48
49
    dx = jnp.reshape(dx, x.shape)
    return (dx,)
50

51

52
_activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule)
53

54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def fused_layernorm_fp8_mlp(
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernels: List[jnp.ndarray],
    biases: List[jnp.ndarray],
    fp8_meta_pkgs: List[FP8MetaPackage],
    layernorm_type: str,
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
    layernorm_input_axes: Tuple[str, ...] = None,
    dot_1_input_axes: Tuple[str, ...] = None,
    dot_2_input_axes: Tuple[str, ...] = None,
    ffn1_ckpt_name: str = "ffn1",
    ffn2_ckpt_name: str = "ffn2",
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
    use_bias: bool = True,
) -> jnp.ndarray:
73
    """
74
    Layernorm + GEMM1 + bias + activation + GEMM2 + bias
75
76
77
    """

    assert len(kernels) == 2
78
    assert len(fp8_meta_pkgs) == len(kernels)
79
80
81
82
83

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]
84
85
86
87
    amax_list_1 = fp8_meta_pkgs[0].amax_list
    amax_list_2 = fp8_meta_pkgs[1].amax_list
    scale_list_1 = fp8_meta_pkgs[0].scale_list
    scale_list_2 = fp8_meta_pkgs[1].scale_list
88
89
90
91
92

    fwd_dtype = FP8Helper.FWD_DTYPE
    bwd_dtype = FP8Helper.BWD_DTYPE

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
93
    if layernorm_type == "rmsnorm":
94
        assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        assert (
            not zero_centered_gamma
        ), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"

    output = _fused_layernorm_fp8_mlp(
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
        amax_list_1,
        amax_list_2,
        scale_list_1,
        scale_list_2,
        fwd_dtype,
        bwd_dtype,
        layernorm_type,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,
        ffn2_ckpt_name,
        activation_type,
        use_bias,
    )
124
125
126
    return output


127
@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22))
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def _fused_layernorm_fp8_mlp(
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernel_1: jnp.ndarray,
    kernel_2: jnp.ndarray,
    bias_1: jnp.ndarray,
    bias_2: jnp.ndarray,
    amax_list_1: List[jnp.ndarray],
    amax_list_2: List[jnp.ndarray],
    scale_list_1: List[jnp.ndarray],
    scale_list_2: List[jnp.ndarray],
    fwd_dtype: jnp.dtype,
    bwd_dtype: jnp.dtype,
    layernorm_type: str,
    zero_centered_gamma: bool,
    epsilon: float,
    layernorm_input_axes: Tuple[str, ...],
    dot_1_input_axes: Tuple[str, ...],
    dot_2_input_axes: Tuple[str, ...],
    ffn1_ckpt_name: str,
    ffn2_ckpt_name: str,
    activation_type: Sequence[Union[str, Callable]],
    use_bias: bool,
):
153
    output, _ = _fused_layernorm_fp8_mlp_fwd_rule(
154
155
156
157
158
159
160
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
161
162
163
164
        amax_list_1,
        amax_list_2,
        scale_list_1,
        scale_list_2,
165
        fwd_dtype,
166
        bwd_dtype,
167
168
169
170
171
172
173
        layernorm_type,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,
174
175
        ffn2_ckpt_name,
        activation_type,
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        use_bias,
    )
    return output


def _fused_layernorm_fp8_mlp_fwd_rule(
    x,
    gamma,
    beta,
    kernel_1,
    kernel_2,
    bias_1,
    bias_2,
    amax_list_1,
    amax_list_2,
    scale_list_1,
    scale_list_2,
    fwd_dtype,
    bwd_dtype,  # pylint: disable=unused-argument
    layernorm_type,
    zero_centered_gamma,
    epsilon,
    layernorm_input_axes,
    dot_1_input_axes,
    dot_2_input_axes,
    ffn1_ckpt_name,
    ffn2_ckpt_name,
    activation_type,
    use_bias,
):
206
207
208
209
210

    # x should be in shape of (batch..., hidden)
    # Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out)
    # Kernel_2 should be in shape of (Hidden_in, Hidden_out)
    assert len(kernel_1.shape) == 3
211
    assert kernel_1.shape[-2] == len(activation_type)
212
213
214
215
216
217
218
219
    assert len(kernel_2.shape) == 2

    x_contracting_dims = (len(x.shape) - 1,)
    xt_batch_dims = tuple(range(1, x.ndim))

    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
    assert kernel_1.shape[-1] == kernel_2.shape[0]

220
221
222
    maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair(
        *amax_list_1, *scale_list_1, *amax_list_2, *scale_list_2
    )
223
224
225
226
227
228
    amax_list_1 = maybe_fm32_to_fp32(*amax_list_1)
    scale_list_1 = maybe_fm32_to_fp32(*scale_list_1)
    amax_list_2 = maybe_fm32_to_fp32(*amax_list_2)
    scale_list_2 = maybe_fm32_to_fp32(*scale_list_2)

    fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype]
229
230
231
    scale_list_1, scale_inv_list_1 = FP8MetaPackage.update_fp8_scale(
        amax_list_1, scale_list_1, fp8_dtype_list
    )
232
    amax_list_1 = FP8MetaPackage.update_amax_list(amax_list_1)
233
234
235
    scale_list_2, scale_inv_list_2 = FP8MetaPackage.update_fp8_scale(
        amax_list_2, scale_list_2, fp8_dtype_list
    )
236
237
238
239
240
    amax_list_2 = FP8MetaPackage.update_amax_list(amax_list_2)

    x_amax = amax_list_1[FP8MetaPackage.INPUT_IDX][0:1]
    x_scale = scale_list_1[FP8MetaPackage.INPUT_IDX]
    x_scale_inv = scale_inv_list_1[FP8MetaPackage.INPUT_IDX]
241
242
243

    x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)

244
    if layernorm_type == "layernorm":
245
        ln_out, mu, rsigma, updated_x_amax = tex.layernorm_fwd_fp8(
246
247
248
249
250
251
252
253
            x,
            gamma,
            beta,
            x_amax,
            x_scale,
            x_scale_inv,
            out_dtype=fwd_dtype,
            zero_centered_gamma=zero_centered_gamma,
254
255
            epsilon=epsilon,
        )
256
    else:
257
258
259
260
261
262
        assert (
            not zero_centered_gamma
        ), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
        ln_out, rsigma, updated_x_amax = tex.rmsnorm_fwd_fp8(
            x, gamma, x_amax, x_scale, x_scale_inv, out_dtype=fwd_dtype, epsilon=epsilon
        )
263
264
265
266
        mu = None

    assert x.shape == ln_out.shape

267
268
269
    kernel_1_amax = amax_list_1[FP8MetaPackage.WEIGHT_IDX][0:1]
    kernel_1_scale = scale_list_1[FP8MetaPackage.WEIGHT_IDX]
    kernel_1_scale_inv = scale_inv_list_1[FP8MetaPackage.WEIGHT_IDX]
270
271
272

    # Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
273
274
275
    casted_kernel_1, updated_kernel_1_amax = tex.cast_fp8(
        kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype
    )
276
277
278
279

    ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes)

    # (batch..., hidden_in) x (hidden_in, hidden_out)
280
281
282
283
284
285
286
287
288
    dot_1_output = fp8_dot_impl(
        ln_out,
        casted_kernel_1,
        x_scale_inv,
        kernel_1_scale_inv,
        x.dtype,
        (x_contracting_dims, (0,)),
        get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP),
    )
289
    if use_bias:
290
291
292
293
294
        bias_1_shape = bias_1.shape
        bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
        dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
    else:
        bias_1_shape = None
295
296
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)

297
298
299
    activation_lu_out_amax = amax_list_2[FP8MetaPackage.INPUT_IDX][0:1]
    activation_lu_out_scale = scale_list_2[FP8MetaPackage.INPUT_IDX]
    activation_lu_out_scale_inv = scale_inv_list_2[FP8MetaPackage.INPUT_IDX]
300
301

    # (batch..., hidden_in) -> (batch..., hidden)
302
303
304
305
306
307
308
309
    casted_activation_lu_out, updated_activation_lu_amax = tex.act_lu_fp8(
        dot_1_output,
        activation_lu_out_amax,
        activation_lu_out_scale,
        activation_lu_out_scale_inv,
        fwd_dtype,
        activation_type,
    )
310

311
    casted_activation_lu_out = with_sharding_constraint_by_logical_axes(
312
313
        casted_activation_lu_out, dot_2_input_axes
    )
314

315
316
    kernel_2_scale = scale_list_2[FP8MetaPackage.WEIGHT_IDX]
    kernel_2_scale_inv = scale_inv_list_2[FP8MetaPackage.WEIGHT_IDX]
317
318
319
320
321
    # Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)

    # (batch..., hidden_in) x (hidden_out, hidden_in)
322
323
324
325
326
327
328
329
330
    dot_2_output = fp8_dot_impl(
        casted_activation_lu_out,
        casted_kernel_2,
        activation_lu_out_scale_inv,
        kernel_2_scale_inv,
        x.dtype,
        (x_contracting_dims, (0,)),
        get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP),
    )
331

332
    if use_bias:
333
334
335
336
337
        bias_2_shape = bias_2.shape
        bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
        dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
    else:
        bias_2_shape = None
338

339
340
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
    ctx = (
        x,
        ln_out,
        mu,
        rsigma,
        gamma,
        dot_1_output,
        casted_activation_lu_out,
        casted_kernel_1,
        casted_kernel_2,
        amax_list_1,
        amax_list_2,
        scale_list_1,
        scale_list_2,
        scale_inv_list_1,
        scale_inv_list_2,
        updated_x_amax,
        updated_activation_lu_amax,
        updated_kernel_1_amax,
        updated_kernel_2_amax,
        x_contracting_dims,
        xt_batch_dims,
        bias_1_shape,
        bias_2_shape,
        maybe_fp32_to_fm32,
    )
367
368
369
370

    return dot_2_output, ctx


371
def _fused_layernorm_fp8_mlp_bwd_rule(
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
    fwd_dtype,  # pylint: disable=unused-argument
    bwd_dtype,
    layernorm_type,
    zero_centered_gamma,
    epsilon,
    layernorm_input_axes,
    dot_1_input_axes,
    dot_2_input_axes,
    ffn1_ckpt_name,  # pylint: disable=unused-argument
    ffn2_ckpt_name,  # pylint: disable=unused-argument
    activation_type,
    use_bias,
    ctx,
    grad,
):
    (
        x,
        ln_out,
        mu,
        rsigma,
        gamma,
        dot_1_output,
        casted_activation_lu_out,
        casted_kernel_1,
        casted_kernel_2,
        amax_list_1,
        amax_list_2,
        scale_list_1,
        scale_list_2,
        scale_inv_list_1,
        scale_inv_list_2,
        updated_x_amax,
        updated_activation_lu_amax,
        updated_kernel_1_amax,
        updated_kernel_2_amax,
        x_contracting_dims,
        xt_batch_dims,
        bias_1_shape,
        bias_2_shape,
        maybe_fp32_to_fm32,
    ) = ctx
413

414
415
416
    grad_amax = amax_list_2[FP8MetaPackage.GRAD_IDX][0:1]
    grad_scale = scale_list_2[FP8MetaPackage.GRAD_IDX]
    grad_scale_inv = scale_inv_list_2[FP8MetaPackage.GRAD_IDX]
417
418
419

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
420
    if use_bias:
421
422
423
424
425
426
427
428
429
        casted_grad, casted_grad_t, dbias_2, updated_grad_amax = tex.dbias_cast_transpose(
            grad,
            grad_amax,
            grad_scale,
            grad_scale_inv,
            bwd_dtype,
            static_axis_boundary=-1,
            transpose_axis_boundary=-1,
        )
430
431
        dbias_2 = jnp.reshape(dbias_2, bias_2_shape)
    else:
432
433
434
435
436
437
438
439
440
        casted_grad, casted_grad_t, updated_grad_amax = tex.cast_transpose(
            grad,
            grad_amax,
            grad_scale,
            grad_scale_inv,
            bwd_dtype,
            static_axis_boundary=-1,
            transpose_axis_boundary=-1,
        )
441
        dbias_2 = None
442

443
444
445
    casted_activation_lu_out_t = tex.transpose(
        casted_activation_lu_out, static_axis_boundary=-1, transpose_axis_boundary=-1
    )
446
447

    # (hidden, batch...,) x (hidden, batch...)
448
    gemm2_x_scale_inv = scale_inv_list_2[FP8MetaPackage.INPUT_IDX]
449
450
451
452
453
454
455
456
457
    wgrad_2 = fp8_dot_impl(
        casted_activation_lu_out_t,
        casted_grad_t,
        gemm2_x_scale_inv,
        grad_scale_inv,
        grad.dtype,
        (xt_batch_dims, xt_batch_dims),
        get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD),
    )
458
459

    # (batch..., hidden_out) x (hidden_in, hidden_out)
460
    kernel_2_scale_inv = scale_inv_list_2[FP8MetaPackage.WEIGHT_IDX]
461
462
463
464
465
466
467
468
469
    dgrad_2 = fp8_dot_impl(
        casted_grad,
        casted_kernel_2,
        grad_scale_inv,
        kernel_2_scale_inv,
        grad.dtype,
        (x_contracting_dims, (1,)),
        get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD),
    )
470
471
472

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

473
474
475
    dactivation_lu_amax = amax_list_1[FP8MetaPackage.GRAD_IDX][0:1]
    dactivation_lu_scale = scale_list_1[FP8MetaPackage.GRAD_IDX]
    dactivation_lu_scale_inv = scale_inv_list_1[FP8MetaPackage.GRAD_IDX]
476

477
    if len(activation_type) > 1:  # if gated
478
        if use_bias:
479
            dactivation_lu = tex.dact_lu(dgrad_2, dot_1_output, activation_type)
480
481
482
483
484
485
486
487
488
489
490
            casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = (
                tex.dbias_cast_transpose(
                    dactivation_lu,
                    dactivation_lu_amax,
                    dactivation_lu_scale,
                    dactivation_lu_scale_inv,
                    bwd_dtype,
                    static_axis_boundary=-1,
                    transpose_axis_boundary=-2,
                )
            )
491
492
            dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
        else:
493
494
495
496
497
498
499
500
501
502
503
504
            casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = (
                tex.dgated_act_lu_cast_transpose(
                    dgrad_2,
                    dot_1_output,
                    dactivation_lu_amax,
                    dactivation_lu_scale,
                    dactivation_lu_scale_inv,
                    bwd_dtype,
                    static_axis_boundary=-1,
                    activation_type=activation_type,
                )
            )
505
            dbias_1 = None
506
507
    else:
        if use_bias:
508
509
510
511
512
513
514
515
516
517
518
519
520
            casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = (
                tex.dact_lu_dbias_cast_transpose(
                    dgrad_2,
                    dot_1_output,
                    dactivation_lu_amax,
                    dactivation_lu_scale,
                    dactivation_lu_scale_inv,
                    bwd_dtype,
                    static_axis_boundary=-1,
                    transpose_axis_boundary=-2,
                    activation_type=activation_type,
                )
            )
521
522
            dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
        else:
523
            dactivation_lu = tex.dact_lu(dgrad_2, dot_1_output, activation_type)
524
525
526
527
528
529
530
531
532
533
534
            casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = (
                tex.cast_transpose(
                    dactivation_lu,
                    dactivation_lu_amax,
                    dactivation_lu_scale,
                    dactivation_lu_scale_inv,
                    bwd_dtype,
                    static_axis_boundary=-1,
                    transpose_axis_boundary=-2,
                )
            )
535
            dbias_1 = None
536

537
    ln_out_t = tex.transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
538
539

    # (hidden, batch...) x (hidden, batch...)
540
    gemm1_x_scale_inv = scale_inv_list_1[FP8MetaPackage.INPUT_IDX]
541
    xt_batch_dims_2 = tuple(i + 1 for i in xt_batch_dims)
542
543
544
545
546
547
548
549
550
551
552
553
554
555
    wgrad_1 = fp8_dot_impl(
        ln_out_t,
        casted_dactivation_lu_t,
        gemm1_x_scale_inv,
        dactivation_lu_scale_inv,
        grad.dtype,
        (xt_batch_dims, xt_batch_dims_2),
        get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD),
    )

    x_contracting_dims = (
        (min(x_contracting_dims),) + tuple(i + 1 for i in x_contracting_dims),
        (1, 2),
    )
556
    kernel_1_scale_inv = scale_inv_list_1[FP8MetaPackage.WEIGHT_IDX]
557
558
559
560
561
562
563
564
565
    dgrad_1 = fp8_dot_impl(
        casted_dactivation_lu,
        casted_kernel_1,
        dactivation_lu_scale_inv,
        kernel_1_scale_inv,
        grad.dtype,
        x_contracting_dims,
        get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD),
    )
566
567
568

    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)

569
570
571
572
    if layernorm_type == "layernorm":
        dx, dgamma, dbeta = tex.layernorm_bwd(
            dgrad_1, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
        )
573
    else:
574
575
576
        assert (
            not zero_centered_gamma
        ), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
577
        dx, dgamma = tex.rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
578
579
        dbeta = None

580
    amax_list_1[FP8MetaPackage.INPUT_IDX] = (
581
        amax_list_1[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0])
582
583
    )
    amax_list_1[FP8MetaPackage.WEIGHT_IDX] = (
584
        amax_list_1[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_1_amax[0])
585
586
    )
    amax_list_1[FP8MetaPackage.GRAD_IDX] = (
587
        amax_list_1[FP8MetaPackage.GRAD_IDX].at[0].set(updated_dactivation_lu_amax[0])
588
589
    )
    amax_list_2[FP8MetaPackage.INPUT_IDX] = (
590
        amax_list_2[FP8MetaPackage.INPUT_IDX].at[0].set(updated_activation_lu_amax[0])
591
592
    )
    amax_list_2[FP8MetaPackage.WEIGHT_IDX] = (
593
        amax_list_2[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_2_amax)
594
595
    )
    amax_list_2[FP8MetaPackage.GRAD_IDX] = (
596
        amax_list_2[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0])
597
    )
598
599
600
601
602

    amax_list_1 = maybe_fp32_to_fm32(*amax_list_1)
    scale_list_1 = maybe_fp32_to_fm32(*scale_list_1)
    amax_list_2 = maybe_fp32_to_fm32(*amax_list_2)
    scale_list_2 = maybe_fp32_to_fm32(*scale_list_2)
603

604
605
606
607
608
609
610
611
612
613
614
615
616
    return (
        dx,
        dgamma,
        dbeta,
        wgrad_1,
        wgrad_2,
        dbias_1,
        dbias_2,
        amax_list_1,
        amax_list_2,
        scale_list_1,
        scale_list_2,
    )
617
618


619
620
621
_fused_layernorm_fp8_mlp.defvjp(
    _fused_layernorm_fp8_mlp_fwd_rule, _fused_layernorm_fp8_mlp_bwd_rule
)