layernorm_mlp.py 18.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
11
12
13
14
"""Multi-layer perceptron (MLP) operations with layer normalization for Transformer Engine in JAX.

This module provides optimized implementations of MLP blocks commonly used in transformer
architectures. Each MLP block consists of:
1. Layer normalization
2. First dense layer transformation (GEMM1) with bias and activation
3. Second dense layer transformation (GEMM2) with bias

The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints.
"""
15

16
from typing import List, Tuple, Sequence, Union, Callable
17
from functools import partial
18
19
20

import jax
import jax.numpy as jnp
21
from jax.ad_checkpoint import checkpoint_name
22

23
from . import cpp_extensions as tex
24
from .cpp_extensions.quantization import AmaxScope
25
from .layernorm import canonicalize_norm_type
26
27
28
29
30
from .quantize import (
    with_sharding_constraint_by_logical_axes,
    QuantizerSet,
    noop_quantizer_set,
    TensorUsage,
31
    get_quantize_config,
32
)
Alp Dener's avatar
Alp Dener committed
33
34


35
def layernorm_mlp(
36
37
38
39
40
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernels: List[jnp.ndarray],
    biases: List[jnp.ndarray],
41
    norm_type: str,
42
43
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
Phuong Nguyen's avatar
Phuong Nguyen committed
44
    batch_sequence_transpose: bool = False,
45
    norm_input_axes: Tuple[str, ...] = None,
46
47
    dot_1_input_axes: Tuple[str, ...] = None,
    dot_2_input_axes: Tuple[str, ...] = None,
48
49
    kernel_1_axes: Tuple[str, ...] = None,
    kernel_2_axes: Tuple[str, ...] = None,
50
51
52
    ffn1_ckpt_name: str = "ffn1",
    ffn2_ckpt_name: str = "ffn2",
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
53
    activation_params: dict = None,
Phuong Nguyen's avatar
Phuong Nguyen committed
54
55
56
57
    collective_op_sets: Tuple[tex.CollectiveOpSet] = (
        tex.noop_collective_op_set,
        tex.noop_collective_op_set,
    ),
58
    quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
59
) -> jnp.ndarray:
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    """Apply layer normalization followed by MLP block.

    This function implements the following sequence of operations:
        1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
        2. First dense layer transformation: y1 = x * kernel1 + bias1
        3. Activation function: y2 = activation(y1)
        4. Second dense layer transformation: y3 = y2 * kernel2 + bias2

    Args:
        x: Input tensor with shape [batch..., hidden_in]
        gamma: Scale parameter for normalization with shape [hidden_in]
        beta: Bias parameter for normalization with shape [hidden_in]
        kernels: List of two weight matrices:
            - kernel1: [hidden_in, intermediate]
            - kernel2: [intermediate, hidden_in]
        biases: List of two bias terms:
            - bias1: [intermediate]
            - bias2: [hidden_in]
        norm_type: Type of normalization ("layernorm" or "rmsnorm")
        zero_centered_gamma: Whether to use zero-centered gamma for normalization
        epsilon: Small constant for numerical stability in normalization
Phuong Nguyen's avatar
Phuong Nguyen committed
81
        batch_sequence_transpose: Whether to transpose the batch and sequence dimensions
82
83
84
        norm_input_axes: Logical axes for sharding the layernorm input
        dot_1_input_axes: Logical axes for sharding the first matrix multiplication
        dot_2_input_axes: Logical axes for sharding the second matrix multiplication
85
86
        kernel_1_axes: Logical axes for sharding the first weight matrix
        kernel_2_axes: Logical axes for sharding the second weight matrix
87
88
89
        ffn1_ckpt_name: Name for checkpointing the first feed-forward network
        ffn2_ckpt_name: Name for checkpointing the second feed-forward network
        activation_type: Activation function(s) to apply after the first dense layer transformation
Phuong Nguyen's avatar
Phuong Nguyen committed
90
        collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations
91
92
93
94
95
96
97
98
99
100
101
        quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations

    Returns:
        Output tensor with shape [batch..., hidden_in]

    Note:
        - For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
          must be False
        - The function supports automatic differentiation through JAX's custom VJP
        - Quantization is applied to both dense layer transformations
        - Checkpointing is applied to both feed-forward networks for memory efficiency
102
103
104
105
106
107
108
109
    """
    assert len(kernels) == 2

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]

110
111
112
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "rmsnorm":
        assert beta is None, "beta should be None if norm_type is 'rmsnorm'"
113
114
        assert (
            not zero_centered_gamma
115
        ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
116

117
118
119
120
121
    if not get_quantize_config().is_fp8_enabled():
        input_dtype = x.dtype
        kernel_1 = kernel_1.astype(input_dtype)
        kernel_2 = kernel_2.astype(input_dtype)

122
    output = _layernorm_mlp(
123
124
125
126
127
128
129
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
130
        norm_type,
131
132
        zero_centered_gamma,
        epsilon,
Phuong Nguyen's avatar
Phuong Nguyen committed
133
        batch_sequence_transpose,
134
        norm_input_axes,
135
136
        dot_1_input_axes,
        dot_2_input_axes,
137
138
        kernel_1_axes,
        kernel_2_axes,
139
140
141
        ffn1_ckpt_name,
        ffn2_ckpt_name,
        activation_type,
142
        activation_params,
Phuong Nguyen's avatar
Phuong Nguyen committed
143
        collective_op_sets,
144
        quantizer_sets,
145
    )
146
147
148
    return output


149
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20))
150
def _layernorm_mlp(
151
152
153
154
155
156
157
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernel_1: jnp.ndarray,
    kernel_2: jnp.ndarray,
    bias_1: jnp.ndarray,
    bias_2: jnp.ndarray,
158
    norm_type: str,
159
160
    zero_centered_gamma: bool,
    epsilon: float,
Phuong Nguyen's avatar
Phuong Nguyen committed
161
    batch_sequence_transpose: bool,
162
    norm_input_axes: Tuple[str, ...],
163
164
    dot_1_input_axes: Tuple[str, ...],
    dot_2_input_axes: Tuple[str, ...],
165
166
    kernel_1_axes: Tuple[str, ...],
    kernel_2_axes: Tuple[str, ...],
167
168
169
    ffn1_ckpt_name: str,
    ffn2_ckpt_name: str,
    activation_type: Sequence[Union[str, Callable]],
170
    activation_params: dict,
Phuong Nguyen's avatar
Phuong Nguyen committed
171
    collective_op_sets: Tuple[tex.CollectiveOpSet],
172
    quantizer_sets,
173
):
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    """Internal implementation of layernorm_mlp with custom VJP.

    This function implements the forward pass of layernorm_mlp with support for
    automatic differentiation. It handles the normalization, dense layer transformations,
    activation, and quantization operations.

    Args:
        x: Input tensor
        gamma: Scale parameter for normalization
        beta: Bias parameter for normalization
        kernel_1: First weight matrix
        kernel_2: Second weight matrix
        bias_1: First bias term
        bias_2: Second bias term
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
Phuong Nguyen's avatar
Phuong Nguyen committed
191
        batch_sequence_transpose: Whether to transpose the batch and sequence dimensions
192
193
194
        norm_input_axes: Logical axes for layernorm sharding
        dot_1_input_axes: Logical axes for first matrix multiplication sharding
        dot_2_input_axes: Logical axes for second matrix multiplication sharding
Phuong Nguyen's avatar
Phuong Nguyen committed
195
196
        kernel_1_axes: Logical axes for first weight matrix sharding
        kernel_2_axes: Logical axes for second weight matrix sharding
197
198
199
        ffn1_ckpt_name: Name for first feed-forward network checkpointing
        ffn2_ckpt_name: Name for second feed-forward network checkpointing
        activation_type: Activation function(s)
Phuong Nguyen's avatar
Phuong Nguyen committed
200
        collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations
201
202
203
204
205
206
        quantizer_sets: Tuple of quantizer sets

    Returns:
        Output tensor from the combined operations
    """
    output, _ = _layernorm_mlp_fwd_rule(
207
208
209
210
211
212
213
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
214
        norm_type,
215
216
        zero_centered_gamma,
        epsilon,
Phuong Nguyen's avatar
Phuong Nguyen committed
217
        batch_sequence_transpose,
218
        norm_input_axes,
219
220
        dot_1_input_axes,
        dot_2_input_axes,
221
222
        kernel_1_axes,
        kernel_2_axes,
223
        ffn1_ckpt_name,
224
225
        ffn2_ckpt_name,
        activation_type,
226
        activation_params,
Phuong Nguyen's avatar
Phuong Nguyen committed
227
        collective_op_sets,
228
        quantizer_sets,
229
230
231
232
    )
    return output


233
def _layernorm_mlp_fwd_rule(
234
235
236
237
238
239
240
    x,
    gamma,
    beta,
    kernel_1,
    kernel_2,
    bias_1,
    bias_2,
241
    norm_type,
242
243
    zero_centered_gamma,
    epsilon,
Phuong Nguyen's avatar
Phuong Nguyen committed
244
    batch_sequence_transpose,
245
    norm_input_axes,
246
247
    dot_1_input_axes,
    dot_2_input_axes,
248
249
    kernel_1_axes,
    kernel_2_axes,
250
251
252
    ffn1_ckpt_name,
    ffn2_ckpt_name,
    activation_type,
253
    activation_params,
Phuong Nguyen's avatar
Phuong Nguyen committed
254
    collective_op_sets,
255
    quantizer_sets,
256
):
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    """Forward pass rule for layernorm_mlp.

    Implements the forward pass computation including:
    1. Layer normalization with quantization
    2. First matrix multiplication with quantized kernel
    3. Activation function application
    4. Second matrix multiplication with quantized kernel
    5. Optional bias additions
    6. Sharding constraints
    7. Checkpointing for memory efficiency

    Returns:
        Tuple of (output, context) for automatic differentiation
    """
271
    del kernel_1_axes, kernel_2_axes
272

273
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
Phuong Nguyen's avatar
Phuong Nguyen committed
274
275
276
277
    collective_op_set_1, collective_op_set_2 = collective_op_sets

    assert not collective_op_set_1.forward.is_reduce_scatter
    assert not collective_op_set_2.forward.is_all_gather
278
279

    # x should be in shape of (batch..., hidden)
280
    # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
281
    # Kernel_2 should be in shape of (intermediate, hidden_in)
282
    assert len(kernel_1.shape) == 3
283
    assert len(kernel_2.shape) == 2
284
    assert kernel_1.shape[-2] == len(activation_type)
285
286

    x_contracting_dims = (len(x.shape) - 1,)
287
    k_contracting_dims = (0,)
288

289
    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
290

291
292
293
294
295
296
297
298
299
300
301
302
303
    use_bias_1 = bias_1 is not None
    use_bias_2 = bias_1 is not None

    x = with_sharding_constraint_by_logical_axes(x, norm_input_axes)

    casted_ln_out, mu, rsigma = tex.normalization_fwd(
        x,
        gamma,
        beta,
        zero_centered_gamma,
        epsilon,
        norm_type,
        quantizer=ffn1_quantizer_set.x,
304
        amax_scope=AmaxScope.TPSP,
305
    )
306
    casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
307

Alp Dener's avatar
Alp Dener committed
308
    casted_kernel_1 = tex.quantize(
309
        kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, amax_scope=AmaxScope.FSDP
Alp Dener's avatar
Alp Dener committed
310
    )
311

312
    # NN GEMM
313
    # (batch..., hidden_in) x (hidden_in, hidden_out)
314
    dot_1_output = tex.gemm(
315
316
        casted_ln_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
317
        contracting_dims=(x_contracting_dims, k_contracting_dims),
Phuong Nguyen's avatar
Phuong Nguyen committed
318
        transpose_batch_sequence=batch_sequence_transpose,
Alp Dener's avatar
Alp Dener committed
319
320
        bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
Phuong Nguyen's avatar
Phuong Nguyen committed
321
        collective_op=collective_op_set_1.forward,
322
    )
323

Alp Dener's avatar
Alp Dener committed
324
    if use_bias_1 and tex.gemm_uses_jax_dot():
325
326
327
        bias_1_shape = bias_1.shape
        bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
        dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
328

329
330
331
332
333
334
335
    # This sharding constraint is needed to correct the Shardy sharding propagation
    if dot_2_input_axes is not None:
        dot_1_output_axes = (
            dot_2_input_axes[:-1] + (None,) + dot_2_input_axes[-1:]
        )  # add the act_num axis
        dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)

336
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
337
338

    # (batch..., hidden_in) -> (batch..., hidden)
Alp Dener's avatar
Alp Dener committed
339
    casted_act_out = tex.act_lu(
340
341
342
        dot_1_output,
        activation_type,
        quantizer=ffn2_quantizer_set.x,
343
344
345
346
347
        act_params=(
            tex.activation.ActivationParams.create(activation_type, **activation_params)
            if activation_params
            else None
        ),
Alp Dener's avatar
Alp Dener committed
348
    )
349

350
    casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
351

Alp Dener's avatar
Alp Dener committed
352
    casted_kernel_2 = tex.quantize(
353
354
        kernel_2,
        quantizer=ffn2_quantizer_set.kernel,
355
        amax_scope=AmaxScope.FSDP,
Alp Dener's avatar
Alp Dener committed
356
    )
357

358
    # NN GEMM
359
    # (batch..., hidden_in) x (hidden_out, hidden_in)
360
    dot_2_output = tex.gemm(
361
362
        casted_act_out.get_tensor(TensorUsage.LHS),
        casted_kernel_2.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
363
        contracting_dims=(x_contracting_dims, k_contracting_dims),
Phuong Nguyen's avatar
Phuong Nguyen committed
364
        transpose_batch_sequence=batch_sequence_transpose,
Alp Dener's avatar
Alp Dener committed
365
366
        bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
Phuong Nguyen's avatar
Phuong Nguyen committed
367
        collective_op=collective_op_set_2.forward,
368
    )
369

Alp Dener's avatar
Alp Dener committed
370
    if use_bias_2 and tex.gemm_uses_jax_dot():
371
372
373
        bias_2_shape = bias_2.shape
        bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
        dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
374

Phuong Nguyen's avatar
Phuong Nguyen committed
375
376
    # sharding of outputs should be the same as dot_1's input
    dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_1_input_axes)
377
378
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

379
380
381
382
383
    ctx = (
        x,
        mu,
        rsigma,
        gamma,
384
        beta,
385
386
        casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
        casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS),
387
        dot_1_output,
388
389
        casted_act_out.get_tensor(TensorUsage.LHS_TRANS),
        casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS),
390
        x_contracting_dims,
391
392
393
394
395
396
        k_contracting_dims,
        kernel_1.shape,
        kernel_2.shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
397
    )
398
399
400
401

    return dot_2_output, ctx


402
403
def _layernorm_mlp_bwd_rule(
    norm_type,
404
405
    zero_centered_gamma,
    epsilon,
Phuong Nguyen's avatar
Phuong Nguyen committed
406
    batch_sequence_transpose,
407
    norm_input_axes,
408
409
    dot_1_input_axes,
    dot_2_input_axes,
410
411
412
413
    kernel_1_axes,
    kernel_2_axes,
    ffn1_ckpt_name,
    ffn2_ckpt_name,
414
    activation_type,
415
    activation_params,
Phuong Nguyen's avatar
Phuong Nguyen committed
416
    collective_op_sets,
417
418
419
    ctx,
    grad,
):
420
421
422
423
424
425
426
427
428
429
430
431
432
    """Backward pass rule for layernorm_mlp.

    Implements the backward pass computation including:
    1. Gradient computation for second matrix multiplication
    2. Gradient computation for activation function
    3. Gradient computation for first matrix multiplication
    4. Gradient computation for layer normalization
    5. Gradient computation for bias terms
    6. Proper handling of quantization

    Returns:
        Tuple of gradients for all input parameters
    """
433
    del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name
434
435
436
437
438
    (
        x,
        mu,
        rsigma,
        gamma,
439
        beta,
440
441
        casted_ln_out,
        casted_kernel_1,
442
        dot_1_output,
443
444
        casted_act_out,
        casted_kernel_2,
445
446
447
448
449
450
451
        x_contracting_dims_in_fwd,
        k_contracting_dims_in_fwd,
        kernel_1_shape,
        kernel_2_shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
452
    ) = ctx
453

454
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
Phuong Nguyen's avatar
Phuong Nguyen committed
455
456
457
458
    collective_op_set_1, collective_op_set_2 = collective_op_sets

    assert not collective_op_set_1.backward.is_all_gather
    assert not collective_op_set_2.backward.is_reduce_scatter
459
460
461

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
462
463

    casted_grad, dbias_2 = tex.quantize_dbias(
464
465
466
        grad,
        is_dbias=use_bias_2,
        quantizer=ffn1_quantizer_set.dgrad,
467
        amax_scope=AmaxScope.TPSP,
468
    )
469

470
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
471
    g_contracting_dims_2 = tuple(
472
473
474
        range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
    )
    # k_non_contracting_dims
475
    k_contracting_dims_2 = tuple(
476
        dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
477
    )
478

479
    # NT GEMM
480
    # (batch..., hidden_out) x (hidden_in, hidden_out)
481
    dgrad_2 = tex.gemm(
482
483
        casted_grad.get_tensor(TensorUsage.LHS),
        casted_kernel_2,
Alp Dener's avatar
Alp Dener committed
484
        contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
Phuong Nguyen's avatar
Phuong Nguyen committed
485
486
        transpose_batch_sequence=batch_sequence_transpose,
        collective_op=collective_op_set_2.backward,
487
    )
488
489
490

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

491
    x_contracting_dims = g_contracting_dims = tuple(
492
        range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
493
494
    )

495
496
497
    # TN GEMM
    # (hidden, batch...,) x (hidden, batch...)
    wgrad_2 = tex.gemm(
498
499
        casted_act_out,
        casted_grad.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
500
        contracting_dims=(x_contracting_dims, g_contracting_dims),
Phuong Nguyen's avatar
Phuong Nguyen committed
501
        transpose_batch_sequence=batch_sequence_transpose,
502
    )
503
    wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
504

505
506
507
508
509
510
    casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
        dgrad_2,
        dot_1_output,
        activation_type=activation_type,
        is_dbias=use_bias_1,
        quantizer=ffn2_quantizer_set.dgrad,
511
512
513
514
515
        act_params=(
            tex.activation.ActivationParams.create(activation_type, **activation_params)
            if activation_params
            else None
        ),
516
    )
517
518

    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
519
    dact_out_ndim = casted_dact_out.get_tensor(TensorUsage.LHS).data.ndim
520
521
    g_contracting_dims_1 = tuple(
        range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
522
    )
523
    # k_non_contracting_dims
524
    k_contracting_dims_1 = tuple(
525
        dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
526
    )
527
528
529

    # NT GEMM
    dgrad_1 = tex.gemm(
530
531
        casted_dact_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1,
Alp Dener's avatar
Alp Dener committed
532
        contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
Phuong Nguyen's avatar
Phuong Nguyen committed
533
534
        transpose_batch_sequence=batch_sequence_transpose,
        collective_op=collective_op_set_1.backward,
535
    )
536

537
    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
538
539
540
541

    # TN GEMM
    # (hidden, batch...) x (hidden, batch...)
    wgrad_1 = tex.gemm(
542
543
        casted_ln_out,
        casted_dact_out.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
544
        contracting_dims=(x_contracting_dims, g_contracting_dims),
Phuong Nguyen's avatar
Phuong Nguyen committed
545
        transpose_batch_sequence=batch_sequence_transpose,
546
    )
547

548
549
    wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)

550
551
552
553
554
555
556
557
558
559
    dx, dgamma, dbeta = tex.normalization_bwd(
        dgrad_1,
        x,
        mu,
        rsigma,
        gamma,
        beta,
        zero_centered_gamma=zero_centered_gamma,
        epsilon=epsilon,
        norm_type=norm_type,
560
    )
561

562
563
    return (dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, quantizer_sets)

564

565
_layernorm_mlp.defvjp(_layernorm_mlp_fwd_rule, _layernorm_mlp_bwd_rule)