layernorm_mlp.py 19 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX MLP modules"""

6
from typing import List, Tuple, Sequence, Union, Callable
7
from functools import partial
8
9
10

import jax
import jax.numpy as jnp
11
from jax.ad_checkpoint import checkpoint_name
12

13
from . import cpp_extensions as tex
14
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize
15
from .layernorm import canonicalize_layernorm_type
16
from .fp8 import FP8Helper, FP8MetaPackage
17
18
from .sharding import with_sharding_constraint_by_logical_axes

19

20
def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
21
    """
22
    Activation Unit
23
    """
24
    if len(activation_type) > 1:
25
        assert x.shape[-2] == 2  # Linear + GeLU
26
    output = _activation_lu(x, activation_type)
27
28
    return output

29

30
31
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
32

33
    _output, _ = _activation_lu_fwd_rule(x, activation_type)
34

35
    return _output
36

37

38
def _activation_lu_fwd_rule(x, activation_type):
39
    fwd_output = tex.act_lu(x, activation_type)
40
    return fwd_output, (x,)
41

42

43
def _activation_lu_bwd_rule(activation_type, ctx, g):
44
    (x,) = ctx
45
    assert x.dtype == g.dtype
46

47
    dx = tex.dact_lu(g, x, activation_type)
48
49
    dx = jnp.reshape(dx, x.shape)
    return (dx,)
50

51

52
_activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule)
53

54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def fused_layernorm_fp8_mlp(
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernels: List[jnp.ndarray],
    biases: List[jnp.ndarray],
    fp8_meta_pkgs: List[FP8MetaPackage],
    layernorm_type: str,
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
    layernorm_input_axes: Tuple[str, ...] = None,
    dot_1_input_axes: Tuple[str, ...] = None,
    dot_2_input_axes: Tuple[str, ...] = None,
    ffn1_ckpt_name: str = "ffn1",
    ffn2_ckpt_name: str = "ffn2",
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
    use_bias: bool = True,
) -> jnp.ndarray:
73
    """
74
    Layernorm + GEMM1 + bias + activation + GEMM2 + bias
75
76
77
    """

    assert len(kernels) == 2
78
    assert len(fp8_meta_pkgs) == len(kernels)
79
80
81
82
83

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]
84
85
86
87
    amax_list_1 = fp8_meta_pkgs[0].amax_list
    amax_list_2 = fp8_meta_pkgs[1].amax_list
    scale_list_1 = fp8_meta_pkgs[0].scale_list
    scale_list_2 = fp8_meta_pkgs[1].scale_list
88
89
90
91
92

    fwd_dtype = FP8Helper.FWD_DTYPE
    bwd_dtype = FP8Helper.BWD_DTYPE

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
93
    if layernorm_type == "rmsnorm":
94
        assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        assert (
            not zero_centered_gamma
        ), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"

    output = _fused_layernorm_fp8_mlp(
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
        amax_list_1,
        amax_list_2,
        scale_list_1,
        scale_list_2,
        fwd_dtype,
        bwd_dtype,
        layernorm_type,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,
        ffn2_ckpt_name,
        activation_type,
        use_bias,
    )
124
125
126
    return output


127
@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22))
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def _fused_layernorm_fp8_mlp(
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernel_1: jnp.ndarray,
    kernel_2: jnp.ndarray,
    bias_1: jnp.ndarray,
    bias_2: jnp.ndarray,
    amax_list_1: List[jnp.ndarray],
    amax_list_2: List[jnp.ndarray],
    scale_list_1: List[jnp.ndarray],
    scale_list_2: List[jnp.ndarray],
    fwd_dtype: jnp.dtype,
    bwd_dtype: jnp.dtype,
    layernorm_type: str,
    zero_centered_gamma: bool,
    epsilon: float,
    layernorm_input_axes: Tuple[str, ...],
    dot_1_input_axes: Tuple[str, ...],
    dot_2_input_axes: Tuple[str, ...],
    ffn1_ckpt_name: str,
    ffn2_ckpt_name: str,
    activation_type: Sequence[Union[str, Callable]],
    use_bias: bool,
):
153
    output, _ = _fused_layernorm_fp8_mlp_fwd_rule(
154
155
156
157
158
159
160
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
161
162
163
164
        amax_list_1,
        amax_list_2,
        scale_list_1,
        scale_list_2,
165
        fwd_dtype,
166
        bwd_dtype,
167
168
169
170
171
172
173
        layernorm_type,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,
174
175
        ffn2_ckpt_name,
        activation_type,
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        use_bias,
    )
    return output


def _fused_layernorm_fp8_mlp_fwd_rule(
    x,
    gamma,
    beta,
    kernel_1,
    kernel_2,
    bias_1,
    bias_2,
    amax_list_1,
    amax_list_2,
    scale_list_1,
    scale_list_2,
    fwd_dtype,
    bwd_dtype,  # pylint: disable=unused-argument
    layernorm_type,
    zero_centered_gamma,
    epsilon,
    layernorm_input_axes,
    dot_1_input_axes,
    dot_2_input_axes,
    ffn1_ckpt_name,
    ffn2_ckpt_name,
    activation_type,
    use_bias,
):
206
207
208
209
210

    # x should be in shape of (batch..., hidden)
    # Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out)
    # Kernel_2 should be in shape of (Hidden_in, Hidden_out)
    assert len(kernel_1.shape) == 3
211
    assert kernel_1.shape[-2] == len(activation_type)
212
213
214
215
216
217
218
219
    assert len(kernel_2.shape) == 2

    x_contracting_dims = (len(x.shape) - 1,)
    xt_batch_dims = tuple(range(1, x.ndim))

    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
    assert kernel_1.shape[-1] == kernel_2.shape[0]

220
221
222
    maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair(
        *amax_list_1, *scale_list_1, *amax_list_2, *scale_list_2
    )
223
224
225
226
227
228
    amax_list_1 = maybe_fm32_to_fp32(*amax_list_1)
    scale_list_1 = maybe_fm32_to_fp32(*scale_list_1)
    amax_list_2 = maybe_fm32_to_fp32(*amax_list_2)
    scale_list_2 = maybe_fm32_to_fp32(*scale_list_2)

    fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype]
229
230
231
    scale_list_1, scale_inv_list_1 = FP8MetaPackage.update_fp8_scale(
        amax_list_1, scale_list_1, fp8_dtype_list
    )
232
    amax_list_1 = FP8MetaPackage.update_amax_list(amax_list_1)
233
234
235
    scale_list_2, scale_inv_list_2 = FP8MetaPackage.update_fp8_scale(
        amax_list_2, scale_list_2, fp8_dtype_list
    )
236
237
238
239
240
    amax_list_2 = FP8MetaPackage.update_amax_list(amax_list_2)

    x_amax = amax_list_1[FP8MetaPackage.INPUT_IDX][0:1]
    x_scale = scale_list_1[FP8MetaPackage.INPUT_IDX]
    x_scale_inv = scale_inv_list_1[FP8MetaPackage.INPUT_IDX]
241
242
243

    x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)

244
    if layernorm_type == "layernorm":
245
        ln_out, mu, rsigma, updated_x_amax = tex.layernorm_fwd_fp8(
246
247
248
249
250
251
252
253
            x,
            gamma,
            beta,
            x_amax,
            x_scale,
            x_scale_inv,
            out_dtype=fwd_dtype,
            zero_centered_gamma=zero_centered_gamma,
254
255
            epsilon=epsilon,
        )
256
    else:
257
258
259
260
261
262
        assert (
            not zero_centered_gamma
        ), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
        ln_out, rsigma, updated_x_amax = tex.rmsnorm_fwd_fp8(
            x, gamma, x_amax, x_scale, x_scale_inv, out_dtype=fwd_dtype, epsilon=epsilon
        )
263
264
265
266
        mu = None

    assert x.shape == ln_out.shape

267
268
269
    kernel_1_amax = amax_list_1[FP8MetaPackage.WEIGHT_IDX][0:1]
    kernel_1_scale = scale_list_1[FP8MetaPackage.WEIGHT_IDX]
    kernel_1_scale_inv = scale_inv_list_1[FP8MetaPackage.WEIGHT_IDX]
270
271
272

    # Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
273
274
275
    casted_kernel_1, updated_kernel_1_amax = tex.cast_fp8(
        kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype
    )
276
277
278
279

    ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes)

    # (batch..., hidden_in) x (hidden_in, hidden_out)
280
281
282
283
284
285
286
287
288
    dot_1_output = fp8_dot_impl(
        ln_out,
        casted_kernel_1,
        x_scale_inv,
        kernel_1_scale_inv,
        x.dtype,
        (x_contracting_dims, (0,)),
        get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP),
    )
289
    if use_bias:
290
291
292
293
294
        bias_1_shape = bias_1.shape
        bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
        dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
    else:
        bias_1_shape = None
295
296
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)

297
298
299
    activation_lu_out_amax = amax_list_2[FP8MetaPackage.INPUT_IDX][0:1]
    activation_lu_out_scale = scale_list_2[FP8MetaPackage.INPUT_IDX]
    activation_lu_out_scale_inv = scale_inv_list_2[FP8MetaPackage.INPUT_IDX]
300
301

    # (batch..., hidden_in) -> (batch..., hidden)
302
303
304
305
306
307
308
309
    casted_activation_lu_out, updated_activation_lu_amax = tex.act_lu_fp8(
        dot_1_output,
        activation_lu_out_amax,
        activation_lu_out_scale,
        activation_lu_out_scale_inv,
        fwd_dtype,
        activation_type,
    )
310

311
    casted_activation_lu_out = with_sharding_constraint_by_logical_axes(
312
313
        casted_activation_lu_out, dot_2_input_axes
    )
314

315
316
    kernel_2_scale = scale_list_2[FP8MetaPackage.WEIGHT_IDX]
    kernel_2_scale_inv = scale_inv_list_2[FP8MetaPackage.WEIGHT_IDX]
317
318
319
320
321
    # Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)

    # (batch..., hidden_in) x (hidden_out, hidden_in)
322
323
324
325
326
327
328
329
330
    dot_2_output = fp8_dot_impl(
        casted_activation_lu_out,
        casted_kernel_2,
        activation_lu_out_scale_inv,
        kernel_2_scale_inv,
        x.dtype,
        (x_contracting_dims, (0,)),
        get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP),
    )
331

332
    if use_bias:
333
334
335
336
337
        bias_2_shape = bias_2.shape
        bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
        dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
    else:
        bias_2_shape = None
338

339
340
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

341
342
343
344
345
346
    ctx = (
        x,
        ln_out,
        mu,
        rsigma,
        gamma,
347
        beta,
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
        dot_1_output,
        casted_activation_lu_out,
        casted_kernel_1,
        casted_kernel_2,
        amax_list_1,
        amax_list_2,
        scale_list_1,
        scale_list_2,
        scale_inv_list_1,
        scale_inv_list_2,
        updated_x_amax,
        updated_activation_lu_amax,
        updated_kernel_1_amax,
        updated_kernel_2_amax,
        x_contracting_dims,
        xt_batch_dims,
        bias_1_shape,
        bias_2_shape,
        maybe_fp32_to_fm32,
    )
368
369
370
371

    return dot_2_output, ctx


372
def _fused_layernorm_fp8_mlp_bwd_rule(
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
    fwd_dtype,  # pylint: disable=unused-argument
    bwd_dtype,
    layernorm_type,
    zero_centered_gamma,
    epsilon,
    layernorm_input_axes,
    dot_1_input_axes,
    dot_2_input_axes,
    ffn1_ckpt_name,  # pylint: disable=unused-argument
    ffn2_ckpt_name,  # pylint: disable=unused-argument
    activation_type,
    use_bias,
    ctx,
    grad,
):
    (
        x,
        ln_out,
        mu,
        rsigma,
        gamma,
394
        beta,
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
        dot_1_output,
        casted_activation_lu_out,
        casted_kernel_1,
        casted_kernel_2,
        amax_list_1,
        amax_list_2,
        scale_list_1,
        scale_list_2,
        scale_inv_list_1,
        scale_inv_list_2,
        updated_x_amax,
        updated_activation_lu_amax,
        updated_kernel_1_amax,
        updated_kernel_2_amax,
        x_contracting_dims,
        xt_batch_dims,
        bias_1_shape,
        bias_2_shape,
        maybe_fp32_to_fm32,
    ) = ctx
415

416
417
418
    grad_amax = amax_list_2[FP8MetaPackage.GRAD_IDX][0:1]
    grad_scale = scale_list_2[FP8MetaPackage.GRAD_IDX]
    grad_scale_inv = scale_inv_list_2[FP8MetaPackage.GRAD_IDX]
419
420
421

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
422
    if use_bias:
423
424
425
426
427
428
429
430
431
        casted_grad, casted_grad_t, dbias_2, updated_grad_amax = tex.dbias_cast_transpose(
            grad,
            grad_amax,
            grad_scale,
            grad_scale_inv,
            bwd_dtype,
            static_axis_boundary=-1,
            transpose_axis_boundary=-1,
        )
432
433
        dbias_2 = jnp.reshape(dbias_2, bias_2_shape)
    else:
434
435
436
437
438
439
440
441
442
        casted_grad, casted_grad_t, updated_grad_amax = tex.cast_transpose(
            grad,
            grad_amax,
            grad_scale,
            grad_scale_inv,
            bwd_dtype,
            static_axis_boundary=-1,
            transpose_axis_boundary=-1,
        )
443
        dbias_2 = None
444

445
446
447
    casted_activation_lu_out_t = tex.transpose(
        casted_activation_lu_out, static_axis_boundary=-1, transpose_axis_boundary=-1
    )
448
449

    # (hidden, batch...,) x (hidden, batch...)
450
    gemm2_x_scale_inv = scale_inv_list_2[FP8MetaPackage.INPUT_IDX]
451
452
453
454
455
456
457
458
459
    wgrad_2 = fp8_dot_impl(
        casted_activation_lu_out_t,
        casted_grad_t,
        gemm2_x_scale_inv,
        grad_scale_inv,
        grad.dtype,
        (xt_batch_dims, xt_batch_dims),
        get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD),
    )
460
461

    # (batch..., hidden_out) x (hidden_in, hidden_out)
462
    kernel_2_scale_inv = scale_inv_list_2[FP8MetaPackage.WEIGHT_IDX]
463
464
465
466
467
468
469
470
471
    dgrad_2 = fp8_dot_impl(
        casted_grad,
        casted_kernel_2,
        grad_scale_inv,
        kernel_2_scale_inv,
        grad.dtype,
        (x_contracting_dims, (1,)),
        get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD),
    )
472
473
474

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

475
476
477
    dactivation_lu_amax = amax_list_1[FP8MetaPackage.GRAD_IDX][0:1]
    dactivation_lu_scale = scale_list_1[FP8MetaPackage.GRAD_IDX]
    dactivation_lu_scale_inv = scale_inv_list_1[FP8MetaPackage.GRAD_IDX]
478

479
    if len(activation_type) > 1:  # if gated
480
        if use_bias:
481
            dactivation_lu = tex.dact_lu(dgrad_2, dot_1_output, activation_type)
482
483
484
485
486
487
488
489
490
491
492
            casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = (
                tex.dbias_cast_transpose(
                    dactivation_lu,
                    dactivation_lu_amax,
                    dactivation_lu_scale,
                    dactivation_lu_scale_inv,
                    bwd_dtype,
                    static_axis_boundary=-1,
                    transpose_axis_boundary=-2,
                )
            )
493
494
            dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
        else:
495
496
497
498
499
500
501
502
503
504
505
506
            casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = (
                tex.dgated_act_lu_cast_transpose(
                    dgrad_2,
                    dot_1_output,
                    dactivation_lu_amax,
                    dactivation_lu_scale,
                    dactivation_lu_scale_inv,
                    bwd_dtype,
                    static_axis_boundary=-1,
                    activation_type=activation_type,
                )
            )
507
            dbias_1 = None
508
509
    else:
        if use_bias:
510
511
512
513
514
515
516
517
518
519
520
521
            casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = (
                tex.dact_lu_dbias_cast_transpose(
                    dgrad_2,
                    dot_1_output,
                    dactivation_lu_amax,
                    dactivation_lu_scale,
                    dactivation_lu_scale_inv,
                    bwd_dtype,
                    static_axis_boundary=-1,
                    activation_type=activation_type,
                )
            )
522
523
            dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
        else:
524
            dactivation_lu = tex.dact_lu(dgrad_2, dot_1_output, activation_type)
525
526
527
528
529
530
531
532
533
534
535
            casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = (
                tex.cast_transpose(
                    dactivation_lu,
                    dactivation_lu_amax,
                    dactivation_lu_scale,
                    dactivation_lu_scale_inv,
                    bwd_dtype,
                    static_axis_boundary=-1,
                    transpose_axis_boundary=-2,
                )
            )
536
            dbias_1 = None
537

538
    ln_out_t = tex.transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
539
540

    # (hidden, batch...) x (hidden, batch...)
541
    gemm1_x_scale_inv = scale_inv_list_1[FP8MetaPackage.INPUT_IDX]
542
    xt_batch_dims_2 = tuple(i + 1 for i in xt_batch_dims)
543
544
545
546
547
548
549
550
551
552
553
554
555
556
    wgrad_1 = fp8_dot_impl(
        ln_out_t,
        casted_dactivation_lu_t,
        gemm1_x_scale_inv,
        dactivation_lu_scale_inv,
        grad.dtype,
        (xt_batch_dims, xt_batch_dims_2),
        get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD),
    )

    x_contracting_dims = (
        (min(x_contracting_dims),) + tuple(i + 1 for i in x_contracting_dims),
        (1, 2),
    )
557
    kernel_1_scale_inv = scale_inv_list_1[FP8MetaPackage.WEIGHT_IDX]
558
559
560
561
562
563
564
565
566
    dgrad_1 = fp8_dot_impl(
        casted_dactivation_lu,
        casted_kernel_1,
        dactivation_lu_scale_inv,
        kernel_1_scale_inv,
        grad.dtype,
        x_contracting_dims,
        get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD),
    )
567
568
569

    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)

570
571
    if layernorm_type == "layernorm":
        dx, dgamma, dbeta = tex.layernorm_bwd(
572
573
574
575
576
577
578
579
            dgrad_1,
            x,
            mu,
            rsigma,
            gamma,
            beta,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon,
580
        )
581
    else:
582
583
584
        assert (
            not zero_centered_gamma
        ), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
585
        dx, dgamma = tex.rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
586
587
        dbeta = None

588
    amax_list_1[FP8MetaPackage.INPUT_IDX] = (
589
        amax_list_1[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0])
590
591
    )
    amax_list_1[FP8MetaPackage.WEIGHT_IDX] = (
592
        amax_list_1[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_1_amax[0])
593
594
    )
    amax_list_1[FP8MetaPackage.GRAD_IDX] = (
595
        amax_list_1[FP8MetaPackage.GRAD_IDX].at[0].set(updated_dactivation_lu_amax[0])
596
597
    )
    amax_list_2[FP8MetaPackage.INPUT_IDX] = (
598
        amax_list_2[FP8MetaPackage.INPUT_IDX].at[0].set(updated_activation_lu_amax[0])
599
600
    )
    amax_list_2[FP8MetaPackage.WEIGHT_IDX] = (
601
        amax_list_2[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_2_amax)
602
603
    )
    amax_list_2[FP8MetaPackage.GRAD_IDX] = (
604
        amax_list_2[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0])
605
    )
606
607
608
609
610

    amax_list_1 = maybe_fp32_to_fm32(*amax_list_1)
    scale_list_1 = maybe_fp32_to_fm32(*scale_list_1)
    amax_list_2 = maybe_fp32_to_fm32(*amax_list_2)
    scale_list_2 = maybe_fp32_to_fm32(*scale_list_2)
611

612
613
614
615
616
617
618
619
620
621
622
623
624
    return (
        dx,
        dgamma,
        dbeta,
        wgrad_1,
        wgrad_2,
        dbias_1,
        dbias_2,
        amax_list_1,
        amax_list_2,
        scale_list_1,
        scale_list_2,
    )
625
626


627
628
629
_fused_layernorm_fp8_mlp.defvjp(
    _fused_layernorm_fp8_mlp_fwd_rule, _fused_layernorm_fp8_mlp_bwd_rule
)