layernorm_mlp.py 18.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
11
12
13
14
"""Multi-layer perceptron (MLP) operations with layer normalization for Transformer Engine in JAX.

This module provides optimized implementations of MLP blocks commonly used in transformer
architectures. Each MLP block consists of:
1. Layer normalization
2. First dense layer transformation (GEMM1) with bias and activation
3. Second dense layer transformation (GEMM2) with bias

The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints.
"""
15

16
from typing import List, Tuple, Sequence, Union, Callable
17
from functools import partial
18
19
20

import jax
import jax.numpy as jnp
21
from jax.ad_checkpoint import checkpoint_name
22

23
from . import cpp_extensions as tex
24
from .cpp_extensions.quantization import AmaxScope
25
from .layernorm import canonicalize_norm_type
26
27
28
29
30
from .quantize import (
    with_sharding_constraint_by_logical_axes,
    QuantizerSet,
    noop_quantizer_set,
    TensorUsage,
31
    get_quantize_config,
32
)
Alp Dener's avatar
Alp Dener committed
33
34


35
def layernorm_mlp(
36
37
38
39
40
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernels: List[jnp.ndarray],
    biases: List[jnp.ndarray],
41
    norm_type: str,
42
43
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
Phuong Nguyen's avatar
Phuong Nguyen committed
44
    batch_sequence_transpose: bool = False,
45
    norm_input_axes: Tuple[str, ...] = None,
46
47
    dot_1_input_axes: Tuple[str, ...] = None,
    dot_2_input_axes: Tuple[str, ...] = None,
48
49
    kernel_1_axes: Tuple[str, ...] = None,
    kernel_2_axes: Tuple[str, ...] = None,
50
51
52
    ffn1_ckpt_name: str = "ffn1",
    ffn2_ckpt_name: str = "ffn2",
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
Phuong Nguyen's avatar
Phuong Nguyen committed
53
54
55
56
    collective_op_sets: Tuple[tex.CollectiveOpSet] = (
        tex.noop_collective_op_set,
        tex.noop_collective_op_set,
    ),
57
    quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
58
) -> jnp.ndarray:
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    """Apply layer normalization followed by MLP block.

    This function implements the following sequence of operations:
        1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
        2. First dense layer transformation: y1 = x * kernel1 + bias1
        3. Activation function: y2 = activation(y1)
        4. Second dense layer transformation: y3 = y2 * kernel2 + bias2

    Args:
        x: Input tensor with shape [batch..., hidden_in]
        gamma: Scale parameter for normalization with shape [hidden_in]
        beta: Bias parameter for normalization with shape [hidden_in]
        kernels: List of two weight matrices:
            - kernel1: [hidden_in, intermediate]
            - kernel2: [intermediate, hidden_in]
        biases: List of two bias terms:
            - bias1: [intermediate]
            - bias2: [hidden_in]
        norm_type: Type of normalization ("layernorm" or "rmsnorm")
        zero_centered_gamma: Whether to use zero-centered gamma for normalization
        epsilon: Small constant for numerical stability in normalization
Phuong Nguyen's avatar
Phuong Nguyen committed
80
        batch_sequence_transpose: Whether to transpose the batch and sequence dimensions
81
82
83
        norm_input_axes: Logical axes for sharding the layernorm input
        dot_1_input_axes: Logical axes for sharding the first matrix multiplication
        dot_2_input_axes: Logical axes for sharding the second matrix multiplication
84
85
        kernel_1_axes: Logical axes for sharding the first weight matrix
        kernel_2_axes: Logical axes for sharding the second weight matrix
86
87
88
        ffn1_ckpt_name: Name for checkpointing the first feed-forward network
        ffn2_ckpt_name: Name for checkpointing the second feed-forward network
        activation_type: Activation function(s) to apply after the first dense layer transformation
Phuong Nguyen's avatar
Phuong Nguyen committed
89
        collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations
90
91
92
93
94
95
96
97
98
99
100
        quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations

    Returns:
        Output tensor with shape [batch..., hidden_in]

    Note:
        - For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
          must be False
        - The function supports automatic differentiation through JAX's custom VJP
        - Quantization is applied to both dense layer transformations
        - Checkpointing is applied to both feed-forward networks for memory efficiency
101
102
103
104
105
106
107
108
    """
    assert len(kernels) == 2

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]

109
110
111
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "rmsnorm":
        assert beta is None, "beta should be None if norm_type is 'rmsnorm'"
112
113
        assert (
            not zero_centered_gamma
114
        ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
115

116
117
118
119
120
    if not get_quantize_config().is_fp8_enabled():
        input_dtype = x.dtype
        kernel_1 = kernel_1.astype(input_dtype)
        kernel_2 = kernel_2.astype(input_dtype)

121
    output = _layernorm_mlp(
122
123
124
125
126
127
128
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
129
        norm_type,
130
131
        zero_centered_gamma,
        epsilon,
Phuong Nguyen's avatar
Phuong Nguyen committed
132
        batch_sequence_transpose,
133
        norm_input_axes,
134
135
        dot_1_input_axes,
        dot_2_input_axes,
136
137
        kernel_1_axes,
        kernel_2_axes,
138
139
140
        ffn1_ckpt_name,
        ffn2_ckpt_name,
        activation_type,
Phuong Nguyen's avatar
Phuong Nguyen committed
141
        collective_op_sets,
142
        quantizer_sets,
143
    )
144
145
146
    return output


Phuong Nguyen's avatar
Phuong Nguyen committed
147
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
148
def _layernorm_mlp(
149
150
151
152
153
154
155
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernel_1: jnp.ndarray,
    kernel_2: jnp.ndarray,
    bias_1: jnp.ndarray,
    bias_2: jnp.ndarray,
156
    norm_type: str,
157
158
    zero_centered_gamma: bool,
    epsilon: float,
Phuong Nguyen's avatar
Phuong Nguyen committed
159
    batch_sequence_transpose: bool,
160
    norm_input_axes: Tuple[str, ...],
161
162
    dot_1_input_axes: Tuple[str, ...],
    dot_2_input_axes: Tuple[str, ...],
163
164
    kernel_1_axes: Tuple[str, ...],
    kernel_2_axes: Tuple[str, ...],
165
166
167
    ffn1_ckpt_name: str,
    ffn2_ckpt_name: str,
    activation_type: Sequence[Union[str, Callable]],
Phuong Nguyen's avatar
Phuong Nguyen committed
168
    collective_op_sets: Tuple[tex.CollectiveOpSet],
169
    quantizer_sets,
170
):
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    """Internal implementation of layernorm_mlp with custom VJP.

    This function implements the forward pass of layernorm_mlp with support for
    automatic differentiation. It handles the normalization, dense layer transformations,
    activation, and quantization operations.

    Args:
        x: Input tensor
        gamma: Scale parameter for normalization
        beta: Bias parameter for normalization
        kernel_1: First weight matrix
        kernel_2: Second weight matrix
        bias_1: First bias term
        bias_2: Second bias term
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
Phuong Nguyen's avatar
Phuong Nguyen committed
188
        batch_sequence_transpose: Whether to transpose the batch and sequence dimensions
189
190
191
        norm_input_axes: Logical axes for layernorm sharding
        dot_1_input_axes: Logical axes for first matrix multiplication sharding
        dot_2_input_axes: Logical axes for second matrix multiplication sharding
Phuong Nguyen's avatar
Phuong Nguyen committed
192
193
        kernel_1_axes: Logical axes for first weight matrix sharding
        kernel_2_axes: Logical axes for second weight matrix sharding
194
195
196
        ffn1_ckpt_name: Name for first feed-forward network checkpointing
        ffn2_ckpt_name: Name for second feed-forward network checkpointing
        activation_type: Activation function(s)
Phuong Nguyen's avatar
Phuong Nguyen committed
197
        collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations
198
199
200
201
202
203
        quantizer_sets: Tuple of quantizer sets

    Returns:
        Output tensor from the combined operations
    """
    output, _ = _layernorm_mlp_fwd_rule(
204
205
206
207
208
209
210
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
211
        norm_type,
212
213
        zero_centered_gamma,
        epsilon,
Phuong Nguyen's avatar
Phuong Nguyen committed
214
        batch_sequence_transpose,
215
        norm_input_axes,
216
217
        dot_1_input_axes,
        dot_2_input_axes,
218
219
        kernel_1_axes,
        kernel_2_axes,
220
        ffn1_ckpt_name,
221
222
        ffn2_ckpt_name,
        activation_type,
Phuong Nguyen's avatar
Phuong Nguyen committed
223
        collective_op_sets,
224
        quantizer_sets,
225
226
227
228
    )
    return output


229
def _layernorm_mlp_fwd_rule(
230
231
232
233
234
235
236
    x,
    gamma,
    beta,
    kernel_1,
    kernel_2,
    bias_1,
    bias_2,
237
    norm_type,
238
239
    zero_centered_gamma,
    epsilon,
Phuong Nguyen's avatar
Phuong Nguyen committed
240
    batch_sequence_transpose,
241
    norm_input_axes,
242
243
    dot_1_input_axes,
    dot_2_input_axes,
244
245
    kernel_1_axes,
    kernel_2_axes,
246
247
248
    ffn1_ckpt_name,
    ffn2_ckpt_name,
    activation_type,
Phuong Nguyen's avatar
Phuong Nguyen committed
249
    collective_op_sets,
250
    quantizer_sets,
251
):
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    """Forward pass rule for layernorm_mlp.

    Implements the forward pass computation including:
    1. Layer normalization with quantization
    2. First matrix multiplication with quantized kernel
    3. Activation function application
    4. Second matrix multiplication with quantized kernel
    5. Optional bias additions
    6. Sharding constraints
    7. Checkpointing for memory efficiency

    Returns:
        Tuple of (output, context) for automatic differentiation
    """
266
    del kernel_1_axes, kernel_2_axes
267

268
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
Phuong Nguyen's avatar
Phuong Nguyen committed
269
270
271
272
    collective_op_set_1, collective_op_set_2 = collective_op_sets

    assert not collective_op_set_1.forward.is_reduce_scatter
    assert not collective_op_set_2.forward.is_all_gather
273
274

    # x should be in shape of (batch..., hidden)
275
    # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
276
    # Kernel_2 should be in shape of (intermediate, hidden_in)
277
    assert len(kernel_1.shape) == 3
278
    assert len(kernel_2.shape) == 2
279
    assert kernel_1.shape[-2] == len(activation_type)
280
281

    x_contracting_dims = (len(x.shape) - 1,)
282
    k_contracting_dims = (0,)
283

284
    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
285

286
287
288
289
290
291
292
293
294
295
296
297
298
    use_bias_1 = bias_1 is not None
    use_bias_2 = bias_1 is not None

    x = with_sharding_constraint_by_logical_axes(x, norm_input_axes)

    casted_ln_out, mu, rsigma = tex.normalization_fwd(
        x,
        gamma,
        beta,
        zero_centered_gamma,
        epsilon,
        norm_type,
        quantizer=ffn1_quantizer_set.x,
299
        amax_scope=AmaxScope.TPSP,
300
    )
301
    casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
302

Alp Dener's avatar
Alp Dener committed
303
    casted_kernel_1 = tex.quantize(
304
        kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, amax_scope=AmaxScope.FSDP
Alp Dener's avatar
Alp Dener committed
305
    )
306

307
    # NN GEMM
308
    # (batch..., hidden_in) x (hidden_in, hidden_out)
309
    dot_1_output = tex.gemm(
310
311
        casted_ln_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
312
        contracting_dims=(x_contracting_dims, k_contracting_dims),
Phuong Nguyen's avatar
Phuong Nguyen committed
313
        transpose_batch_sequence=batch_sequence_transpose,
Alp Dener's avatar
Alp Dener committed
314
315
        bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
Phuong Nguyen's avatar
Phuong Nguyen committed
316
        collective_op=collective_op_set_1.forward,
317
    )
318

Alp Dener's avatar
Alp Dener committed
319
    if use_bias_1 and tex.gemm_uses_jax_dot():
320
321
322
        bias_1_shape = bias_1.shape
        bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
        dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
323

324
325
326
327
328
329
330
    # This sharding constraint is needed to correct the Shardy sharding propagation
    if dot_2_input_axes is not None:
        dot_1_output_axes = (
            dot_2_input_axes[:-1] + (None,) + dot_2_input_axes[-1:]
        )  # add the act_num axis
        dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)

331
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
332
333

    # (batch..., hidden_in) -> (batch..., hidden)
Alp Dener's avatar
Alp Dener committed
334
    casted_act_out = tex.act_lu(
335
336
337
        dot_1_output,
        activation_type,
        quantizer=ffn2_quantizer_set.x,
Alp Dener's avatar
Alp Dener committed
338
    )
339

340
    casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
341

Alp Dener's avatar
Alp Dener committed
342
    casted_kernel_2 = tex.quantize(
343
344
        kernel_2,
        quantizer=ffn2_quantizer_set.kernel,
345
        amax_scope=AmaxScope.FSDP,
Alp Dener's avatar
Alp Dener committed
346
    )
347

348
    # NN GEMM
349
    # (batch..., hidden_in) x (hidden_out, hidden_in)
350
    dot_2_output = tex.gemm(
351
352
        casted_act_out.get_tensor(TensorUsage.LHS),
        casted_kernel_2.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
353
        contracting_dims=(x_contracting_dims, k_contracting_dims),
Phuong Nguyen's avatar
Phuong Nguyen committed
354
        transpose_batch_sequence=batch_sequence_transpose,
Alp Dener's avatar
Alp Dener committed
355
356
        bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
Phuong Nguyen's avatar
Phuong Nguyen committed
357
        collective_op=collective_op_set_2.forward,
358
    )
359

Alp Dener's avatar
Alp Dener committed
360
    if use_bias_2 and tex.gemm_uses_jax_dot():
361
362
363
        bias_2_shape = bias_2.shape
        bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
        dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
364

Phuong Nguyen's avatar
Phuong Nguyen committed
365
366
    # sharding of outputs should be the same as dot_1's input
    dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_1_input_axes)
367
368
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

369
370
371
372
373
    ctx = (
        x,
        mu,
        rsigma,
        gamma,
374
        beta,
375
376
        casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
        casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS),
377
        dot_1_output,
378
379
        casted_act_out.get_tensor(TensorUsage.LHS_TRANS),
        casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS),
380
        x_contracting_dims,
381
382
383
384
385
386
        k_contracting_dims,
        kernel_1.shape,
        kernel_2.shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
387
    )
388
389
390
391

    return dot_2_output, ctx


392
393
def _layernorm_mlp_bwd_rule(
    norm_type,
394
395
    zero_centered_gamma,
    epsilon,
Phuong Nguyen's avatar
Phuong Nguyen committed
396
    batch_sequence_transpose,
397
    norm_input_axes,
398
399
    dot_1_input_axes,
    dot_2_input_axes,
400
401
402
403
    kernel_1_axes,
    kernel_2_axes,
    ffn1_ckpt_name,
    ffn2_ckpt_name,
404
    activation_type,
Phuong Nguyen's avatar
Phuong Nguyen committed
405
    collective_op_sets,
406
407
408
    ctx,
    grad,
):
409
410
411
412
413
414
415
416
417
418
419
420
421
    """Backward pass rule for layernorm_mlp.

    Implements the backward pass computation including:
    1. Gradient computation for second matrix multiplication
    2. Gradient computation for activation function
    3. Gradient computation for first matrix multiplication
    4. Gradient computation for layer normalization
    5. Gradient computation for bias terms
    6. Proper handling of quantization

    Returns:
        Tuple of gradients for all input parameters
    """
422
    del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name
423
424
425
426
427
    (
        x,
        mu,
        rsigma,
        gamma,
428
        beta,
429
430
        casted_ln_out,
        casted_kernel_1,
431
        dot_1_output,
432
433
        casted_act_out,
        casted_kernel_2,
434
435
436
437
438
439
440
        x_contracting_dims_in_fwd,
        k_contracting_dims_in_fwd,
        kernel_1_shape,
        kernel_2_shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
441
    ) = ctx
442

443
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
Phuong Nguyen's avatar
Phuong Nguyen committed
444
445
446
447
    collective_op_set_1, collective_op_set_2 = collective_op_sets

    assert not collective_op_set_1.backward.is_all_gather
    assert not collective_op_set_2.backward.is_reduce_scatter
448
449
450

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
451
452

    casted_grad, dbias_2 = tex.quantize_dbias(
453
454
455
        grad,
        is_dbias=use_bias_2,
        quantizer=ffn1_quantizer_set.dgrad,
456
        amax_scope=AmaxScope.TPSP,
457
    )
458

459
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
460
    g_contracting_dims_2 = tuple(
461
462
463
        range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
    )
    # k_non_contracting_dims
464
    k_contracting_dims_2 = tuple(
465
        dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
466
    )
467

468
    # NT GEMM
469
    # (batch..., hidden_out) x (hidden_in, hidden_out)
470
    dgrad_2 = tex.gemm(
471
472
        casted_grad.get_tensor(TensorUsage.LHS),
        casted_kernel_2,
Alp Dener's avatar
Alp Dener committed
473
        contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
Phuong Nguyen's avatar
Phuong Nguyen committed
474
475
        transpose_batch_sequence=batch_sequence_transpose,
        collective_op=collective_op_set_2.backward,
476
    )
477
478
479

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

480
    x_contracting_dims = g_contracting_dims = tuple(
481
        range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
482
483
    )

484
485
486
    # TN GEMM
    # (hidden, batch...,) x (hidden, batch...)
    wgrad_2 = tex.gemm(
487
488
        casted_act_out,
        casted_grad.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
489
        contracting_dims=(x_contracting_dims, g_contracting_dims),
Phuong Nguyen's avatar
Phuong Nguyen committed
490
        transpose_batch_sequence=batch_sequence_transpose,
491
    )
492
    wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
493

494
495
496
497
498
499
    casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
        dgrad_2,
        dot_1_output,
        activation_type=activation_type,
        is_dbias=use_bias_1,
        quantizer=ffn2_quantizer_set.dgrad,
500
    )
501
502

    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
503
    dact_out_ndim = casted_dact_out.get_tensor(TensorUsage.LHS).data.ndim
504
505
    g_contracting_dims_1 = tuple(
        range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
506
    )
507
    # k_non_contracting_dims
508
    k_contracting_dims_1 = tuple(
509
        dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
510
    )
511
512
513

    # NT GEMM
    dgrad_1 = tex.gemm(
514
515
        casted_dact_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1,
Alp Dener's avatar
Alp Dener committed
516
        contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
Phuong Nguyen's avatar
Phuong Nguyen committed
517
518
        transpose_batch_sequence=batch_sequence_transpose,
        collective_op=collective_op_set_1.backward,
519
    )
520

521
    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
522
523
524
525

    # TN GEMM
    # (hidden, batch...) x (hidden, batch...)
    wgrad_1 = tex.gemm(
526
527
        casted_ln_out,
        casted_dact_out.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
528
        contracting_dims=(x_contracting_dims, g_contracting_dims),
Phuong Nguyen's avatar
Phuong Nguyen committed
529
        transpose_batch_sequence=batch_sequence_transpose,
530
    )
531

532
533
    wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)

534
535
536
537
538
539
540
541
542
543
    dx, dgamma, dbeta = tex.normalization_bwd(
        dgrad_1,
        x,
        mu,
        rsigma,
        gamma,
        beta,
        zero_centered_gamma=zero_centered_gamma,
        epsilon=epsilon,
        norm_type=norm_type,
544
    )
545

546
547
    return (dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, quantizer_sets)

548

549
_layernorm_mlp.defvjp(_layernorm_mlp_fwd_rule, _layernorm_mlp_bwd_rule)