layernorm_mlp.py 19.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
11
12
13
14
"""Multi-layer perceptron (MLP) operations with layer normalization for Transformer Engine in JAX.

This module provides optimized implementations of MLP blocks commonly used in transformer
architectures. Each MLP block consists of:
1. Layer normalization
2. First dense layer transformation (GEMM1) with bias and activation
3. Second dense layer transformation (GEMM2) with bias

The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints.
"""
15

16
from typing import List, Tuple, Sequence, Union, Callable
17
from functools import partial
18
19
20

import jax
import jax.numpy as jnp
21
from jax.ad_checkpoint import checkpoint_name
22

23
from . import cpp_extensions as tex
24
from .cpp_extensions.amax import AmaxScope
25
from .layernorm import canonicalize_norm_type
26
27
28
29
30
31
from .quantize import (
    with_sharding_constraint_by_logical_axes,
    QuantizerSet,
    noop_quantizer_set,
    TensorUsage,
)
Alp Dener's avatar
Alp Dener committed
32
33


34
def layernorm_mlp(
35
36
37
38
39
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernels: List[jnp.ndarray],
    biases: List[jnp.ndarray],
40
    norm_type: str,
41
42
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
43
    transpose_batch_sequence: bool = False,
44
    norm_input_axes: Tuple[str, ...] = None,
45
46
    dot_1_input_axes: Tuple[str, ...] = None,
    dot_2_input_axes: Tuple[str, ...] = None,
47
48
    kernel_1_axes: Tuple[str, ...] = None,
    kernel_2_axes: Tuple[str, ...] = None,
49
50
51
    ffn1_ckpt_name: str = "ffn1",
    ffn2_ckpt_name: str = "ffn2",
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
52
    activation_params: dict = None,
Phuong Nguyen's avatar
Phuong Nguyen committed
53
54
55
56
    collective_op_sets: Tuple[tex.CollectiveOpSet] = (
        tex.noop_collective_op_set,
        tex.noop_collective_op_set,
    ),
57
    quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
58
) -> jnp.ndarray:
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    """Apply layer normalization followed by MLP block.

    This function implements the following sequence of operations:
        1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
        2. First dense layer transformation: y1 = x * kernel1 + bias1
        3. Activation function: y2 = activation(y1)
        4. Second dense layer transformation: y3 = y2 * kernel2 + bias2

    Args:
        x: Input tensor with shape [batch..., hidden_in]
        gamma: Scale parameter for normalization with shape [hidden_in]
        beta: Bias parameter for normalization with shape [hidden_in]
        kernels: List of two weight matrices:
            - kernel1: [hidden_in, intermediate]
            - kernel2: [intermediate, hidden_in]
        biases: List of two bias terms:
            - bias1: [intermediate]
            - bias2: [hidden_in]
        norm_type: Type of normalization ("layernorm" or "rmsnorm")
        zero_centered_gamma: Whether to use zero-centered gamma for normalization
        epsilon: Small constant for numerical stability in normalization
80
        transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
81
82
83
        norm_input_axes: Logical axes for sharding the layernorm input
        dot_1_input_axes: Logical axes for sharding the first matrix multiplication
        dot_2_input_axes: Logical axes for sharding the second matrix multiplication
84
85
        kernel_1_axes: Logical axes for sharding the first weight matrix
        kernel_2_axes: Logical axes for sharding the second weight matrix
86
87
88
        ffn1_ckpt_name: Name for checkpointing the first feed-forward network
        ffn2_ckpt_name: Name for checkpointing the second feed-forward network
        activation_type: Activation function(s) to apply after the first dense layer transformation
Phuong Nguyen's avatar
Phuong Nguyen committed
89
        collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations
90
91
92
93
94
95
96
97
98
99
100
        quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations

    Returns:
        Output tensor with shape [batch..., hidden_in]

    Note:
        - For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
          must be False
        - The function supports automatic differentiation through JAX's custom VJP
        - Quantization is applied to both dense layer transformations
        - Checkpointing is applied to both feed-forward networks for memory efficiency
101
102
103
104
105
106
107
108
    """
    assert len(kernels) == 2

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]

109
110
111
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "rmsnorm":
        assert beta is None, "beta should be None if norm_type is 'rmsnorm'"
112
113
        assert (
            not zero_centered_gamma
114
        ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
115

116
    if quantizer_sets == (noop_quantizer_set, noop_quantizer_set):
117
118
119
120
        input_dtype = x.dtype
        kernel_1 = kernel_1.astype(input_dtype)
        kernel_2 = kernel_2.astype(input_dtype)

121
    output = _layernorm_mlp(
122
123
124
125
126
127
128
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
129
        norm_type,
130
131
        zero_centered_gamma,
        epsilon,
132
        transpose_batch_sequence,
133
        norm_input_axes,
134
135
        dot_1_input_axes,
        dot_2_input_axes,
136
137
        kernel_1_axes,
        kernel_2_axes,
138
139
140
        ffn1_ckpt_name,
        ffn2_ckpt_name,
        activation_type,
141
        activation_params,
Phuong Nguyen's avatar
Phuong Nguyen committed
142
        collective_op_sets,
143
        quantizer_sets,
144
    )
145
146
147
    return output


148
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20))
149
def _layernorm_mlp(
150
151
152
153
154
155
156
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernel_1: jnp.ndarray,
    kernel_2: jnp.ndarray,
    bias_1: jnp.ndarray,
    bias_2: jnp.ndarray,
157
    norm_type: str,
158
159
    zero_centered_gamma: bool,
    epsilon: float,
160
    transpose_batch_sequence: bool,
161
    norm_input_axes: Tuple[str, ...],
162
163
    dot_1_input_axes: Tuple[str, ...],
    dot_2_input_axes: Tuple[str, ...],
164
165
    kernel_1_axes: Tuple[str, ...],
    kernel_2_axes: Tuple[str, ...],
166
167
168
    ffn1_ckpt_name: str,
    ffn2_ckpt_name: str,
    activation_type: Sequence[Union[str, Callable]],
169
    activation_params: dict,
Phuong Nguyen's avatar
Phuong Nguyen committed
170
    collective_op_sets: Tuple[tex.CollectiveOpSet],
171
    quantizer_sets,
172
):
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    """Internal implementation of layernorm_mlp with custom VJP.

    This function implements the forward pass of layernorm_mlp with support for
    automatic differentiation. It handles the normalization, dense layer transformations,
    activation, and quantization operations.

    Args:
        x: Input tensor
        gamma: Scale parameter for normalization
        beta: Bias parameter for normalization
        kernel_1: First weight matrix
        kernel_2: Second weight matrix
        bias_1: First bias term
        bias_2: Second bias term
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
190
        transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
191
192
193
        norm_input_axes: Logical axes for layernorm sharding
        dot_1_input_axes: Logical axes for first matrix multiplication sharding
        dot_2_input_axes: Logical axes for second matrix multiplication sharding
Phuong Nguyen's avatar
Phuong Nguyen committed
194
195
        kernel_1_axes: Logical axes for first weight matrix sharding
        kernel_2_axes: Logical axes for second weight matrix sharding
196
197
198
        ffn1_ckpt_name: Name for first feed-forward network checkpointing
        ffn2_ckpt_name: Name for second feed-forward network checkpointing
        activation_type: Activation function(s)
Phuong Nguyen's avatar
Phuong Nguyen committed
199
        collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations
200
201
202
203
204
205
        quantizer_sets: Tuple of quantizer sets

    Returns:
        Output tensor from the combined operations
    """
    output, _ = _layernorm_mlp_fwd_rule(
206
207
208
209
210
211
212
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
213
        norm_type,
214
215
        zero_centered_gamma,
        epsilon,
216
        transpose_batch_sequence,
217
        norm_input_axes,
218
219
        dot_1_input_axes,
        dot_2_input_axes,
220
221
        kernel_1_axes,
        kernel_2_axes,
222
        ffn1_ckpt_name,
223
224
        ffn2_ckpt_name,
        activation_type,
225
        activation_params,
Phuong Nguyen's avatar
Phuong Nguyen committed
226
        collective_op_sets,
227
        quantizer_sets,
228
229
230
231
    )
    return output


232
def _layernorm_mlp_fwd_rule(
233
234
235
236
237
238
239
    x,
    gamma,
    beta,
    kernel_1,
    kernel_2,
    bias_1,
    bias_2,
240
    norm_type,
241
242
    zero_centered_gamma,
    epsilon,
243
    transpose_batch_sequence,
244
    norm_input_axes,
245
246
    dot_1_input_axes,
    dot_2_input_axes,
247
248
    kernel_1_axes,
    kernel_2_axes,
249
250
251
    ffn1_ckpt_name,
    ffn2_ckpt_name,
    activation_type,
252
    activation_params,
Phuong Nguyen's avatar
Phuong Nguyen committed
253
    collective_op_sets,
254
    quantizer_sets,
255
):
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    """Forward pass rule for layernorm_mlp.

    Implements the forward pass computation including:
    1. Layer normalization with quantization
    2. First matrix multiplication with quantized kernel
    3. Activation function application
    4. Second matrix multiplication with quantized kernel
    5. Optional bias additions
    6. Sharding constraints
    7. Checkpointing for memory efficiency

    Returns:
        Tuple of (output, context) for automatic differentiation
    """
270
    del kernel_1_axes, kernel_2_axes
271

272
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
Phuong Nguyen's avatar
Phuong Nguyen committed
273
274
275
276
    collective_op_set_1, collective_op_set_2 = collective_op_sets

    assert not collective_op_set_1.forward.is_reduce_scatter
    assert not collective_op_set_2.forward.is_all_gather
277
278

    # x should be in shape of (batch..., hidden)
279
    # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
280
    # Kernel_2 should be in shape of (intermediate, hidden_in)
281
    assert len(kernel_1.shape) == 3
282
    assert len(kernel_2.shape) == 2
283
    assert kernel_1.shape[-2] == len(activation_type)
284
285

    x_contracting_dims = (len(x.shape) - 1,)
286
    k_contracting_dims = (0,)
287

288
    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
289

290
291
292
293
294
295
296
297
298
299
300
301
302
    use_bias_1 = bias_1 is not None
    use_bias_2 = bias_1 is not None

    x = with_sharding_constraint_by_logical_axes(x, norm_input_axes)

    casted_ln_out, mu, rsigma = tex.normalization_fwd(
        x,
        gamma,
        beta,
        zero_centered_gamma,
        epsilon,
        norm_type,
        quantizer=ffn1_quantizer_set.x,
303
        amax_scope=AmaxScope.TPSP,
304
        transpose_batch_sequence=transpose_batch_sequence,
305
    )
306
    casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
307

Alp Dener's avatar
Alp Dener committed
308
    casted_kernel_1 = tex.quantize(
309
310
311
312
313
        kernel_1,
        flatten_axis=-2,
        quantizer=ffn1_quantizer_set.kernel,
        amax_scope=AmaxScope.FSDP,
        transpose_batch_sequence=transpose_batch_sequence,
Alp Dener's avatar
Alp Dener committed
314
    )
315

316
    # NN GEMM
317
    # (batch..., hidden_in) x (hidden_in, hidden_out)
318
    dot_1_output = tex.gemm(
319
320
        casted_ln_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
321
        contracting_dims=(x_contracting_dims, k_contracting_dims),
322
        transpose_batch_sequence=transpose_batch_sequence,
Alp Dener's avatar
Alp Dener committed
323
324
        bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
Phuong Nguyen's avatar
Phuong Nguyen committed
325
        collective_op=collective_op_set_1.forward,
326
    )
327

Alp Dener's avatar
Alp Dener committed
328
    if use_bias_1 and tex.gemm_uses_jax_dot():
329
330
331
        bias_1_shape = bias_1.shape
        bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
        dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
332

333
334
335
336
337
338
339
    # This sharding constraint is needed to correct the Shardy sharding propagation
    if dot_2_input_axes is not None:
        dot_1_output_axes = (
            dot_2_input_axes[:-1] + (None,) + dot_2_input_axes[-1:]
        )  # add the act_num axis
        dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)

340
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
341
342

    # (batch..., hidden_in) -> (batch..., hidden)
Alp Dener's avatar
Alp Dener committed
343
    casted_act_out = tex.act_lu(
344
345
346
        dot_1_output,
        activation_type,
        quantizer=ffn2_quantizer_set.x,
347
348
349
350
351
        act_params=(
            tex.activation.ActivationParams.create(activation_type, **activation_params)
            if activation_params
            else None
        ),
352
353
        amax_scope=AmaxScope.TPSP,
        transpose_batch_sequence=transpose_batch_sequence,
Alp Dener's avatar
Alp Dener committed
354
    )
355

356
    casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
357

Alp Dener's avatar
Alp Dener committed
358
    casted_kernel_2 = tex.quantize(
359
360
        kernel_2,
        quantizer=ffn2_quantizer_set.kernel,
361
        amax_scope=AmaxScope.FSDP,
362
        transpose_batch_sequence=transpose_batch_sequence,
Alp Dener's avatar
Alp Dener committed
363
    )
364

365
    # NN GEMM
366
    # (batch..., hidden_in) x (hidden_out, hidden_in)
367
    dot_2_output = tex.gemm(
368
369
        casted_act_out.get_tensor(TensorUsage.LHS),
        casted_kernel_2.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
370
        contracting_dims=(x_contracting_dims, k_contracting_dims),
371
        transpose_batch_sequence=transpose_batch_sequence,
Alp Dener's avatar
Alp Dener committed
372
373
        bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
Phuong Nguyen's avatar
Phuong Nguyen committed
374
        collective_op=collective_op_set_2.forward,
375
    )
376

Alp Dener's avatar
Alp Dener committed
377
    if use_bias_2 and tex.gemm_uses_jax_dot():
378
379
380
        bias_2_shape = bias_2.shape
        bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
        dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
381

Phuong Nguyen's avatar
Phuong Nguyen committed
382
383
    # sharding of outputs should be the same as dot_1's input
    dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_1_input_axes)
384
385
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

386
387
388
389
390
    ctx = (
        x,
        mu,
        rsigma,
        gamma,
391
        beta,
392
393
        casted_ln_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(ffn1_quantizer_set.x),
        casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS).checkpoint(ffn1_quantizer_set.kernel),
394
        dot_1_output,
395
396
        casted_act_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(ffn2_quantizer_set.x),
        casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS).checkpoint(ffn2_quantizer_set.kernel),
397
        x_contracting_dims,
398
399
400
401
402
403
        k_contracting_dims,
        kernel_1.shape,
        kernel_2.shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
404
    )
405
406
407
408

    return dot_2_output, ctx


409
410
def _layernorm_mlp_bwd_rule(
    norm_type,
411
412
    zero_centered_gamma,
    epsilon,
413
    transpose_batch_sequence,
414
    norm_input_axes,
415
416
    dot_1_input_axes,
    dot_2_input_axes,
417
418
419
420
    kernel_1_axes,
    kernel_2_axes,
    ffn1_ckpt_name,
    ffn2_ckpt_name,
421
    activation_type,
422
    activation_params,
Phuong Nguyen's avatar
Phuong Nguyen committed
423
    collective_op_sets,
424
425
426
    ctx,
    grad,
):
427
428
429
430
431
432
433
434
435
436
437
438
439
    """Backward pass rule for layernorm_mlp.

    Implements the backward pass computation including:
    1. Gradient computation for second matrix multiplication
    2. Gradient computation for activation function
    3. Gradient computation for first matrix multiplication
    4. Gradient computation for layer normalization
    5. Gradient computation for bias terms
    6. Proper handling of quantization

    Returns:
        Tuple of gradients for all input parameters
    """
440
    del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name
441
442
443
444
445
    (
        x,
        mu,
        rsigma,
        gamma,
446
        beta,
447
448
        casted_ln_out,
        casted_kernel_1,
449
        dot_1_output,
450
451
        casted_act_out,
        casted_kernel_2,
452
453
454
455
456
457
458
        x_contracting_dims_in_fwd,
        k_contracting_dims_in_fwd,
        kernel_1_shape,
        kernel_2_shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
459
    ) = ctx
460

461
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
Phuong Nguyen's avatar
Phuong Nguyen committed
462
463
464
465
    collective_op_set_1, collective_op_set_2 = collective_op_sets

    assert not collective_op_set_1.backward.is_all_gather
    assert not collective_op_set_2.backward.is_reduce_scatter
466
467
468

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
469
470

    casted_grad, dbias_2 = tex.quantize_dbias(
471
472
473
        grad,
        is_dbias=use_bias_2,
        quantizer=ffn1_quantizer_set.dgrad,
474
        amax_scope=AmaxScope.TPSP,
475
        transpose_batch_sequence=transpose_batch_sequence,
476
    )
477

478
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
479
    g_contracting_dims_2 = tuple(
480
481
482
        range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
    )
    # k_non_contracting_dims
483
    k_contracting_dims_2 = tuple(
484
        dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
485
    )
486

487
    # NT GEMM
488
    # (batch..., hidden_out) x (hidden_in, hidden_out)
489
    dgrad_2 = tex.gemm(
490
491
        casted_grad.get_tensor(TensorUsage.LHS),
        casted_kernel_2,
Alp Dener's avatar
Alp Dener committed
492
        contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
493
        transpose_batch_sequence=transpose_batch_sequence,
Phuong Nguyen's avatar
Phuong Nguyen committed
494
        collective_op=collective_op_set_2.backward,
495
    )
496
497
498

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

499
    x_contracting_dims = g_contracting_dims = tuple(
500
        range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
501
502
    )

503
504
505
    # TN GEMM
    # (hidden, batch...,) x (hidden, batch...)
    wgrad_2 = tex.gemm(
506
507
        casted_act_out,
        casted_grad.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
508
        contracting_dims=(x_contracting_dims, g_contracting_dims),
509
        transpose_batch_sequence=transpose_batch_sequence,
510
    )
511
    wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
512

513
514
515
516
517
518
    casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
        dgrad_2,
        dot_1_output,
        activation_type=activation_type,
        is_dbias=use_bias_1,
        quantizer=ffn2_quantizer_set.dgrad,
519
520
521
522
523
        act_params=(
            tex.activation.ActivationParams.create(activation_type, **activation_params)
            if activation_params
            else None
        ),
524
525
        amax_scope=AmaxScope.TPSP,
        transpose_batch_sequence=transpose_batch_sequence,
526
    )
527
528

    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
529
    dact_out_ndim = casted_dact_out.get_tensor(TensorUsage.LHS).data.ndim
530
531
    g_contracting_dims_1 = tuple(
        range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
532
    )
533
    # k_non_contracting_dims
534
    k_contracting_dims_1 = tuple(
535
        dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
536
    )
537
538
539

    # NT GEMM
    dgrad_1 = tex.gemm(
540
541
        casted_dact_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1,
Alp Dener's avatar
Alp Dener committed
542
        contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
543
        transpose_batch_sequence=transpose_batch_sequence,
Phuong Nguyen's avatar
Phuong Nguyen committed
544
        collective_op=collective_op_set_1.backward,
545
    )
546

547
    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
548
549
550
551

    # TN GEMM
    # (hidden, batch...) x (hidden, batch...)
    wgrad_1 = tex.gemm(
552
553
        casted_ln_out,
        casted_dact_out.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
554
        contracting_dims=(x_contracting_dims, g_contracting_dims),
555
        transpose_batch_sequence=transpose_batch_sequence,
556
    )
557

558
559
    wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)

560
561
562
563
564
565
566
567
568
569
    dx, dgamma, dbeta = tex.normalization_bwd(
        dgrad_1,
        x,
        mu,
        rsigma,
        gamma,
        beta,
        zero_centered_gamma=zero_centered_gamma,
        epsilon=epsilon,
        norm_type=norm_type,
570
    )
571

572
573
    return (dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, quantizer_sets)

574

575
_layernorm_mlp.defvjp(_layernorm_mlp_fwd_rule, _layernorm_mlp_bwd_rule)