layernorm_mlp.py 19.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
11
12
13
14
"""Multi-layer perceptron (MLP) operations with layer normalization for Transformer Engine in JAX.

This module provides optimized implementations of MLP blocks commonly used in transformer
architectures. Each MLP block consists of:
1. Layer normalization
2. First dense layer transformation (GEMM1) with bias and activation
3. Second dense layer transformation (GEMM2) with bias

The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints.
"""
15

16
from typing import List, Tuple, Sequence, Union, Callable
17
from functools import partial
18
19
20

import jax
import jax.numpy as jnp
21
from jax.ad_checkpoint import checkpoint_name
22

23
from . import cpp_extensions as tex
24
from .cpp_extensions.amax import AmaxScope
25
from .layernorm import canonicalize_norm_type
26
27
28
29
30
from .quantize import (
    with_sharding_constraint_by_logical_axes,
    QuantizerSet,
    noop_quantizer_set,
    TensorUsage,
31
    get_quantize_config,
32
)
Alp Dener's avatar
Alp Dener committed
33
34


35
def layernorm_mlp(
36
37
38
39
40
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernels: List[jnp.ndarray],
    biases: List[jnp.ndarray],
41
    norm_type: str,
42
43
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
44
    transpose_batch_sequence: bool = False,
45
    norm_input_axes: Tuple[str, ...] = None,
46
47
    dot_1_input_axes: Tuple[str, ...] = None,
    dot_2_input_axes: Tuple[str, ...] = None,
48
49
    kernel_1_axes: Tuple[str, ...] = None,
    kernel_2_axes: Tuple[str, ...] = None,
50
51
52
    ffn1_ckpt_name: str = "ffn1",
    ffn2_ckpt_name: str = "ffn2",
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
53
    activation_params: dict = None,
Phuong Nguyen's avatar
Phuong Nguyen committed
54
55
56
57
    collective_op_sets: Tuple[tex.CollectiveOpSet] = (
        tex.noop_collective_op_set,
        tex.noop_collective_op_set,
    ),
58
    quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
59
) -> jnp.ndarray:
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    """Apply layer normalization followed by MLP block.

    This function implements the following sequence of operations:
        1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
        2. First dense layer transformation: y1 = x * kernel1 + bias1
        3. Activation function: y2 = activation(y1)
        4. Second dense layer transformation: y3 = y2 * kernel2 + bias2

    Args:
        x: Input tensor with shape [batch..., hidden_in]
        gamma: Scale parameter for normalization with shape [hidden_in]
        beta: Bias parameter for normalization with shape [hidden_in]
        kernels: List of two weight matrices:
            - kernel1: [hidden_in, intermediate]
            - kernel2: [intermediate, hidden_in]
        biases: List of two bias terms:
            - bias1: [intermediate]
            - bias2: [hidden_in]
        norm_type: Type of normalization ("layernorm" or "rmsnorm")
        zero_centered_gamma: Whether to use zero-centered gamma for normalization
        epsilon: Small constant for numerical stability in normalization
81
        transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
82
83
84
        norm_input_axes: Logical axes for sharding the layernorm input
        dot_1_input_axes: Logical axes for sharding the first matrix multiplication
        dot_2_input_axes: Logical axes for sharding the second matrix multiplication
85
86
        kernel_1_axes: Logical axes for sharding the first weight matrix
        kernel_2_axes: Logical axes for sharding the second weight matrix
87
88
89
        ffn1_ckpt_name: Name for checkpointing the first feed-forward network
        ffn2_ckpt_name: Name for checkpointing the second feed-forward network
        activation_type: Activation function(s) to apply after the first dense layer transformation
Phuong Nguyen's avatar
Phuong Nguyen committed
90
        collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations
91
92
93
94
95
96
97
98
99
100
101
        quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations

    Returns:
        Output tensor with shape [batch..., hidden_in]

    Note:
        - For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
          must be False
        - The function supports automatic differentiation through JAX's custom VJP
        - Quantization is applied to both dense layer transformations
        - Checkpointing is applied to both feed-forward networks for memory efficiency
102
103
104
105
106
107
108
109
    """
    assert len(kernels) == 2

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]

110
111
112
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "rmsnorm":
        assert beta is None, "beta should be None if norm_type is 'rmsnorm'"
113
114
        assert (
            not zero_centered_gamma
115
        ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
116

117
118
119
120
121
    if not get_quantize_config().is_fp8_enabled():
        input_dtype = x.dtype
        kernel_1 = kernel_1.astype(input_dtype)
        kernel_2 = kernel_2.astype(input_dtype)

122
    output = _layernorm_mlp(
123
124
125
126
127
128
129
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
130
        norm_type,
131
132
        zero_centered_gamma,
        epsilon,
133
        transpose_batch_sequence,
134
        norm_input_axes,
135
136
        dot_1_input_axes,
        dot_2_input_axes,
137
138
        kernel_1_axes,
        kernel_2_axes,
139
140
141
        ffn1_ckpt_name,
        ffn2_ckpt_name,
        activation_type,
142
        activation_params,
Phuong Nguyen's avatar
Phuong Nguyen committed
143
        collective_op_sets,
144
        quantizer_sets,
145
    )
146
147
148
    return output


149
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20))
150
def _layernorm_mlp(
151
152
153
154
155
156
157
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernel_1: jnp.ndarray,
    kernel_2: jnp.ndarray,
    bias_1: jnp.ndarray,
    bias_2: jnp.ndarray,
158
    norm_type: str,
159
160
    zero_centered_gamma: bool,
    epsilon: float,
161
    transpose_batch_sequence: bool,
162
    norm_input_axes: Tuple[str, ...],
163
164
    dot_1_input_axes: Tuple[str, ...],
    dot_2_input_axes: Tuple[str, ...],
165
166
    kernel_1_axes: Tuple[str, ...],
    kernel_2_axes: Tuple[str, ...],
167
168
169
    ffn1_ckpt_name: str,
    ffn2_ckpt_name: str,
    activation_type: Sequence[Union[str, Callable]],
170
    activation_params: dict,
Phuong Nguyen's avatar
Phuong Nguyen committed
171
    collective_op_sets: Tuple[tex.CollectiveOpSet],
172
    quantizer_sets,
173
):
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    """Internal implementation of layernorm_mlp with custom VJP.

    This function implements the forward pass of layernorm_mlp with support for
    automatic differentiation. It handles the normalization, dense layer transformations,
    activation, and quantization operations.

    Args:
        x: Input tensor
        gamma: Scale parameter for normalization
        beta: Bias parameter for normalization
        kernel_1: First weight matrix
        kernel_2: Second weight matrix
        bias_1: First bias term
        bias_2: Second bias term
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
191
        transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
192
193
194
        norm_input_axes: Logical axes for layernorm sharding
        dot_1_input_axes: Logical axes for first matrix multiplication sharding
        dot_2_input_axes: Logical axes for second matrix multiplication sharding
Phuong Nguyen's avatar
Phuong Nguyen committed
195
196
        kernel_1_axes: Logical axes for first weight matrix sharding
        kernel_2_axes: Logical axes for second weight matrix sharding
197
198
199
        ffn1_ckpt_name: Name for first feed-forward network checkpointing
        ffn2_ckpt_name: Name for second feed-forward network checkpointing
        activation_type: Activation function(s)
Phuong Nguyen's avatar
Phuong Nguyen committed
200
        collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations
201
202
203
204
205
206
        quantizer_sets: Tuple of quantizer sets

    Returns:
        Output tensor from the combined operations
    """
    output, _ = _layernorm_mlp_fwd_rule(
207
208
209
210
211
212
213
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
214
        norm_type,
215
216
        zero_centered_gamma,
        epsilon,
217
        transpose_batch_sequence,
218
        norm_input_axes,
219
220
        dot_1_input_axes,
        dot_2_input_axes,
221
222
        kernel_1_axes,
        kernel_2_axes,
223
        ffn1_ckpt_name,
224
225
        ffn2_ckpt_name,
        activation_type,
226
        activation_params,
Phuong Nguyen's avatar
Phuong Nguyen committed
227
        collective_op_sets,
228
        quantizer_sets,
229
230
231
232
    )
    return output


233
def _layernorm_mlp_fwd_rule(
234
235
236
237
238
239
240
    x,
    gamma,
    beta,
    kernel_1,
    kernel_2,
    bias_1,
    bias_2,
241
    norm_type,
242
243
    zero_centered_gamma,
    epsilon,
244
    transpose_batch_sequence,
245
    norm_input_axes,
246
247
    dot_1_input_axes,
    dot_2_input_axes,
248
249
    kernel_1_axes,
    kernel_2_axes,
250
251
252
    ffn1_ckpt_name,
    ffn2_ckpt_name,
    activation_type,
253
    activation_params,
Phuong Nguyen's avatar
Phuong Nguyen committed
254
    collective_op_sets,
255
    quantizer_sets,
256
):
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    """Forward pass rule for layernorm_mlp.

    Implements the forward pass computation including:
    1. Layer normalization with quantization
    2. First matrix multiplication with quantized kernel
    3. Activation function application
    4. Second matrix multiplication with quantized kernel
    5. Optional bias additions
    6. Sharding constraints
    7. Checkpointing for memory efficiency

    Returns:
        Tuple of (output, context) for automatic differentiation
    """
271
    del kernel_1_axes, kernel_2_axes
272

273
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
Phuong Nguyen's avatar
Phuong Nguyen committed
274
275
276
277
    collective_op_set_1, collective_op_set_2 = collective_op_sets

    assert not collective_op_set_1.forward.is_reduce_scatter
    assert not collective_op_set_2.forward.is_all_gather
278
279

    # x should be in shape of (batch..., hidden)
280
    # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
281
    # Kernel_2 should be in shape of (intermediate, hidden_in)
282
    assert len(kernel_1.shape) == 3
283
    assert len(kernel_2.shape) == 2
284
    assert kernel_1.shape[-2] == len(activation_type)
285
286

    x_contracting_dims = (len(x.shape) - 1,)
287
    k_contracting_dims = (0,)
288

289
    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
290

291
292
293
294
295
296
297
298
299
300
301
302
303
    use_bias_1 = bias_1 is not None
    use_bias_2 = bias_1 is not None

    x = with_sharding_constraint_by_logical_axes(x, norm_input_axes)

    casted_ln_out, mu, rsigma = tex.normalization_fwd(
        x,
        gamma,
        beta,
        zero_centered_gamma,
        epsilon,
        norm_type,
        quantizer=ffn1_quantizer_set.x,
304
        amax_scope=AmaxScope.TPSP,
305
        transpose_batch_sequence=transpose_batch_sequence,
306
    )
307
    casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
308

Alp Dener's avatar
Alp Dener committed
309
    casted_kernel_1 = tex.quantize(
310
311
312
313
314
        kernel_1,
        flatten_axis=-2,
        quantizer=ffn1_quantizer_set.kernel,
        amax_scope=AmaxScope.FSDP,
        transpose_batch_sequence=transpose_batch_sequence,
Alp Dener's avatar
Alp Dener committed
315
    )
316

317
    # NN GEMM
318
    # (batch..., hidden_in) x (hidden_in, hidden_out)
319
    dot_1_output = tex.gemm(
320
321
        casted_ln_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
322
        contracting_dims=(x_contracting_dims, k_contracting_dims),
323
        transpose_batch_sequence=transpose_batch_sequence,
Alp Dener's avatar
Alp Dener committed
324
325
        bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
Phuong Nguyen's avatar
Phuong Nguyen committed
326
        collective_op=collective_op_set_1.forward,
327
    )
328

Alp Dener's avatar
Alp Dener committed
329
    if use_bias_1 and tex.gemm_uses_jax_dot():
330
331
332
        bias_1_shape = bias_1.shape
        bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
        dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
333

334
335
336
337
338
339
340
    # This sharding constraint is needed to correct the Shardy sharding propagation
    if dot_2_input_axes is not None:
        dot_1_output_axes = (
            dot_2_input_axes[:-1] + (None,) + dot_2_input_axes[-1:]
        )  # add the act_num axis
        dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)

341
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
342
343

    # (batch..., hidden_in) -> (batch..., hidden)
Alp Dener's avatar
Alp Dener committed
344
    casted_act_out = tex.act_lu(
345
346
347
        dot_1_output,
        activation_type,
        quantizer=ffn2_quantizer_set.x,
348
349
350
351
352
        act_params=(
            tex.activation.ActivationParams.create(activation_type, **activation_params)
            if activation_params
            else None
        ),
353
354
        amax_scope=AmaxScope.TPSP,
        transpose_batch_sequence=transpose_batch_sequence,
Alp Dener's avatar
Alp Dener committed
355
    )
356

357
    casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
358

Alp Dener's avatar
Alp Dener committed
359
    casted_kernel_2 = tex.quantize(
360
361
        kernel_2,
        quantizer=ffn2_quantizer_set.kernel,
362
        amax_scope=AmaxScope.FSDP,
363
        transpose_batch_sequence=transpose_batch_sequence,
Alp Dener's avatar
Alp Dener committed
364
    )
365

366
    # NN GEMM
367
    # (batch..., hidden_in) x (hidden_out, hidden_in)
368
    dot_2_output = tex.gemm(
369
370
        casted_act_out.get_tensor(TensorUsage.LHS),
        casted_kernel_2.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
371
        contracting_dims=(x_contracting_dims, k_contracting_dims),
372
        transpose_batch_sequence=transpose_batch_sequence,
Alp Dener's avatar
Alp Dener committed
373
374
        bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
Phuong Nguyen's avatar
Phuong Nguyen committed
375
        collective_op=collective_op_set_2.forward,
376
    )
377

Alp Dener's avatar
Alp Dener committed
378
    if use_bias_2 and tex.gemm_uses_jax_dot():
379
380
381
        bias_2_shape = bias_2.shape
        bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
        dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
382

Phuong Nguyen's avatar
Phuong Nguyen committed
383
384
    # sharding of outputs should be the same as dot_1's input
    dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_1_input_axes)
385
386
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

387
388
389
390
391
    ctx = (
        x,
        mu,
        rsigma,
        gamma,
392
        beta,
393
394
        casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
        casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS),
395
        dot_1_output,
396
397
        casted_act_out.get_tensor(TensorUsage.LHS_TRANS),
        casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS),
398
        x_contracting_dims,
399
400
401
402
403
404
        k_contracting_dims,
        kernel_1.shape,
        kernel_2.shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
405
    )
406
407
408
409

    return dot_2_output, ctx


410
411
def _layernorm_mlp_bwd_rule(
    norm_type,
412
413
    zero_centered_gamma,
    epsilon,
414
    transpose_batch_sequence,
415
    norm_input_axes,
416
417
    dot_1_input_axes,
    dot_2_input_axes,
418
419
420
421
    kernel_1_axes,
    kernel_2_axes,
    ffn1_ckpt_name,
    ffn2_ckpt_name,
422
    activation_type,
423
    activation_params,
Phuong Nguyen's avatar
Phuong Nguyen committed
424
    collective_op_sets,
425
426
427
    ctx,
    grad,
):
428
429
430
431
432
433
434
435
436
437
438
439
440
    """Backward pass rule for layernorm_mlp.

    Implements the backward pass computation including:
    1. Gradient computation for second matrix multiplication
    2. Gradient computation for activation function
    3. Gradient computation for first matrix multiplication
    4. Gradient computation for layer normalization
    5. Gradient computation for bias terms
    6. Proper handling of quantization

    Returns:
        Tuple of gradients for all input parameters
    """
441
    del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name
442
443
444
445
446
    (
        x,
        mu,
        rsigma,
        gamma,
447
        beta,
448
449
        casted_ln_out,
        casted_kernel_1,
450
        dot_1_output,
451
452
        casted_act_out,
        casted_kernel_2,
453
454
455
456
457
458
459
        x_contracting_dims_in_fwd,
        k_contracting_dims_in_fwd,
        kernel_1_shape,
        kernel_2_shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
460
    ) = ctx
461

462
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
Phuong Nguyen's avatar
Phuong Nguyen committed
463
464
465
466
    collective_op_set_1, collective_op_set_2 = collective_op_sets

    assert not collective_op_set_1.backward.is_all_gather
    assert not collective_op_set_2.backward.is_reduce_scatter
467
468
469

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
470
471

    casted_grad, dbias_2 = tex.quantize_dbias(
472
473
474
        grad,
        is_dbias=use_bias_2,
        quantizer=ffn1_quantizer_set.dgrad,
475
        amax_scope=AmaxScope.TPSP,
476
        transpose_batch_sequence=transpose_batch_sequence,
477
    )
478

479
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
480
    g_contracting_dims_2 = tuple(
481
482
483
        range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
    )
    # k_non_contracting_dims
484
    k_contracting_dims_2 = tuple(
485
        dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
486
    )
487

488
    # NT GEMM
489
    # (batch..., hidden_out) x (hidden_in, hidden_out)
490
    dgrad_2 = tex.gemm(
491
492
        casted_grad.get_tensor(TensorUsage.LHS),
        casted_kernel_2,
Alp Dener's avatar
Alp Dener committed
493
        contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
494
        transpose_batch_sequence=transpose_batch_sequence,
Phuong Nguyen's avatar
Phuong Nguyen committed
495
        collective_op=collective_op_set_2.backward,
496
    )
497
498
499

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

500
    x_contracting_dims = g_contracting_dims = tuple(
501
        range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
502
503
    )

504
505
506
    # TN GEMM
    # (hidden, batch...,) x (hidden, batch...)
    wgrad_2 = tex.gemm(
507
508
        casted_act_out,
        casted_grad.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
509
        contracting_dims=(x_contracting_dims, g_contracting_dims),
510
        transpose_batch_sequence=transpose_batch_sequence,
511
    )
512
    wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
513

514
515
516
517
518
519
    casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
        dgrad_2,
        dot_1_output,
        activation_type=activation_type,
        is_dbias=use_bias_1,
        quantizer=ffn2_quantizer_set.dgrad,
520
521
522
523
524
        act_params=(
            tex.activation.ActivationParams.create(activation_type, **activation_params)
            if activation_params
            else None
        ),
525
526
        amax_scope=AmaxScope.TPSP,
        transpose_batch_sequence=transpose_batch_sequence,
527
    )
528
529

    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
530
    dact_out_ndim = casted_dact_out.get_tensor(TensorUsage.LHS).data.ndim
531
532
    g_contracting_dims_1 = tuple(
        range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
533
    )
534
    # k_non_contracting_dims
535
    k_contracting_dims_1 = tuple(
536
        dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
537
    )
538
539
540

    # NT GEMM
    dgrad_1 = tex.gemm(
541
542
        casted_dact_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1,
Alp Dener's avatar
Alp Dener committed
543
        contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
544
        transpose_batch_sequence=transpose_batch_sequence,
Phuong Nguyen's avatar
Phuong Nguyen committed
545
        collective_op=collective_op_set_1.backward,
546
    )
547

548
    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
549
550
551
552

    # TN GEMM
    # (hidden, batch...) x (hidden, batch...)
    wgrad_1 = tex.gemm(
553
554
        casted_ln_out,
        casted_dact_out.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
555
        contracting_dims=(x_contracting_dims, g_contracting_dims),
556
        transpose_batch_sequence=transpose_batch_sequence,
557
    )
558

559
560
    wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)

561
562
563
564
565
566
567
568
569
570
    dx, dgamma, dbeta = tex.normalization_bwd(
        dgrad_1,
        x,
        mu,
        rsigma,
        gamma,
        beta,
        zero_centered_gamma=zero_centered_gamma,
        epsilon=epsilon,
        norm_type=norm_type,
571
    )
572

573
574
    return (dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, quantizer_sets)

575

576
_layernorm_mlp.defvjp(_layernorm_mlp_fwd_rule, _layernorm_mlp_bwd_rule)