mlp.py 28 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX MLP modules"""

6
from typing import List, Tuple
7
from functools import partial
8
9
10

import jax
import jax.numpy as jnp
11
from jax.ad_checkpoint import checkpoint_name
12

13
from .cpp_extensions import cast_fp8, transpose, cast_transpose
14
15
from .cpp_extensions import gelu as te_gelu
from .cpp_extensions import gelu_fp8, dgelu, dgelu_dbias_cast_transpose
16
17
18
19
from .cpp_extensions import gated_gelu, gated_gelu_fp8
from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd
20
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize
21
from .layernorm import canonicalize_layernorm_type
22
from .fp8 import FP8Helper, FP8MetaPackage
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from .sharding import with_sharding_constraint_by_logical_axes


def gelu(x: jnp.ndarray):
    """
    Gelu
    """
    output = _gelu(x)
    return output


@partial(jax.custom_vjp)
def _gelu(x: jnp.ndarray):

    geglu_output, _ = _gelu_fwd_rule(x)

    return geglu_output


def _gelu_fwd_rule(x):
    geglu_output = te_gelu(x)
    return geglu_output, (x,)


def _gelu_bwd_rule(ctx, g):
    x, = ctx
    assert x.dtype == g.dtype

    dx = dgelu(g, x)
    dx = jnp.reshape(dx, x.shape)
    return (dx,)


_gelu.defvjp(_gelu_fwd_rule, _gelu_bwd_rule)
57
58


59
def geglu(x: jnp.ndarray):
60
61
62
    """
    Gated gelu
    """
63
    assert x.shape[-2] == 2    # Linear + GeLU
64

65
    output = _geglu(x)
66
67
68
69

    return output


70
71
@partial(jax.custom_vjp)
def _geglu(x: jnp.ndarray):
72

73
    geglu_output, _ = _geglu_fwd_rule(x)
74
75
76
77

    return geglu_output


78
79
80
def _geglu_fwd_rule(x):
    geglu_output = gated_gelu(x)
    return geglu_output, (x,)
81
82


83
84
85
def _geglu_bwd_rule(ctx, g):
    x, = ctx
    assert x.dtype == g.dtype
86

87
88
89
    dx = dgated_gelu(g, x)
    dx = jnp.reshape(dx, x.shape)
    return (dx,)
90
91


92
_geglu.defvjp(_geglu_fwd_rule, _geglu_bwd_rule)
93
94


95
def layernorm_geglu_fp8_mlp(x: jnp.ndarray,
96
97
98
99
100
101
                            gamma: jnp.ndarray,
                            beta: jnp.ndarray,
                            kernels: List[jnp.ndarray],
                            fp8_gemm_pkg: FP8MetaPackage,
                            layernorm_type: str,
                            zero_centered_gamma: bool = False,
102
103
104
105
106
107
                            epsilon: float = 1e-6,
                            layernorm_input_axes: Tuple[str, ...] = None,
                            dot_1_input_axes: Tuple[str, ...] = None,
                            dot_2_input_axes: Tuple[str, ...] = None,
                            ffn1_ckpt_name: str = 'ffn1',
                            ffn2_ckpt_name: str = 'ffn2') -> jnp.ndarray:
108
    """
109
    Layernorm + GEMM1 + GeGLU + GEMM2
110
    """
111
112
113
114
115
116

    assert len(kernels) == 2
    assert fp8_gemm_pkg.num_of_gemm == len(kernels)

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
117
118
119
120
121
    fp8_max = fp8_gemm_pkg.fp8_max
    amax = fp8_gemm_pkg.amax
    scale = fp8_gemm_pkg.scale
    scale_inv = fp8_gemm_pkg.scale_inv

122
123
124
    fwd_dtype = FP8Helper.FWD_DTYPE
    bwd_dtype = FP8Helper.BWD_DTYPE

125
126
    layernorm_type = canonicalize_layernorm_type(layernorm_type)
    if layernorm_type == 'rmsnorm':
127
        assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
128
129
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
130

131
    output = _layernorm_geglu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale,
132
                                      scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
133
134
135
                                      zero_centered_gamma, epsilon, layernorm_input_axes,
                                      dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name,
                                      ffn2_ckpt_name)
136
137
138
    return output


139
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
140
def _layernorm_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
141
142
143
                             kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_max: jnp.ndarray,
                             amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
                             fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, layernorm_type: str,
144
145
146
147
                             zero_centered_gamma: bool, epsilon: float,
                             layernorm_input_axes: Tuple[str, ...],
                             dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
                             ffn1_ckpt_name: str, ffn2_ckpt_name: str):
148
    output, _ = _layernorm_geglu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax,
149
                                                  scale, scale_inv, fwd_dtype, bwd_dtype,
150
151
152
                                                  layernorm_type, zero_centered_gamma, epsilon,
                                                  layernorm_input_axes, dot_1_input_axes,
                                                  dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name)
153
154
155
    return output


156
def _layernorm_geglu_fp8_mlp_fwd_rule(
157
        x,
158
159
160
161
        gamma,
        beta,
        kernel_1,
        kernel_2,
162
        fp8_max,
163
164
165
166
167
        amax,
        scale,
        scale_inv,
        fwd_dtype,
        bwd_dtype,    # pylint: disable=unused-argument
168
169
        layernorm_type,
        zero_centered_gamma,
170
171
172
173
174
175
        epsilon,
        layernorm_input_axes,
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,
        ffn2_ckpt_name):
176
177
178
179
180
181
182
183
184
185
186
187
188

    # x should be in shape of (batch..., hidden)
    # Kernel_1 should be in shape of (Hidden_in, 2, Hidden_out)
    # Kernel_2 should be in shape of (Hidden_in, Hidden_out)
    assert len(kernel_1.shape) == 3
    assert kernel_1.shape[-2] == 2
    assert len(kernel_2.shape) == 2

    x_contracting_dims = (len(x.shape) - 1,)
    xt_batch_dims = tuple(range(1, x.ndim))

    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
    assert kernel_1.shape[-1] == kernel_2.shape[0]
189

190
191
    amax = FP8Helper.update_amax_history(amax)

192
193
194
195
196
    gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)

    x_amax = amax[gemm1_x_idx, 0:1]
    x_scale = scale[gemm1_x_idx]
    x_scale_inv = scale_inv[gemm1_x_idx]
197

198
199
    x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)

200
    if layernorm_type == 'layernorm':
201
202
203
204
205
206
207
208
209
210
        ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
            x,
            gamma,
            beta,
            x_amax,
            x_scale,
            x_scale_inv,
            out_dtype=fwd_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon)
211
    else:
212
213
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
214
215
216
217
218
219
220
        ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
                                                         gamma,
                                                         x_amax,
                                                         x_scale,
                                                         x_scale_inv,
                                                         out_dtype=fwd_dtype,
                                                         epsilon=epsilon)
221
222
        mu = None

223
224
    assert x.shape == ln_out.shape

225
    kernel_1_amax = amax[gemm1_kernel_idx, 0:1]
226
227
228
    kernel_1_scale = scale[gemm1_kernel_idx]
    kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]

229
230
231
232
    # Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel_1, updated_kernel_1_amax = \
        cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)
233

234
235
    ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes)

236
237
238
239
    # (batch..., hidden_in) x (hidden_in, 2, hidden_out)
    dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
                                (x_contracting_dims, (0,)),
                                get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
240
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
241
242
243
244
245
246
247
248
249
250
251

    gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)

    geglu_out_amax = amax[gemm2_x_idx, 0:1]
    geglu_out_scale = scale[gemm2_x_idx]
    geglu_out_scale_inv = scale_inv[gemm2_x_idx]

    # (batch..., hidden_in) -> (batch..., hidden)
    casted_geglu_out, updated_geglu_amax = gated_gelu_fp8(dot_1_output, geglu_out_amax,
                                                          geglu_out_scale, geglu_out_scale_inv,
                                                          fwd_dtype)
252

253
254
    casted_geglu_out = with_sharding_constraint_by_logical_axes(casted_geglu_out, dot_2_input_axes)

255
256
    kernel_2_scale = scale[gemm2_kernel_idx]
    kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
257
258
259
    # Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)
260
261

    # (batch..., hidden_in) x (hidden_out, hidden_in)
262
263
264
    dot_2_output = fp8_dot_impl(casted_geglu_out, casted_kernel_2, geglu_out_scale_inv,
                                kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)),
                                get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
265
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
266

267
268
    ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, casted_kernel_1,
           casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_geglu_amax,
269
270
271
272
273
           updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims)

    return dot_2_output, ctx


274
def _layernorm_geglu_fp8_mlp_bwd_rule(
275
276
        fwd_dtype,    # pylint: disable=unused-argument
        bwd_dtype,
277
        layernorm_type,
278
        zero_centered_gamma,
279
        epsilon,
280
281
282
283
284
        layernorm_input_axes,
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,    # pylint: disable=unused-argument
        ffn2_ckpt_name,    # pylint: disable=unused-argument
285
        ctx,
286
287
        grad):
    x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, \
288
    casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \
289
290
    updated_geglu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
    x_contracting_dims, xt_batch_dims = ctx
291

292
    gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)
293

294
    grad_amax = amax[gemm2_grad_idx, 0:1]
295
296
297
    grad_scale = scale[gemm2_grad_idx]
    grad_scale_inv = scale_inv[gemm2_grad_idx]

298
299
300
    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)

301
302
303
304
305
306
307
    casted_grad, casted_grad_t, updated_grad_amax = \
        cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
                       static_axis_boundary=-1, transpose_axis_boundary=-1)

    casted_geglu_out_t = transpose(casted_geglu_out,
                                   static_axis_boundary=-1,
                                   transpose_axis_boundary=-1)
308

309
310
311
    # (hidden, batch...,) x (hidden, batch...)
    gemm2_x_scale_inv = scale_inv[gemm2_x_idx]
    wgrad_2 = fp8_dot_impl(casted_geglu_out_t, casted_grad_t, gemm2_x_scale_inv, grad_scale_inv,
312
313
                           grad.dtype, (xt_batch_dims, xt_batch_dims),
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
314
315

    # (batch..., hidden_out) x (hidden_in, hidden_out)
316
    kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
317
318
319
    dgrad_2 = fp8_dot_impl(casted_grad, casted_kernel_2, grad_scale_inv, kernel_2_scale_inv,
                           grad.dtype, (x_contracting_dims, (1,)),
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
320

321
322
    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

323
324
325
326
327
    gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)

    dgeglu_amax = amax[gemm1_grad_idx, 0:1]
    dgeglu_scale = scale[gemm1_grad_idx]
    dgeglu_scale_inv = scale_inv[gemm1_grad_idx]
328

329
330
331
332
333
334
335
336
    casted_dgeglu, casted_dgeglu_t, updated_dgeglu_amax = dgated_gelu_cast_transpose(
        dgrad_2,
        dot_1_output,
        dgeglu_amax,
        dgeglu_scale,
        dgeglu_scale_inv,
        bwd_dtype,
        static_axis_boundary=-1)
337

338
    ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
339

340
341
342
343
    # (hidden, batch...) x (2, hidden, batch...)
    xt_batch_dims_plus_act_dim = tuple(i + 1 for i in xt_batch_dims)
    gemm1_x_scale_inv = scale_inv[gemm1_x_idx]
    wgrad_1 = fp8_dot_impl(ln_out_t, casted_dgeglu_t, gemm1_x_scale_inv, dgeglu_scale_inv,
344
345
                           grad.dtype, (xt_batch_dims, xt_batch_dims_plus_act_dim),
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
346

347
348
349
    # (batch..., 2, hidden_out) x (hidden_in, 2, hidden_out)
    x_contracting_dims_plus_act_dim = (min(x_contracting_dims),) + tuple(
        i + 1 for i in x_contracting_dims)
350
    kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
351
352
353
    dgrad_1 = fp8_dot_impl(casted_dgeglu, casted_kernel_1, dgeglu_scale_inv, kernel_1_scale_inv,
                           grad.dtype, (x_contracting_dims_plus_act_dim, (1, 2)),
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
354

355
356
    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)

357
    if layernorm_type == 'layernorm':
358
359
360
361
362
363
364
        dx, dgamma, dbeta = layernorm_bwd(dgrad_1,
                                          x,
                                          mu,
                                          rsigma,
                                          gamma,
                                          zero_centered_gamma=zero_centered_gamma,
                                          epsilon=epsilon)
365
    else:
366
367
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
368
369
370
371
372
373
374
        dx, dgamma = rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
        dbeta = None

    amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0])
    amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0])
    amax = amax.at[gemm1_grad_idx, 0].set(updated_dgeglu_amax[0])
    amax = amax.at[gemm2_x_idx, 0].set(updated_geglu_amax[0])
375
    amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax)
376
377
378
379
380
381
382
383
    amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0])

    scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)

    return dx, dgamma, dbeta, wgrad_1, wgrad_2, \
           fp8_max, amax, scale, scale_inv


384
385
_layernorm_geglu_fp8_mlp.defvjp(_layernorm_geglu_fp8_mlp_fwd_rule,
                                _layernorm_geglu_fp8_mlp_bwd_rule)
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697


def layernorm_gelu_fp8_mlp(x: jnp.ndarray,
                           gamma: jnp.ndarray,
                           beta: jnp.ndarray,
                           kernels: List[jnp.ndarray],
                           biases: List[jnp.ndarray],
                           fp8_gemm_pkg: FP8MetaPackage,
                           layernorm_type: str,
                           zero_centered_gamma: bool = False,
                           epsilon: float = 1e-6,
                           layernorm_input_axes: Tuple[str, ...] = None,
                           dot_1_input_axes: Tuple[str, ...] = None,
                           dot_2_input_axes: Tuple[str, ...] = None,
                           ffn1_ckpt_name: str = 'ffn1',
                           ffn2_ckpt_name: str = 'ffn2') -> jnp.ndarray:
    """
    Layernorm + GEMM1 + bias + GeLU + GEMM2 + bias
    """

    assert len(kernels) == 2
    assert fp8_gemm_pkg.num_of_gemm == len(kernels)

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]
    fp8_max = fp8_gemm_pkg.fp8_max
    amax = fp8_gemm_pkg.amax
    scale = fp8_gemm_pkg.scale
    scale_inv = fp8_gemm_pkg.scale_inv

    fwd_dtype = FP8Helper.FWD_DTYPE
    bwd_dtype = FP8Helper.BWD_DTYPE

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
    if layernorm_type == 'rmsnorm':
        assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"

    output = _layernorm_gelu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max,
                                     amax, scale, scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
                                     zero_centered_gamma, epsilon, layernorm_input_axes,
                                     dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name,
                                     ffn2_ckpt_name)
    return output


@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20))
def _layernorm_gelu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
                            kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray,
                            bias_2: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray,
                            scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype,
                            bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool,
                            epsilon: float, layernorm_input_axes: Tuple[str, ...],
                            dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
                            ffn1_ckpt_name: str, ffn2_ckpt_name: str):
    output, _ = _layernorm_gelu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2,
                                                 fp8_max, amax, scale, scale_inv, fwd_dtype,
                                                 bwd_dtype, layernorm_type, zero_centered_gamma,
                                                 epsilon, layernorm_input_axes, dot_1_input_axes,
                                                 dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name)
    return output


def _layernorm_gelu_fp8_mlp_fwd_rule(
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
        fp8_max,
        amax,
        scale,
        scale_inv,
        fwd_dtype,
        bwd_dtype,    # pylint: disable=unused-argument
        layernorm_type,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,
        ffn2_ckpt_name):

    # x should be in shape of (batch..., hidden)
    # Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out)
    # Kernel_2 should be in shape of (Hidden_in, Hidden_out)
    assert len(kernel_1.shape) == 3
    assert kernel_1.shape[-2] == 1
    assert len(kernel_2.shape) == 2

    x_contracting_dims = (len(x.shape) - 1,)
    xt_batch_dims = tuple(range(1, x.ndim))

    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
    assert kernel_1.shape[-1] == kernel_2.shape[0]

    # Squeeze act axis
    # (hidden_in, 1, hidden_out) -> (hidden_in, hidden_out)
    kernel_1 = jnp.squeeze(kernel_1, axis=-2)

    amax = FP8Helper.update_amax_history(amax)

    gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)

    x_amax = amax[gemm1_x_idx, 0:1]
    x_scale = scale[gemm1_x_idx]
    x_scale_inv = scale_inv[gemm1_x_idx]

    x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)

    if layernorm_type == 'layernorm':
        ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
            x,
            gamma,
            beta,
            x_amax,
            x_scale,
            x_scale_inv,
            out_dtype=fwd_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon)
    else:
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
        ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
                                                         gamma,
                                                         x_amax,
                                                         x_scale,
                                                         x_scale_inv,
                                                         out_dtype=fwd_dtype,
                                                         epsilon=epsilon)
        mu = None

    assert x.shape == ln_out.shape

    kernel_1_amax = amax[gemm1_kernel_idx, 0:1]
    kernel_1_scale = scale[gemm1_kernel_idx]
    kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]

    # Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel_1, updated_kernel_1_amax = \
        cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)

    ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes)

    # (batch..., hidden_in) x (hidden_in, hidden_out)
    dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
                                (x_contracting_dims, (0,)),
                                get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))

    bias_1_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1.shape
    dot_1_output += jnp.reshape(bias_1, bias_1_shape)
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)

    gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)

    gelu_out_amax = amax[gemm2_x_idx, 0:1]
    gelu_out_scale = scale[gemm2_x_idx]
    gelu_out_scale_inv = scale_inv[gemm2_x_idx]

    # (batch..., hidden_in) -> (batch..., hidden)
    casted_gelu_out, updated_gelu_amax = gelu_fp8(dot_1_output, gelu_out_amax, gelu_out_scale,
                                                  gelu_out_scale_inv, fwd_dtype)

    casted_gelu_out = with_sharding_constraint_by_logical_axes(casted_gelu_out, dot_2_input_axes)

    kernel_2_scale = scale[gemm2_kernel_idx]
    kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
    # Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)

    # (batch..., hidden_in) x (hidden_out, hidden_in)
    dot_2_output = fp8_dot_impl(casted_gelu_out, casted_kernel_2, gelu_out_scale_inv,
                                kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)),
                                get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))

    bias_2_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2.shape
    dot_2_output += jnp.reshape(bias_2, bias_2_shape)
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

    ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_gelu_out, casted_kernel_1,
           casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_gelu_amax,
           updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims,
           bias_1.shape, bias_2.shape)

    return dot_2_output, ctx


def _layernorm_gelu_fp8_mlp_bwd_rule(
        fwd_dtype,    # pylint: disable=unused-argument
        bwd_dtype,
        layernorm_type,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,    # pylint: disable=unused-argument
        ffn2_ckpt_name,    # pylint: disable=unused-argument
        ctx,
        grad):
    x, ln_out, mu, rsigma, gamma, dot_1_output, casted_gelu_out, \
    casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \
    updated_gelu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
    x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape= ctx

    gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)

    grad_amax = amax[gemm2_grad_idx, 0:1]
    grad_scale = scale[gemm2_grad_idx]
    grad_scale_inv = scale_inv[gemm2_grad_idx]

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)

    casted_grad, casted_grad_t, updated_grad_amax = \
        cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
                       static_axis_boundary=-1, transpose_axis_boundary=-1)

    casted_gelu_out_t = transpose(casted_gelu_out,
                                  static_axis_boundary=-1,
                                  transpose_axis_boundary=-1)

    dbias_2 = jnp.sum(grad, axis=(i for i in range(grad.ndim - 1)))
    dbias_2 = jnp.reshape(dbias_2, bias_2_shape)

    # (hidden, batch...,) x (hidden, batch...)
    gemm2_x_scale_inv = scale_inv[gemm2_x_idx]
    wgrad_2 = fp8_dot_impl(casted_gelu_out_t, casted_grad_t, gemm2_x_scale_inv, grad_scale_inv,
                           grad.dtype, (xt_batch_dims, xt_batch_dims),
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))

    # (batch..., hidden_out) x (hidden_in, hidden_out)
    kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
    dgrad_2 = fp8_dot_impl(casted_grad, casted_kernel_2, grad_scale_inv, kernel_2_scale_inv,
                           grad.dtype, (x_contracting_dims, (1,)),
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

    gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)

    dgelu_amax = amax[gemm1_grad_idx, 0:1]
    dgelu_scale = scale[gemm1_grad_idx]
    dgelu_scale_inv = scale_inv[gemm1_grad_idx]

    casted_dgelu, casted_dgelu_t, dbias_1, updated_dgelu_amax = dgelu_dbias_cast_transpose(
        dgrad_2,
        dot_1_output,
        dgelu_amax,
        dgelu_scale,
        dgelu_scale_inv,
        bwd_dtype,
        static_axis_boundary=-1,
        transpose_axis_boundary=-1)

    dbias_1 = jnp.reshape(dbias_1, bias_1_shape)

    ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)

    # (hidden, batch...) x (hidden, batch...)
    gemm1_x_scale_inv = scale_inv[gemm1_x_idx]
    wgrad_1 = fp8_dot_impl(ln_out_t, casted_dgelu_t, gemm1_x_scale_inv, dgelu_scale_inv, grad.dtype,
                           (xt_batch_dims, xt_batch_dims),
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
    # Expand act axis to match the shape with the given kernel_1
    wgrad_1 = jnp.expand_dims(wgrad_1, axis=-2)

    # (batch..., hidden_out) x (hidden_in, hidden_out)
    kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
    dgrad_1 = fp8_dot_impl(casted_dgelu, casted_kernel_1, dgelu_scale_inv, kernel_1_scale_inv,
                           grad.dtype, (x_contracting_dims, (1,)),
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))

    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)

    if layernorm_type == 'layernorm':
        dx, dgamma, dbeta = layernorm_bwd(dgrad_1,
                                          x,
                                          mu,
                                          rsigma,
                                          gamma,
                                          zero_centered_gamma=zero_centered_gamma,
                                          epsilon=epsilon)
    else:
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
        dx, dgamma = rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
        dbeta = None

    amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0])
    amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0])
    amax = amax.at[gemm1_grad_idx, 0].set(updated_dgelu_amax[0])
    amax = amax.at[gemm2_x_idx, 0].set(updated_gelu_amax[0])
    amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax)
    amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0])

    scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)

    return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \
           fp8_max, amax, scale, scale_inv


_layernorm_gelu_fp8_mlp.defvjp(_layernorm_gelu_fp8_mlp_fwd_rule, _layernorm_gelu_fp8_mlp_bwd_rule)