layernorm_mlp.py 16.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
11
12
13
14
"""Multi-layer perceptron (MLP) operations with layer normalization for Transformer Engine in JAX.

This module provides optimized implementations of MLP blocks commonly used in transformer
architectures. Each MLP block consists of:
1. Layer normalization
2. First dense layer transformation (GEMM1) with bias and activation
3. Second dense layer transformation (GEMM2) with bias

The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints.
"""
15

16
from typing import List, Tuple, Sequence, Union, Callable
17
from functools import partial
18
import warnings
19
20
21

import jax
import jax.numpy as jnp
22
from jax.ad_checkpoint import checkpoint_name
23

24
from . import cpp_extensions as tex
25
from .layernorm import canonicalize_norm_type
26
27
28
29
30
31
from .quantize import (
    with_sharding_constraint_by_logical_axes,
    QuantizerSet,
    noop_quantizer_set,
    TensorUsage,
)
Alp Dener's avatar
Alp Dener committed
32
33


34
def layernorm_mlp(
35
36
37
38
39
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernels: List[jnp.ndarray],
    biases: List[jnp.ndarray],
40
    norm_type: str,
41
42
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
43
    norm_input_axes: Tuple[str, ...] = None,
44
45
    dot_1_input_axes: Tuple[str, ...] = None,
    dot_2_input_axes: Tuple[str, ...] = None,
46
47
    kernel_1_axes: Tuple[str, ...] = None,
    kernel_2_axes: Tuple[str, ...] = None,
48
49
50
    ffn1_ckpt_name: str = "ffn1",
    ffn2_ckpt_name: str = "ffn2",
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
51
    quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
52
) -> jnp.ndarray:
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    """Apply layer normalization followed by MLP block.

    This function implements the following sequence of operations:
        1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
        2. First dense layer transformation: y1 = x * kernel1 + bias1
        3. Activation function: y2 = activation(y1)
        4. Second dense layer transformation: y3 = y2 * kernel2 + bias2

    Args:
        x: Input tensor with shape [batch..., hidden_in]
        gamma: Scale parameter for normalization with shape [hidden_in]
        beta: Bias parameter for normalization with shape [hidden_in]
        kernels: List of two weight matrices:
            - kernel1: [hidden_in, intermediate]
            - kernel2: [intermediate, hidden_in]
        biases: List of two bias terms:
            - bias1: [intermediate]
            - bias2: [hidden_in]
        norm_type: Type of normalization ("layernorm" or "rmsnorm")
        zero_centered_gamma: Whether to use zero-centered gamma for normalization
        epsilon: Small constant for numerical stability in normalization
        norm_input_axes: Logical axes for sharding the layernorm input
        dot_1_input_axes: Logical axes for sharding the first matrix multiplication
        dot_2_input_axes: Logical axes for sharding the second matrix multiplication
77
78
        kernel_1_axes: Logical axes for sharding the first weight matrix
        kernel_2_axes: Logical axes for sharding the second weight matrix
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        ffn1_ckpt_name: Name for checkpointing the first feed-forward network
        ffn2_ckpt_name: Name for checkpointing the second feed-forward network
        activation_type: Activation function(s) to apply after the first dense layer transformation
        quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations

    Returns:
        Output tensor with shape [batch..., hidden_in]

    Note:
        - For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
          must be False
        - The function supports automatic differentiation through JAX's custom VJP
        - Quantization is applied to both dense layer transformations
        - Checkpointing is applied to both feed-forward networks for memory efficiency
93
94
95
    """
    assert len(kernels) == 2

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    # For MaxText TP (= Megatron TP + sharding in hidden dimension of remaining unsharded
    # activations), JAX dot_general may perform better then TE GEMM custom call
    # This inspection only works if either norm_input_axes or dot_1_input_axes is set
    is_mxfp8 = (
        False
        if quantizer_sets[0] == noop_quantizer_set
        else quantizer_sets[0].x.scaling_mode.is_1d_block_scaling()
    )
    inspect_axes = norm_input_axes or dot_1_input_axes
    if (
        inspect_axes is not None
        and len(inspect_axes) == x.ndim
        and inspect_axes[-1] is not None
        and not is_mxfp8
    ):
        warnings.warn(
            "Detected sharding in the hidden dimension of the MLP activation input. For improved"
            " performance, consider using JAX’s built-in `dot_general` implementation.  To try"
            " this, set the environment variable: `NVTE_JAX_CUSTOM_CALLS='GemmPrimitive=false'`",
            UserWarning,
        )

118
119
120
121
122
    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]

123
124
125
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "rmsnorm":
        assert beta is None, "beta should be None if norm_type is 'rmsnorm'"
126
127
        assert (
            not zero_centered_gamma
128
        ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
129

130
    output = _layernorm_mlp(
131
132
133
134
135
136
137
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
138
        norm_type,
139
140
        zero_centered_gamma,
        epsilon,
141
        norm_input_axes,
142
143
        dot_1_input_axes,
        dot_2_input_axes,
144
145
        kernel_1_axes,
        kernel_2_axes,
146
147
148
        ffn1_ckpt_name,
        ffn2_ckpt_name,
        activation_type,
149
        quantizer_sets,
150
    )
151
152
153
    return output


154
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
155
def _layernorm_mlp(
156
157
158
159
160
161
162
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernel_1: jnp.ndarray,
    kernel_2: jnp.ndarray,
    bias_1: jnp.ndarray,
    bias_2: jnp.ndarray,
163
    norm_type: str,
164
165
    zero_centered_gamma: bool,
    epsilon: float,
166
    norm_input_axes: Tuple[str, ...],
167
168
    dot_1_input_axes: Tuple[str, ...],
    dot_2_input_axes: Tuple[str, ...],
169
170
    kernel_1_axes: Tuple[str, ...],
    kernel_2_axes: Tuple[str, ...],
171
172
173
    ffn1_ckpt_name: str,
    ffn2_ckpt_name: str,
    activation_type: Sequence[Union[str, Callable]],
174
    quantizer_sets,
175
):
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    """Internal implementation of layernorm_mlp with custom VJP.

    This function implements the forward pass of layernorm_mlp with support for
    automatic differentiation. It handles the normalization, dense layer transformations,
    activation, and quantization operations.

    Args:
        x: Input tensor
        gamma: Scale parameter for normalization
        beta: Bias parameter for normalization
        kernel_1: First weight matrix
        kernel_2: Second weight matrix
        bias_1: First bias term
        bias_2: Second bias term
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
        norm_input_axes: Logical axes for layernorm sharding
        dot_1_input_axes: Logical axes for first matrix multiplication sharding
        dot_2_input_axes: Logical axes for second matrix multiplication sharding
        ffn1_ckpt_name: Name for first feed-forward network checkpointing
        ffn2_ckpt_name: Name for second feed-forward network checkpointing
        activation_type: Activation function(s)
        quantizer_sets: Tuple of quantizer sets

    Returns:
        Output tensor from the combined operations
    """
    output, _ = _layernorm_mlp_fwd_rule(
205
206
207
208
209
210
211
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
212
        norm_type,
213
214
        zero_centered_gamma,
        epsilon,
215
        norm_input_axes,
216
217
        dot_1_input_axes,
        dot_2_input_axes,
218
219
        kernel_1_axes,
        kernel_2_axes,
220
        ffn1_ckpt_name,
221
222
        ffn2_ckpt_name,
        activation_type,
223
        quantizer_sets,
224
225
226
227
    )
    return output


228
def _layernorm_mlp_fwd_rule(
229
230
231
232
233
234
235
    x,
    gamma,
    beta,
    kernel_1,
    kernel_2,
    bias_1,
    bias_2,
236
    norm_type,
237
238
    zero_centered_gamma,
    epsilon,
239
    norm_input_axes,
240
241
    dot_1_input_axes,
    dot_2_input_axes,
242
243
    kernel_1_axes,
    kernel_2_axes,
244
245
246
    ffn1_ckpt_name,
    ffn2_ckpt_name,
    activation_type,
247
    quantizer_sets,
248
):
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    """Forward pass rule for layernorm_mlp.

    Implements the forward pass computation including:
    1. Layer normalization with quantization
    2. First matrix multiplication with quantized kernel
    3. Activation function application
    4. Second matrix multiplication with quantized kernel
    5. Optional bias additions
    6. Sharding constraints
    7. Checkpointing for memory efficiency

    Returns:
        Tuple of (output, context) for automatic differentiation
    """
263
    del kernel_1_axes, kernel_2_axes
264

265
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
266
267

    # x should be in shape of (batch..., hidden)
268
    # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
269
    # Kernel_2 should be in shape of (intermediate, hidden_in)
270
    assert len(kernel_1.shape) == 3
271
    assert len(kernel_2.shape) == 2
272
    assert kernel_1.shape[-2] == len(activation_type)
273
274

    x_contracting_dims = (len(x.shape) - 1,)
275
    k_contracting_dims = (0,)
276

277
    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
278

279
280
281
282
283
284
285
286
287
288
289
290
291
    use_bias_1 = bias_1 is not None
    use_bias_2 = bias_1 is not None

    x = with_sharding_constraint_by_logical_axes(x, norm_input_axes)

    casted_ln_out, mu, rsigma = tex.normalization_fwd(
        x,
        gamma,
        beta,
        zero_centered_gamma,
        epsilon,
        norm_type,
        quantizer=ffn1_quantizer_set.x,
Alp Dener's avatar
Alp Dener committed
292
        noop_scaled_tensor=True,
293
    )
294
    casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
295

Alp Dener's avatar
Alp Dener committed
296
297
298
    casted_kernel_1 = tex.quantize(
        kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True
    )
299

300
    # NN GEMM
301
    # (batch..., hidden_in) x (hidden_in, hidden_out)
302
    dot_1_output = tex.gemm(
303
304
        casted_ln_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
305
306
307
        contracting_dims=(x_contracting_dims, k_contracting_dims),
        bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
308
    )
309

Alp Dener's avatar
Alp Dener committed
310
    if use_bias_1 and tex.gemm_uses_jax_dot():
311
312
313
        bias_1_shape = bias_1.shape
        bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
        dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
314

315
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
316
317

    # (batch..., hidden_in) -> (batch..., hidden)
Alp Dener's avatar
Alp Dener committed
318
319
320
    casted_act_out = tex.act_lu(
        dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True
    )
321

322
    casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
323

Alp Dener's avatar
Alp Dener committed
324
325
326
    casted_kernel_2 = tex.quantize(
        kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True
    )
327

328
    # NN GEMM
329
    # (batch..., hidden_in) x (hidden_out, hidden_in)
330
    dot_2_output = tex.gemm(
331
332
        casted_act_out.get_tensor(TensorUsage.LHS),
        casted_kernel_2.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
333
334
335
        contracting_dims=(x_contracting_dims, k_contracting_dims),
        bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
336
    )
337

Alp Dener's avatar
Alp Dener committed
338
    if use_bias_2 and tex.gemm_uses_jax_dot():
339
340
341
        bias_2_shape = bias_2.shape
        bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
        dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
342

343
344
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

345
346
347
348
349
    ctx = (
        x,
        mu,
        rsigma,
        gamma,
350
        beta,
351
352
        casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
        casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS),
353
        dot_1_output,
354
355
        casted_act_out.get_tensor(TensorUsage.LHS_TRANS),
        casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS),
356
        x_contracting_dims,
357
358
359
360
361
362
        k_contracting_dims,
        kernel_1.shape,
        kernel_2.shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
363
    )
364
365
366
367

    return dot_2_output, ctx


368
369
def _layernorm_mlp_bwd_rule(
    norm_type,
370
371
    zero_centered_gamma,
    epsilon,
372
    norm_input_axes,
373
374
    dot_1_input_axes,
    dot_2_input_axes,
375
376
377
378
    kernel_1_axes,
    kernel_2_axes,
    ffn1_ckpt_name,
    ffn2_ckpt_name,
379
380
381
382
    activation_type,
    ctx,
    grad,
):
383
384
385
386
387
388
389
390
391
392
393
394
395
    """Backward pass rule for layernorm_mlp.

    Implements the backward pass computation including:
    1. Gradient computation for second matrix multiplication
    2. Gradient computation for activation function
    3. Gradient computation for first matrix multiplication
    4. Gradient computation for layer normalization
    5. Gradient computation for bias terms
    6. Proper handling of quantization

    Returns:
        Tuple of gradients for all input parameters
    """
396
    del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name
397
398
399
400
401
    (
        x,
        mu,
        rsigma,
        gamma,
402
        beta,
403
404
        casted_ln_out,
        casted_kernel_1,
405
        dot_1_output,
406
407
        casted_act_out,
        casted_kernel_2,
408
409
410
411
412
413
414
        x_contracting_dims_in_fwd,
        k_contracting_dims_in_fwd,
        kernel_1_shape,
        kernel_2_shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
415
    ) = ctx
416

417
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
418
419
420

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
421
422

    casted_grad, dbias_2 = tex.quantize_dbias(
Alp Dener's avatar
Alp Dener committed
423
        grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True
424
    )
425

426
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
427
    g_contracting_dims_2 = tuple(
428
429
430
        range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
    )
    # k_non_contracting_dims
431
    k_contracting_dims_2 = tuple(
432
        dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
433
    )
434

435
    # NT GEMM
436
    # (batch..., hidden_out) x (hidden_in, hidden_out)
437
    dgrad_2 = tex.gemm(
438
439
        casted_grad.get_tensor(TensorUsage.LHS),
        casted_kernel_2,
Alp Dener's avatar
Alp Dener committed
440
        contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
441
    )
442
443
444

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

445
    x_contracting_dims = g_contracting_dims = tuple(
446
        range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
447
448
    )

449
450
451
    # TN GEMM
    # (hidden, batch...,) x (hidden, batch...)
    wgrad_2 = tex.gemm(
452
453
        casted_act_out,
        casted_grad.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
454
        contracting_dims=(x_contracting_dims, g_contracting_dims),
455
    )
456
    wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
457

458
459
460
461
462
463
    casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
        dgrad_2,
        dot_1_output,
        activation_type=activation_type,
        is_dbias=use_bias_1,
        quantizer=ffn2_quantizer_set.dgrad,
Alp Dener's avatar
Alp Dener committed
464
        noop_scaled_tensor=True,
465
    )
466
467

    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
468
    dact_out_ndim = casted_dact_out.get_tensor(TensorUsage.LHS).data.ndim
469
470
    g_contracting_dims_1 = tuple(
        range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
471
    )
472
    # k_non_contracting_dims
473
    k_contracting_dims_1 = tuple(
474
        dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
475
    )
476
477
478

    # NT GEMM
    dgrad_1 = tex.gemm(
479
480
        casted_dact_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1,
Alp Dener's avatar
Alp Dener committed
481
        contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
482
    )
483

484
    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
485
486
487
488

    # TN GEMM
    # (hidden, batch...) x (hidden, batch...)
    wgrad_1 = tex.gemm(
489
490
        casted_ln_out,
        casted_dact_out.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
491
        contracting_dims=(x_contracting_dims, g_contracting_dims),
492
    )
493

494
495
    wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)

496
497
498
499
500
501
502
503
504
505
    dx, dgamma, dbeta = tex.normalization_bwd(
        dgrad_1,
        x,
        mu,
        rsigma,
        gamma,
        beta,
        zero_centered_gamma=zero_centered_gamma,
        epsilon=epsilon,
        norm_type=norm_type,
506
    )
507

508
509
    return (dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, quantizer_sets)

510

511
_layernorm_mlp.defvjp(_layernorm_mlp_fwd_rule, _layernorm_mlp_bwd_rule)