mlp.py 22.7 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX MLP modules"""

6
from typing import List, Tuple, Sequence, Union, Callable
7
from functools import partial
8
9
10

import jax
import jax.numpy as jnp
11
from jax.ad_checkpoint import checkpoint_name
12

13
14
from .cpp_extensions import cast_fp8, transpose, cast_transpose, dbias_cast_transpose
from .cpp_extensions import gelu
15
from .cpp_extensions import gelu_fp8, dgelu, dgelu_dbias_cast_transpose
16
17
from .cpp_extensions import gated_gelu, gated_gelu_fp8
from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose
18
19
20
21
from .cpp_extensions import silu, silu_fp8
from .cpp_extensions import dsilu, dsilu_dbias_cast_transpose
from .cpp_extensions import gated_silu, gated_silu_fp8
from .cpp_extensions import dgated_silu, dgated_silu_cast_transpose
22
23
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd
24
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize, dequantize
25
from .layernorm import canonicalize_layernorm_type
26
from .fp8 import FP8Helper, FP8MetaPackage
27
28
29
from .sharding import with_sharding_constraint_by_logical_axes


30
31
32
33
activation_dict = {
    ('gelu',): {'fwd': gelu,
                "bwd": dgelu},
    ('gelu', 'linear'): {'fwd': gated_gelu,
34
35
36
37
38
                         'bwd': dgated_gelu},
    ('silu',): {'fwd': silu,
                "bwd": dsilu },
    ('silu', 'linear'): {'fwd': gated_silu,
                         'bwd': dgated_silu}
39
}
40

41
42
43
44
activation_fp8_dict = {
    ('gelu',): {'fwd': gelu_fp8,
                'bwd': dgelu_dbias_cast_transpose},
    ('gelu', 'linear'): {'fwd': gated_gelu_fp8,
45
46
47
48
49
                         'bwd': dgated_gelu_cast_transpose},
    ('silu',): { 'fwd': silu_fp8,
                'bwd': dsilu_dbias_cast_transpose },
    ('silu', 'linear'): { 'fwd': gated_silu_fp8,
                          'bwd': dgated_silu_cast_transpose }
50
}
51

52

53
def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
54
    """
55
    Activation Unit
56
    """
57
58
59
    if len(activation_type) > 1:
        assert x.shape[-2] == 2  # Linear + GeLU
    output = _activation_lu(x, activation_type)
60
61
    return output

62
63
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
64

65
    _output, _ = _activation_lu_fwd_rule(x, activation_type)
66

67
    return _output
68

69
70
71
def _activation_lu_fwd_rule(x, activation_type):
    fwd_output = activation_dict[activation_type]["fwd"](x)
    return fwd_output, (x,)
72

73
def _activation_lu_bwd_rule(activation_type, ctx, g):
74
75
    x, = ctx
    assert x.dtype == g.dtype
76

77
    dx = activation_dict[activation_type]["bwd"](g, x)
78
79
    dx = jnp.reshape(dx, x.shape)
    return (dx,)
80

81
_activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule)
82

83

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
                      scale_inv: jnp.ndarray, fwd_dtype:jnp.dtype, bwd_dtype: jnp.dtype,
                      activation_type: Sequence[Union[str, Callable]]):
    """
    Activation Unit
    """
    transpose_indices = (1, 2, 0) if len(activation_type) > 1 else (2, 0, 1)
    dx_trans_no_use = jnp.empty([x.shape[i] for i in transpose_indices], dtype=x.dtype)
    dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)

    output = _activation_lu_fp8(x, dx_trans_no_use, dbias_no_use, amax,
                                scale, scale_inv, fwd_dtype, bwd_dtype, activation_type)
    return output

@partial(jax.custom_vjp, nondiff_argnums=(6,7,8))
def _activation_lu_fp8(x: jnp.ndarray,
                       dx_trans_no_use: jnp.ndarray, dbias_no_use: jnp.ndarray,
                       amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
                       fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
                       activation_type: Sequence[Union[str, Callable]]):

    output = _activation_lu_fp8_fwd_rule(x, dx_trans_no_use, dbias_no_use, amax,
                                         scale, scale_inv, fwd_dtype, bwd_dtype,
                                         activation_type)

    return output

def _activation_lu_fp8_fwd_rule(x,
                                dx_trans_no_use,    # pylint: disable=unused-argument
                                dbias_no_use,   # pylint: disable=unused-argument
                                amax,
                                scale, scale_inv,
                                fwd_dtype, bwd_dtype,   # pylint: disable=unused-argument
                                activation_type):
    activation_lu_out, _ = activation_fp8_dict[activation_type ]["fwd"](
        x, amax, scale, scale_inv, fwd_dtype)

    activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
    ctx = (x, amax, scale, scale_inv)
    return activation_lu_out, ctx

def _activation_lu_fp8_bwd_rule(fwd_dtype, bwd_dtype,   # pylint: disable=unused-argument
                                activation_type, ctx, g):
    x, amax, scale, scale_inv = ctx

    activation_lu_fp8_bwd = activation_fp8_dict[activation_type]["bwd"]
    if len(activation_type) > 1: #gated, no bias
        dactivation_lu, dactivation_lu_trans, amax_out = \
        activation_lu_fp8_bwd(g, x, amax, scale, scale_inv, bwd_dtype, -1)
        dbias = jnp.empty(x.shape[-1], x.dtype)
    else:
        dactivation_lu, dactivation_lu_trans, dbias, amax_out = \
        activation_lu_fp8_bwd(g, x, amax, scale, scale_inv, bwd_dtype, -1)
    dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
    dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
    ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out, scale, scale_inv)
    return ctx

_activation_lu_fp8.defvjp(_activation_lu_fp8_fwd_rule, _activation_lu_fp8_bwd_rule)


145
def fused_layernorm_fp8_mlp(x: jnp.ndarray,
146
147
148
149
150
151
152
153
154
155
156
157
                           gamma: jnp.ndarray,
                           beta: jnp.ndarray,
                           kernels: List[jnp.ndarray],
                           biases: List[jnp.ndarray],
                           fp8_gemm_pkg: FP8MetaPackage,
                           layernorm_type: str,
                           zero_centered_gamma: bool = False,
                           epsilon: float = 1e-6,
                           layernorm_input_axes: Tuple[str, ...] = None,
                           dot_1_input_axes: Tuple[str, ...] = None,
                           dot_2_input_axes: Tuple[str, ...] = None,
                           ffn1_ckpt_name: str = 'ffn1',
158
159
160
                           ffn2_ckpt_name: str = 'ffn2',
                           activation_type: Sequence[Union[str, Callable]] = ('gelu',),
                           use_bias: bool = True) -> jnp.ndarray:
161
    """
162
    Layernorm + GEMM1 + bias + activation + GEMM2 + bias
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    """

    assert len(kernels) == 2
    assert fp8_gemm_pkg.num_of_gemm == len(kernels)

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]
    fp8_max = fp8_gemm_pkg.fp8_max
    amax = fp8_gemm_pkg.amax
    scale = fp8_gemm_pkg.scale
    scale_inv = fp8_gemm_pkg.scale_inv

    fwd_dtype = FP8Helper.FWD_DTYPE
    bwd_dtype = FP8Helper.BWD_DTYPE

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
    if layernorm_type == 'rmsnorm':
        assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"

186
    output = _fused_layernorm_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max,
187
188
189
                                     amax, scale, scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
                                     zero_centered_gamma, epsilon, layernorm_input_axes,
                                     dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name,
190
                                     ffn2_ckpt_name, activation_type, use_bias)
191
192
193
    return output


194
195
@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22))
def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
196
197
198
199
200
201
                            kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray,
                            bias_2: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray,
                            scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype,
                            bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool,
                            epsilon: float, layernorm_input_axes: Tuple[str, ...],
                            dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
202
203
204
205
206
207
208
209
210
211
                            ffn1_ckpt_name: str, ffn2_ckpt_name: str,
                            activation_type: Sequence[Union[str, Callable]],
                            use_bias: bool):
    output, _ = _fused_layernorm_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1,
                                                  bias_2, fp8_max, amax, scale, scale_inv,
                                                  fwd_dtype, bwd_dtype, layernorm_type,
                                                  zero_centered_gamma, epsilon,
                                                  layernorm_input_axes, dot_1_input_axes,
                                                  dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name,
                                                  activation_type, use_bias)
212
213
214
    return output


215
def _fused_layernorm_fp8_mlp_fwd_rule(
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
        fp8_max,
        amax,
        scale,
        scale_inv,
        fwd_dtype,
        bwd_dtype,    # pylint: disable=unused-argument
        layernorm_type,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,
236
237
238
        ffn2_ckpt_name,
        activation_type,
        use_bias):
239

240
    is_gated = len(activation_type) > 1
241
242
243
244
    # x should be in shape of (batch..., hidden)
    # Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out)
    # Kernel_2 should be in shape of (Hidden_in, Hidden_out)
    assert len(kernel_1.shape) == 3
245
    assert kernel_1.shape[-2] == len(activation_type)
246
247
248
249
250
251
252
253
254
255
    assert len(kernel_2.shape) == 2

    x_contracting_dims = (len(x.shape) - 1,)
    xt_batch_dims = tuple(range(1, x.ndim))

    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
    assert kernel_1.shape[-1] == kernel_2.shape[0]

    # Squeeze act axis
    # (hidden_in, 1, hidden_out) -> (hidden_in, hidden_out)
256
257
    if not is_gated:
        kernel_1 = jnp.squeeze(kernel_1, axis=-2)
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308

    amax = FP8Helper.update_amax_history(amax)

    gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)

    x_amax = amax[gemm1_x_idx, 0:1]
    x_scale = scale[gemm1_x_idx]
    x_scale_inv = scale_inv[gemm1_x_idx]

    x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)

    if layernorm_type == 'layernorm':
        ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
            x,
            gamma,
            beta,
            x_amax,
            x_scale,
            x_scale_inv,
            out_dtype=fwd_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon)
    else:
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
        ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
                                                         gamma,
                                                         x_amax,
                                                         x_scale,
                                                         x_scale_inv,
                                                         out_dtype=fwd_dtype,
                                                         epsilon=epsilon)
        mu = None

    assert x.shape == ln_out.shape

    kernel_1_amax = amax[gemm1_kernel_idx, 0:1]
    kernel_1_scale = scale[gemm1_kernel_idx]
    kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]

    # Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel_1, updated_kernel_1_amax = \
        cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)

    ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes)

    # (batch..., hidden_in) x (hidden_in, hidden_out)
    dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
                                (x_contracting_dims, (0,)),
                                get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
309
310
311
    if use_bias:
        bias_1_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1.shape
        dot_1_output += jnp.reshape(bias_1, bias_1_shape)
312
313
314
315
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)

    gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)

316
317
318
319
    activation_lu_out_amax = amax[gemm2_x_idx, 0:1]
    activation_lu_out_scale = scale[gemm2_x_idx]
    activation_lu_out_scale_inv = scale_inv[gemm2_x_idx]

320
    activation_lu_fwd_fp8 = activation_fp8_dict[activation_type]["fwd"]
321
322

    # (batch..., hidden_in) -> (batch..., hidden)
323
324
    casted_activation_lu_out, updated_activation_lu_amax = \
        activation_lu_fwd_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale,
325
                                                    activation_lu_out_scale_inv, fwd_dtype)
326

327
328
    casted_activation_lu_out = with_sharding_constraint_by_logical_axes(casted_activation_lu_out,
                                                                        dot_2_input_axes)
329
330
331
332
333
334
335
336

    kernel_2_scale = scale[gemm2_kernel_idx]
    kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
    # Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)

    # (batch..., hidden_in) x (hidden_out, hidden_in)
337
338
    dot_2_output = fp8_dot_impl(casted_activation_lu_out, casted_kernel_2,
                                activation_lu_out_scale_inv,
339
340
341
                                kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)),
                                get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))

342
343
344
345
    if use_bias:
        bias_2_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2.shape
        dot_2_output += jnp.reshape(bias_2, bias_2_shape)

346
347
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

348
349
350
351
    ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1,
           casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax,
           updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax,
           x_contracting_dims, xt_batch_dims, bias_1.shape, bias_2.shape)
352
353
354
355

    return dot_2_output, ctx


356
def _fused_layernorm_fp8_mlp_bwd_rule(
357
358
359
360
361
362
363
364
365
366
        fwd_dtype,    # pylint: disable=unused-argument
        bwd_dtype,
        layernorm_type,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,    # pylint: disable=unused-argument
        ffn2_ckpt_name,    # pylint: disable=unused-argument
367
368
        activation_type,
        use_bias,
369
370
        ctx,
        grad):
371
    x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, \
372
    casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \
373
    updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
374
375
    x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape= ctx

376
377
    is_gated = len(activation_type) > 1

378
379
380
381
382
383
384
385
386
    gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)

    grad_amax = amax[gemm2_grad_idx, 0:1]
    grad_scale = scale[gemm2_grad_idx]
    grad_scale_inv = scale_inv[gemm2_grad_idx]

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)

387
388
389
390
391
392
393
394
395
396
397
398
399
400
    if use_bias:
        casted_grad, casted_grad_t, dbias_2, updated_grad_amax = \
        dbias_cast_transpose(grad, grad_amax, grad_scale,
                             grad_scale_inv, bwd_dtype,
                             static_axis_boundary=-1,
                             transpose_axis_boundary=-1)
        dbias_2 = jnp.reshape(dbias_2, bias_2_shape)
    else:
        casted_grad, casted_grad_t, updated_grad_amax = \
        cast_transpose(grad, grad_amax, grad_scale,
                       grad_scale_inv, bwd_dtype,
                       static_axis_boundary=-1,
                       transpose_axis_boundary=-1)
        dbias_2 = jnp.empty(bias_2_shape, grad.dtype)
401

402
403
404
    casted_activation_lu_out_t = transpose(casted_activation_lu_out,
                                           static_axis_boundary=-1,
                                           transpose_axis_boundary=-1)
405
406
407

    # (hidden, batch...,) x (hidden, batch...)
    gemm2_x_scale_inv = scale_inv[gemm2_x_idx]
408
409
    wgrad_2 = fp8_dot_impl(casted_activation_lu_out_t, casted_grad_t, gemm2_x_scale_inv,
                           grad_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims),
410
411
412
413
414
415
416
417
418
419
420
421
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))

    # (batch..., hidden_out) x (hidden_in, hidden_out)
    kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
    dgrad_2 = fp8_dot_impl(casted_grad, casted_kernel_2, grad_scale_inv, kernel_2_scale_inv,
                           grad.dtype, (x_contracting_dims, (1,)),
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

    gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)

422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    dactivation_lu_amax = amax[gemm1_grad_idx, 0:1]
    dactivation_lu_scale = scale[gemm1_grad_idx]
    dactivation_lu_scale_inv = scale_inv[gemm1_grad_idx]

    dactivation_lu_cast_transpose = activation_fp8_dict[activation_type]["bwd"]
    dactivation_lu = activation_dict[activation_type]["bwd"](dgrad_2, dot_1_output)

    if is_gated:
        if use_bias:
            casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \
            dbias_cast_transpose(
                dactivation_lu,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
                static_axis_boundary=-1,
                transpose_axis_boundary=-2)
            dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
        else:
            casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
            dactivation_lu_cast_transpose(
                dgrad_2,
                dot_1_output,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
                static_axis_boundary=-1)
            dbias_1 = jnp.empty(bias_1_shape, bwd_dtype)
    else:
        if use_bias:
            casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \
            dactivation_lu_cast_transpose(
                dgrad_2,
                dot_1_output,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
                static_axis_boundary=-1,
                transpose_axis_boundary=-1)
            dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
        else:
            casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
            cast_transpose(
                dactivation_lu,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
                static_axis_boundary=-1,
                transpose_axis_boundary=-1)
            dbias_1 = jnp.empty(bias_1_shape, bwd_dtype)
476
477
478
479
480

    ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)

    # (hidden, batch...) x (hidden, batch...)
    gemm1_x_scale_inv = scale_inv[gemm1_x_idx]
481
482
483
484
485
    xt_batch_dims_2 = xt_batch_dims if not is_gated \
        else tuple(i + 1 for i in xt_batch_dims)
    wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv,
                           dactivation_lu_scale_inv, grad.dtype,
                           (xt_batch_dims, xt_batch_dims_2),
486
487
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
    # Expand act axis to match the shape with the given kernel_1
488
489
    if not is_gated:
        wgrad_1 = jnp.expand_dims(wgrad_1, axis=-2)
490
491

    # (batch..., hidden_out) x (hidden_in, hidden_out)
492
493
494
495
496
    if is_gated:
        x_contracting_dims = ((min(x_contracting_dims),) + tuple(
            i + 1 for i in x_contracting_dims), (1,2))
    else:
        x_contracting_dims = (x_contracting_dims, (1,))
497
    kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
498
499
500
    dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1,
                           dactivation_lu_scale_inv, kernel_1_scale_inv,
                           grad.dtype, x_contracting_dims,
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))

    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)

    if layernorm_type == 'layernorm':
        dx, dgamma, dbeta = layernorm_bwd(dgrad_1,
                                          x,
                                          mu,
                                          rsigma,
                                          gamma,
                                          zero_centered_gamma=zero_centered_gamma,
                                          epsilon=epsilon)
    else:
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
        dx, dgamma = rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
        dbeta = None

    amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0])
    amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0])
521
522
    amax = amax.at[gemm1_grad_idx, 0].set(updated_dactivation_lu_amax[0])
    amax = amax.at[gemm2_x_idx, 0].set(updated_activation_lu_amax[0])
523
524
525
526
527
528
529
530
    amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax)
    amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0])

    scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
    return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \
           fp8_max, amax, scale, scale_inv


531
532
_fused_layernorm_fp8_mlp.defvjp(_fused_layernorm_fp8_mlp_fwd_rule,
                                        _fused_layernorm_fp8_mlp_bwd_rule)