layernorm_mlp.py 17.5 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
11
12
13
14
"""Multi-layer perceptron (MLP) operations with layer normalization for Transformer Engine in JAX.

This module provides optimized implementations of MLP blocks commonly used in transformer
architectures. Each MLP block consists of:
1. Layer normalization
2. First dense layer transformation (GEMM1) with bias and activation
3. Second dense layer transformation (GEMM2) with bias

The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints.
"""
15

Alp Dener's avatar
Alp Dener committed
16
import warnings
17
from typing import List, Tuple, Sequence, Union, Callable
18
from functools import partial
19
20
21

import jax
import jax.numpy as jnp
22
from jax.ad_checkpoint import checkpoint_name
23

24
from . import cpp_extensions as tex
25
from .layernorm import canonicalize_norm_type
26
27
28
29
30
31
from .quantize import (
    with_sharding_constraint_by_logical_axes,
    QuantizerSet,
    noop_quantizer_set,
    TensorUsage,
)
32
33
34
from .sharding import (
    get_sequence_parallel_dim,
)
35

36

Alp Dener's avatar
Alp Dener committed
37
38
39
40
41
42
43
44
45
46
LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = False


def _issue_batch_first_warning(msg):
    global LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED
    if not LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED:
        warnings.warn(msg, UserWarning)
        LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = True


47
def layernorm_mlp(
48
49
50
51
52
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernels: List[jnp.ndarray],
    biases: List[jnp.ndarray],
53
    norm_type: str,
54
55
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
56
    norm_input_axes: Tuple[str, ...] = None,
57
58
    dot_1_input_axes: Tuple[str, ...] = None,
    dot_2_input_axes: Tuple[str, ...] = None,
59
60
    kernel_1_axes: Tuple[str, ...] = None,
    kernel_2_axes: Tuple[str, ...] = None,
61
62
63
    ffn1_ckpt_name: str = "ffn1",
    ffn2_ckpt_name: str = "ffn2",
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
Alp Dener's avatar
Alp Dener committed
64
    batch_first: bool = True,
65
    quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
66
) -> jnp.ndarray:
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    """Apply layer normalization followed by MLP block.

    This function implements the following sequence of operations:
        1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
        2. First dense layer transformation: y1 = x * kernel1 + bias1
        3. Activation function: y2 = activation(y1)
        4. Second dense layer transformation: y3 = y2 * kernel2 + bias2

    Args:
        x: Input tensor with shape [batch..., hidden_in]
        gamma: Scale parameter for normalization with shape [hidden_in]
        beta: Bias parameter for normalization with shape [hidden_in]
        kernels: List of two weight matrices:
            - kernel1: [hidden_in, intermediate]
            - kernel2: [intermediate, hidden_in]
        biases: List of two bias terms:
            - bias1: [intermediate]
            - bias2: [hidden_in]
        norm_type: Type of normalization ("layernorm" or "rmsnorm")
        zero_centered_gamma: Whether to use zero-centered gamma for normalization
        epsilon: Small constant for numerical stability in normalization
        norm_input_axes: Logical axes for sharding the layernorm input
        dot_1_input_axes: Logical axes for sharding the first matrix multiplication
        dot_2_input_axes: Logical axes for sharding the second matrix multiplication
91
92
        kernel_1_axes: Logical axes for sharding the first weight matrix
        kernel_2_axes: Logical axes for sharding the second weight matrix
93
94
95
        ffn1_ckpt_name: Name for checkpointing the first feed-forward network
        ffn2_ckpt_name: Name for checkpointing the second feed-forward network
        activation_type: Activation function(s) to apply after the first dense layer transformation
Alp Dener's avatar
Alp Dener committed
96
        batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
97
98
99
100
101
102
103
104
105
106
107
        quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations

    Returns:
        Output tensor with shape [batch..., hidden_in]

    Note:
        - For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
          must be False
        - The function supports automatic differentiation through JAX's custom VJP
        - Quantization is applied to both dense layer transformations
        - Checkpointing is applied to both feed-forward networks for memory efficiency
108
109
110
111
112
113
114
115
    """
    assert len(kernels) == 2

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]

116
117
118
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "rmsnorm":
        assert beta is None, "beta should be None if norm_type is 'rmsnorm'"
119
120
        assert (
            not zero_centered_gamma
121
        ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
122

123
    output = _layernorm_mlp(
124
125
126
127
128
129
130
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
131
        norm_type,
132
133
        zero_centered_gamma,
        epsilon,
134
        norm_input_axes,
135
136
        dot_1_input_axes,
        dot_2_input_axes,
137
138
        kernel_1_axes,
        kernel_2_axes,
139
140
141
        ffn1_ckpt_name,
        ffn2_ckpt_name,
        activation_type,
Alp Dener's avatar
Alp Dener committed
142
        batch_first,
143
        quantizer_sets,
144
    )
145
146
147
    return output


Alp Dener's avatar
Alp Dener committed
148
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
149
def _layernorm_mlp(
150
151
152
153
154
155
156
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernel_1: jnp.ndarray,
    kernel_2: jnp.ndarray,
    bias_1: jnp.ndarray,
    bias_2: jnp.ndarray,
157
    norm_type: str,
158
159
    zero_centered_gamma: bool,
    epsilon: float,
160
    norm_input_axes: Tuple[str, ...],
161
162
    dot_1_input_axes: Tuple[str, ...],
    dot_2_input_axes: Tuple[str, ...],
163
164
    kernel_1_axes: Tuple[str, ...],
    kernel_2_axes: Tuple[str, ...],
165
166
167
    ffn1_ckpt_name: str,
    ffn2_ckpt_name: str,
    activation_type: Sequence[Union[str, Callable]],
Alp Dener's avatar
Alp Dener committed
168
    batch_first: bool,
169
    quantizer_sets,
170
):
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    """Internal implementation of layernorm_mlp with custom VJP.

    This function implements the forward pass of layernorm_mlp with support for
    automatic differentiation. It handles the normalization, dense layer transformations,
    activation, and quantization operations.

    Args:
        x: Input tensor
        gamma: Scale parameter for normalization
        beta: Bias parameter for normalization
        kernel_1: First weight matrix
        kernel_2: Second weight matrix
        bias_1: First bias term
        bias_2: Second bias term
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
        norm_input_axes: Logical axes for layernorm sharding
        dot_1_input_axes: Logical axes for first matrix multiplication sharding
        dot_2_input_axes: Logical axes for second matrix multiplication sharding
        ffn1_ckpt_name: Name for first feed-forward network checkpointing
        ffn2_ckpt_name: Name for second feed-forward network checkpointing
        activation_type: Activation function(s)
Alp Dener's avatar
Alp Dener committed
194
        batch_first: Assume that X is batched in the first dimension.
195
196
197
198
199
200
        quantizer_sets: Tuple of quantizer sets

    Returns:
        Output tensor from the combined operations
    """
    output, _ = _layernorm_mlp_fwd_rule(
201
202
203
204
205
206
207
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
208
        norm_type,
209
210
        zero_centered_gamma,
        epsilon,
211
        norm_input_axes,
212
213
        dot_1_input_axes,
        dot_2_input_axes,
214
215
        kernel_1_axes,
        kernel_2_axes,
216
        ffn1_ckpt_name,
217
218
        ffn2_ckpt_name,
        activation_type,
Alp Dener's avatar
Alp Dener committed
219
        batch_first,
220
        quantizer_sets,
221
222
223
224
    )
    return output


225
def _layernorm_mlp_fwd_rule(
226
227
228
229
230
231
232
    x,
    gamma,
    beta,
    kernel_1,
    kernel_2,
    bias_1,
    bias_2,
233
    norm_type,
234
235
    zero_centered_gamma,
    epsilon,
236
    norm_input_axes,
237
238
    dot_1_input_axes,
    dot_2_input_axes,
239
240
    kernel_1_axes,
    kernel_2_axes,
241
242
243
    ffn1_ckpt_name,
    ffn2_ckpt_name,
    activation_type,
Alp Dener's avatar
Alp Dener committed
244
    batch_first,
245
    quantizer_sets,
246
):
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    """Forward pass rule for layernorm_mlp.

    Implements the forward pass computation including:
    1. Layer normalization with quantization
    2. First matrix multiplication with quantized kernel
    3. Activation function application
    4. Second matrix multiplication with quantized kernel
    5. Optional bias additions
    6. Sharding constraints
    7. Checkpointing for memory efficiency

    Returns:
        Tuple of (output, context) for automatic differentiation
    """
261
    del kernel_1_axes, kernel_2_axes
262

263
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
264
265

    # x should be in shape of (batch..., hidden)
266
    # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
267
    # Kernel_2 should be in shape of (intermediate, hidden_in)
268
    assert len(kernel_1.shape) == 3
269
    assert len(kernel_2.shape) == 2
270
    assert kernel_1.shape[-2] == len(activation_type)
271
272

    x_contracting_dims = (len(x.shape) - 1,)
273
    k_contracting_dims = (0,)
274

275
    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
276

Alp Dener's avatar
Alp Dener committed
277
278
279
280
281
282
283
284
285
286
287
    x_bdim = None
    if x.ndim > 2:
        if not batch_first:
            _issue_batch_first_warning(
                "TE/JAX `layernorm_mlp()` fused-layer implementation does not officially "
                "support sequence-first inputs and may produce incorrect results when "
                "`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first "
                "inputs at your own discretion."
            )
        x_bdim = 0 if batch_first else x.ndim - 2

288
289
290
291
292
293
294
295
296
297
298
299
300
    use_bias_1 = bias_1 is not None
    use_bias_2 = bias_1 is not None

    x = with_sharding_constraint_by_logical_axes(x, norm_input_axes)

    casted_ln_out, mu, rsigma = tex.normalization_fwd(
        x,
        gamma,
        beta,
        zero_centered_gamma,
        epsilon,
        norm_type,
        quantizer=ffn1_quantizer_set.x,
Alp Dener's avatar
Alp Dener committed
301
        noop_scaled_tensor=True,
302
    )
303
    casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
304

Alp Dener's avatar
Alp Dener committed
305
306
307
    casted_kernel_1 = tex.quantize(
        kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True
    )
308

309
    # NN GEMM
310
    # (batch..., hidden_in) x (hidden_in, hidden_out)
311
    dot_1_output = tex.gemm(
312
313
        casted_ln_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
314
315
316
317
        contracting_dims=(x_contracting_dims, k_contracting_dims),
        batched_dims=((x_bdim,), ()),
        bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
318
    )
319

Alp Dener's avatar
Alp Dener committed
320
    if use_bias_1 and tex.gemm_uses_jax_dot():
321
322
323
        bias_1_shape = bias_1.shape
        bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
        dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
324

325
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
326
327

    # (batch..., hidden_in) -> (batch..., hidden)
Alp Dener's avatar
Alp Dener committed
328
329
330
    casted_act_out = tex.act_lu(
        dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True
    )
331

332
    casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
333

Alp Dener's avatar
Alp Dener committed
334
335
336
    casted_kernel_2 = tex.quantize(
        kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True
    )
337

338
    # NN GEMM
339
    # (batch..., hidden_in) x (hidden_out, hidden_in)
340
    sequence_dim = get_sequence_parallel_dim(norm_input_axes, x_contracting_dims, (x_bdim,))
341
    dot_2_output = tex.gemm(
342
343
        casted_act_out.get_tensor(TensorUsage.LHS),
        casted_kernel_2.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
344
345
346
347
        contracting_dims=(x_contracting_dims, k_contracting_dims),
        batched_dims=((x_bdim,), ()),
        bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
348
349
        sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(),
        sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None,
350
    )
351

Alp Dener's avatar
Alp Dener committed
352
    if use_bias_2 and tex.gemm_uses_jax_dot():
353
354
355
        bias_2_shape = bias_2.shape
        bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
        dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
356

357
358
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

359
360
361
362
363
    ctx = (
        x,
        mu,
        rsigma,
        gamma,
364
        beta,
365
366
        casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
        casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS),
367
        dot_1_output,
368
369
        casted_act_out.get_tensor(TensorUsage.LHS_TRANS),
        casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS),
370
        x_contracting_dims,
371
372
373
374
375
376
        k_contracting_dims,
        kernel_1.shape,
        kernel_2.shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
Alp Dener's avatar
Alp Dener committed
377
        x_bdim,
378
        sequence_dim,
379
    )
380
381
382
383

    return dot_2_output, ctx


384
385
def _layernorm_mlp_bwd_rule(
    norm_type,
386
387
    zero_centered_gamma,
    epsilon,
388
    norm_input_axes,
389
390
    dot_1_input_axes,
    dot_2_input_axes,
391
392
393
394
    kernel_1_axes,
    kernel_2_axes,
    ffn1_ckpt_name,
    ffn2_ckpt_name,
395
    activation_type,
Alp Dener's avatar
Alp Dener committed
396
    batch_first,
397
398
399
    ctx,
    grad,
):
400
401
402
403
404
405
406
407
408
409
410
411
412
    """Backward pass rule for layernorm_mlp.

    Implements the backward pass computation including:
    1. Gradient computation for second matrix multiplication
    2. Gradient computation for activation function
    3. Gradient computation for first matrix multiplication
    4. Gradient computation for layer normalization
    5. Gradient computation for bias terms
    6. Proper handling of quantization

    Returns:
        Tuple of gradients for all input parameters
    """
Alp Dener's avatar
Alp Dener committed
413
    del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, batch_first
414
415
416
417
418
    (
        x,
        mu,
        rsigma,
        gamma,
419
        beta,
420
421
        casted_ln_out,
        casted_kernel_1,
422
        dot_1_output,
423
424
        casted_act_out,
        casted_kernel_2,
425
426
427
428
429
430
431
        x_contracting_dims_in_fwd,
        k_contracting_dims_in_fwd,
        kernel_1_shape,
        kernel_2_shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
Alp Dener's avatar
Alp Dener committed
432
        x_bdim,
433
        sequence_dim,
434
    ) = ctx
435

436
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
437
438
439

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
440
441

    casted_grad, dbias_2 = tex.quantize_dbias(
Alp Dener's avatar
Alp Dener committed
442
        grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True
443
    )
444

445
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
446
    g_contracting_dims_2 = tuple(
447
448
449
        range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
    )
    # k_non_contracting_dims
450
    k_contracting_dims_2 = tuple(
451
        dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
452
    )
453

454
    # NT GEMM
455
    # (batch..., hidden_out) x (hidden_in, hidden_out)
456
    dgrad_2 = tex.gemm(
457
458
        casted_grad.get_tensor(TensorUsage.LHS),
        casted_kernel_2,
Alp Dener's avatar
Alp Dener committed
459
460
        contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
        batched_dims=((x_bdim,), ()),
461
    )
462
463
464

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

465
    x_contracting_dims = g_contracting_dims = tuple(
466
        range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
467
468
    )

469
470
471
    # TN GEMM
    # (hidden, batch...,) x (hidden, batch...)
    wgrad_2 = tex.gemm(
472
473
        casted_act_out,
        casted_grad.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
474
475
        contracting_dims=(x_contracting_dims, g_contracting_dims),
        batched_dims=((x_bdim,), (x_bdim,)),
476
    )
477
    wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
478

479
480
481
482
483
484
    casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
        dgrad_2,
        dot_1_output,
        activation_type=activation_type,
        is_dbias=use_bias_1,
        quantizer=ffn2_quantizer_set.dgrad,
Alp Dener's avatar
Alp Dener committed
485
        noop_scaled_tensor=True,
486
    )
487
488

    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
489
    dact_out_ndim = casted_dact_out.get_tensor(TensorUsage.LHS).data.ndim
490
491
    g_contracting_dims_1 = tuple(
        range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
492
    )
493
    # k_non_contracting_dims
494
    k_contracting_dims_1 = tuple(
495
        dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
496
    )
497
498
499

    # NT GEMM
    dgrad_1 = tex.gemm(
500
501
        casted_dact_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1,
Alp Dener's avatar
Alp Dener committed
502
503
        contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
        batched_dims=((x_bdim,), ()),
504
505
        sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(),
        sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None,
506
    )
507

508
    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
509
510
511
512

    # TN GEMM
    # (hidden, batch...) x (hidden, batch...)
    wgrad_1 = tex.gemm(
513
514
        casted_ln_out,
        casted_dact_out.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
515
516
        contracting_dims=(x_contracting_dims, g_contracting_dims),
        batched_dims=((x_bdim,), (x_bdim,)),
517
    )
518

519
520
    wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)

521
522
523
524
525
526
527
528
529
530
    dx, dgamma, dbeta = tex.normalization_bwd(
        dgrad_1,
        x,
        mu,
        rsigma,
        gamma,
        beta,
        zero_centered_gamma=zero_centered_gamma,
        epsilon=epsilon,
        norm_type=norm_type,
531
    )
532

533
534
    return (dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, quantizer_sets)

535

536
_layernorm_mlp.defvjp(_layernorm_mlp_fwd_rule, _layernorm_mlp_bwd_rule)