layernorm_mlp.py 17.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
11
12
13
14
"""Multi-layer perceptron (MLP) operations with layer normalization for Transformer Engine in JAX.

This module provides optimized implementations of MLP blocks commonly used in transformer
architectures. Each MLP block consists of:
1. Layer normalization
2. First dense layer transformation (GEMM1) with bias and activation
3. Second dense layer transformation (GEMM2) with bias

The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints.
"""
15

Alp Dener's avatar
Alp Dener committed
16
import warnings
17
from typing import List, Tuple, Sequence, Union, Callable
18
from functools import partial
19
20
21

import jax
import jax.numpy as jnp
22
from jax.ad_checkpoint import checkpoint_name
23

24
from . import cpp_extensions as tex
25
from .layernorm import canonicalize_norm_type
26
27
28
29
30
31
from .quantize import (
    with_sharding_constraint_by_logical_axes,
    QuantizerSet,
    noop_quantizer_set,
    TensorUsage,
)
32
33
34
35
from .sharding import (
    get_non_contracting_logical_axes,
    get_sequence_parallel_dim,
)
36

37

Alp Dener's avatar
Alp Dener committed
38
39
40
41
42
43
44
45
46
47
LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = False


def _issue_batch_first_warning(msg):
    global LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED
    if not LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED:
        warnings.warn(msg, UserWarning)
        LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = True


48
def layernorm_mlp(
49
50
51
52
53
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernels: List[jnp.ndarray],
    biases: List[jnp.ndarray],
54
    norm_type: str,
55
56
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
57
    norm_input_axes: Tuple[str, ...] = None,
58
59
    dot_1_input_axes: Tuple[str, ...] = None,
    dot_2_input_axes: Tuple[str, ...] = None,
60
61
    kernel_1_axes: Tuple[str, ...] = None,
    kernel_2_axes: Tuple[str, ...] = None,
62
63
64
    ffn1_ckpt_name: str = "ffn1",
    ffn2_ckpt_name: str = "ffn2",
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
Alp Dener's avatar
Alp Dener committed
65
    batch_first: bool = True,
66
    quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
67
) -> jnp.ndarray:
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    """Apply layer normalization followed by MLP block.

    This function implements the following sequence of operations:
        1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
        2. First dense layer transformation: y1 = x * kernel1 + bias1
        3. Activation function: y2 = activation(y1)
        4. Second dense layer transformation: y3 = y2 * kernel2 + bias2

    Args:
        x: Input tensor with shape [batch..., hidden_in]
        gamma: Scale parameter for normalization with shape [hidden_in]
        beta: Bias parameter for normalization with shape [hidden_in]
        kernels: List of two weight matrices:
            - kernel1: [hidden_in, intermediate]
            - kernel2: [intermediate, hidden_in]
        biases: List of two bias terms:
            - bias1: [intermediate]
            - bias2: [hidden_in]
        norm_type: Type of normalization ("layernorm" or "rmsnorm")
        zero_centered_gamma: Whether to use zero-centered gamma for normalization
        epsilon: Small constant for numerical stability in normalization
        norm_input_axes: Logical axes for sharding the layernorm input
        dot_1_input_axes: Logical axes for sharding the first matrix multiplication
        dot_2_input_axes: Logical axes for sharding the second matrix multiplication
92
93
        kernel_1_axes: Logical axes for sharding the first weight matrix
        kernel_2_axes: Logical axes for sharding the second weight matrix
94
95
96
        ffn1_ckpt_name: Name for checkpointing the first feed-forward network
        ffn2_ckpt_name: Name for checkpointing the second feed-forward network
        activation_type: Activation function(s) to apply after the first dense layer transformation
Alp Dener's avatar
Alp Dener committed
97
        batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
98
99
100
101
102
103
104
105
106
107
108
        quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations

    Returns:
        Output tensor with shape [batch..., hidden_in]

    Note:
        - For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
          must be False
        - The function supports automatic differentiation through JAX's custom VJP
        - Quantization is applied to both dense layer transformations
        - Checkpointing is applied to both feed-forward networks for memory efficiency
109
110
111
112
113
114
115
116
    """
    assert len(kernels) == 2

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]

117
118
119
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "rmsnorm":
        assert beta is None, "beta should be None if norm_type is 'rmsnorm'"
120
121
        assert (
            not zero_centered_gamma
122
        ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
123

124
    output = _layernorm_mlp(
125
126
127
128
129
130
131
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
132
        norm_type,
133
134
        zero_centered_gamma,
        epsilon,
135
        norm_input_axes,
136
137
        dot_1_input_axes,
        dot_2_input_axes,
138
139
        kernel_1_axes,
        kernel_2_axes,
140
141
142
        ffn1_ckpt_name,
        ffn2_ckpt_name,
        activation_type,
Alp Dener's avatar
Alp Dener committed
143
        batch_first,
144
        quantizer_sets,
145
    )
146
147
148
    return output


Alp Dener's avatar
Alp Dener committed
149
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
150
def _layernorm_mlp(
151
152
153
154
155
156
157
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernel_1: jnp.ndarray,
    kernel_2: jnp.ndarray,
    bias_1: jnp.ndarray,
    bias_2: jnp.ndarray,
158
    norm_type: str,
159
160
    zero_centered_gamma: bool,
    epsilon: float,
161
    norm_input_axes: Tuple[str, ...],
162
163
    dot_1_input_axes: Tuple[str, ...],
    dot_2_input_axes: Tuple[str, ...],
164
165
    kernel_1_axes: Tuple[str, ...],
    kernel_2_axes: Tuple[str, ...],
166
167
168
    ffn1_ckpt_name: str,
    ffn2_ckpt_name: str,
    activation_type: Sequence[Union[str, Callable]],
Alp Dener's avatar
Alp Dener committed
169
    batch_first: bool,
170
    quantizer_sets,
171
):
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    """Internal implementation of layernorm_mlp with custom VJP.

    This function implements the forward pass of layernorm_mlp with support for
    automatic differentiation. It handles the normalization, dense layer transformations,
    activation, and quantization operations.

    Args:
        x: Input tensor
        gamma: Scale parameter for normalization
        beta: Bias parameter for normalization
        kernel_1: First weight matrix
        kernel_2: Second weight matrix
        bias_1: First bias term
        bias_2: Second bias term
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
        norm_input_axes: Logical axes for layernorm sharding
        dot_1_input_axes: Logical axes for first matrix multiplication sharding
        dot_2_input_axes: Logical axes for second matrix multiplication sharding
        ffn1_ckpt_name: Name for first feed-forward network checkpointing
        ffn2_ckpt_name: Name for second feed-forward network checkpointing
        activation_type: Activation function(s)
Alp Dener's avatar
Alp Dener committed
195
        batch_first: Assume that X is batched in the first dimension.
196
197
198
199
200
201
        quantizer_sets: Tuple of quantizer sets

    Returns:
        Output tensor from the combined operations
    """
    output, _ = _layernorm_mlp_fwd_rule(
202
203
204
205
206
207
208
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
209
        norm_type,
210
211
        zero_centered_gamma,
        epsilon,
212
        norm_input_axes,
213
214
        dot_1_input_axes,
        dot_2_input_axes,
215
216
        kernel_1_axes,
        kernel_2_axes,
217
        ffn1_ckpt_name,
218
219
        ffn2_ckpt_name,
        activation_type,
Alp Dener's avatar
Alp Dener committed
220
        batch_first,
221
        quantizer_sets,
222
223
224
225
    )
    return output


226
def _layernorm_mlp_fwd_rule(
227
228
229
230
231
232
233
    x,
    gamma,
    beta,
    kernel_1,
    kernel_2,
    bias_1,
    bias_2,
234
    norm_type,
235
236
    zero_centered_gamma,
    epsilon,
237
    norm_input_axes,
238
239
    dot_1_input_axes,
    dot_2_input_axes,
240
241
    kernel_1_axes,
    kernel_2_axes,
242
243
244
    ffn1_ckpt_name,
    ffn2_ckpt_name,
    activation_type,
Alp Dener's avatar
Alp Dener committed
245
    batch_first,
246
    quantizer_sets,
247
):
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    """Forward pass rule for layernorm_mlp.

    Implements the forward pass computation including:
    1. Layer normalization with quantization
    2. First matrix multiplication with quantized kernel
    3. Activation function application
    4. Second matrix multiplication with quantized kernel
    5. Optional bias additions
    6. Sharding constraints
    7. Checkpointing for memory efficiency

    Returns:
        Tuple of (output, context) for automatic differentiation
    """
262
263
    del kernel_2_axes

264
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
265
266

    # x should be in shape of (batch..., hidden)
267
    # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
268
    # Kernel_2 should be in shape of (intermediate, hidden_in)
269
    assert len(kernel_1.shape) == 3
270
    assert len(kernel_2.shape) == 2
271
    assert kernel_1.shape[-2] == len(activation_type)
272
273

    x_contracting_dims = (len(x.shape) - 1,)
274
    k_contracting_dims = (0,)
275

276
    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
277

Alp Dener's avatar
Alp Dener committed
278
279
280
281
282
283
284
285
286
287
288
    x_bdim = None
    if x.ndim > 2:
        if not batch_first:
            _issue_batch_first_warning(
                "TE/JAX `layernorm_mlp()` fused-layer implementation does not officially "
                "support sequence-first inputs and may produce incorrect results when "
                "`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first "
                "inputs at your own discretion."
            )
        x_bdim = 0 if batch_first else x.ndim - 2

289
290
291
292
293
294
295
296
297
298
299
300
301
    use_bias_1 = bias_1 is not None
    use_bias_2 = bias_1 is not None

    x = with_sharding_constraint_by_logical_axes(x, norm_input_axes)

    casted_ln_out, mu, rsigma = tex.normalization_fwd(
        x,
        gamma,
        beta,
        zero_centered_gamma,
        epsilon,
        norm_type,
        quantizer=ffn1_quantizer_set.x,
Alp Dener's avatar
Alp Dener committed
302
        noop_scaled_tensor=True,
303
    )
304
    casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
305

Alp Dener's avatar
Alp Dener committed
306
307
308
    casted_kernel_1 = tex.quantize(
        kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True
    )
309

310
    # NN GEMM
311
    # (batch..., hidden_in) x (hidden_in, hidden_out)
312
    dot_1_output = tex.gemm(
313
314
        casted_ln_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
315
316
317
318
        contracting_dims=(x_contracting_dims, k_contracting_dims),
        batched_dims=((x_bdim,), ()),
        bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
319
    )
320

321
322
323
324
325
326
    if dot_1_input_axes is not None and kernel_1_axes is not None:
        dot_1_output_axes = (
            *get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims),
            *get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims),
        )
        dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)
327

Alp Dener's avatar
Alp Dener committed
328
    if use_bias_1 and tex.gemm_uses_jax_dot():
329
330
331
        bias_1_shape = bias_1.shape
        bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
        dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
332

333
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
334
335

    # (batch..., hidden_in) -> (batch..., hidden)
Alp Dener's avatar
Alp Dener committed
336
337
338
    casted_act_out = tex.act_lu(
        dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True
    )
339

340
    casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
341

Alp Dener's avatar
Alp Dener committed
342
343
344
    casted_kernel_2 = tex.quantize(
        kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True
    )
345

346
    # NN GEMM
347
    # (batch..., hidden_in) x (hidden_out, hidden_in)
348
    sequence_dim = get_sequence_parallel_dim(norm_input_axes, x_contracting_dims, (x_bdim,))
349
    dot_2_output = tex.gemm(
350
351
        casted_act_out.get_tensor(TensorUsage.LHS),
        casted_kernel_2.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
352
353
354
355
        contracting_dims=(x_contracting_dims, k_contracting_dims),
        batched_dims=((x_bdim,), ()),
        bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
356
357
        sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(),
        sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None,
358
    )
359

Alp Dener's avatar
Alp Dener committed
360
    if use_bias_2 and tex.gemm_uses_jax_dot():
361
362
363
        bias_2_shape = bias_2.shape
        bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
        dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
364

365
366
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

367
368
369
370
371
    ctx = (
        x,
        mu,
        rsigma,
        gamma,
372
        beta,
373
374
        casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
        casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS),
375
        dot_1_output,
376
377
        casted_act_out.get_tensor(TensorUsage.LHS_TRANS),
        casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS),
378
        x_contracting_dims,
379
380
381
382
383
384
        k_contracting_dims,
        kernel_1.shape,
        kernel_2.shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
Alp Dener's avatar
Alp Dener committed
385
        x_bdim,
386
        sequence_dim,
387
    )
388
389
390
391

    return dot_2_output, ctx


392
393
def _layernorm_mlp_bwd_rule(
    norm_type,
394
395
    zero_centered_gamma,
    epsilon,
396
    norm_input_axes,
397
398
    dot_1_input_axes,
    dot_2_input_axes,
399
400
401
402
    kernel_1_axes,
    kernel_2_axes,
    ffn1_ckpt_name,
    ffn2_ckpt_name,
403
    activation_type,
Alp Dener's avatar
Alp Dener committed
404
    batch_first,
405
406
407
    ctx,
    grad,
):
408
409
410
411
412
413
414
415
416
417
418
419
420
    """Backward pass rule for layernorm_mlp.

    Implements the backward pass computation including:
    1. Gradient computation for second matrix multiplication
    2. Gradient computation for activation function
    3. Gradient computation for first matrix multiplication
    4. Gradient computation for layer normalization
    5. Gradient computation for bias terms
    6. Proper handling of quantization

    Returns:
        Tuple of gradients for all input parameters
    """
Alp Dener's avatar
Alp Dener committed
421
    del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, batch_first
422
423
424
425
426
    (
        x,
        mu,
        rsigma,
        gamma,
427
        beta,
428
429
        casted_ln_out,
        casted_kernel_1,
430
        dot_1_output,
431
432
        casted_act_out,
        casted_kernel_2,
433
434
435
436
437
438
439
        x_contracting_dims_in_fwd,
        k_contracting_dims_in_fwd,
        kernel_1_shape,
        kernel_2_shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
Alp Dener's avatar
Alp Dener committed
440
        x_bdim,
441
        sequence_dim,
442
    ) = ctx
443

444
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
445
446
447

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
448
449

    casted_grad, dbias_2 = tex.quantize_dbias(
Alp Dener's avatar
Alp Dener committed
450
        grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True
451
    )
452

453
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
454
    g_contracting_dims_2 = tuple(
455
456
457
        range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
    )
    # k_non_contracting_dims
458
    k_contracting_dims_2 = tuple(
459
        dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
460
    )
461

462
    # NT GEMM
463
    # (batch..., hidden_out) x (hidden_in, hidden_out)
464
    dgrad_2 = tex.gemm(
465
466
        casted_grad.get_tensor(TensorUsage.LHS),
        casted_kernel_2,
Alp Dener's avatar
Alp Dener committed
467
468
        contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
        batched_dims=((x_bdim,), ()),
469
    )
470
471
472

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

473
    x_contracting_dims = g_contracting_dims = tuple(
474
        range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
475
476
    )

477
478
479
    # TN GEMM
    # (hidden, batch...,) x (hidden, batch...)
    wgrad_2 = tex.gemm(
480
481
        casted_act_out,
        casted_grad.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
482
483
        contracting_dims=(x_contracting_dims, g_contracting_dims),
        batched_dims=((x_bdim,), (x_bdim,)),
484
    )
485
    wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
486

487
488
489
490
491
492
    casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
        dgrad_2,
        dot_1_output,
        activation_type=activation_type,
        is_dbias=use_bias_1,
        quantizer=ffn2_quantizer_set.dgrad,
Alp Dener's avatar
Alp Dener committed
493
        noop_scaled_tensor=True,
494
    )
495
496

    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
497
    dact_out_ndim = casted_dact_out.get_tensor(TensorUsage.LHS).data.ndim
498
499
    g_contracting_dims_1 = tuple(
        range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
500
    )
501
    # k_non_contracting_dims
502
    k_contracting_dims_1 = tuple(
503
        dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
504
    )
505
506
507

    # NT GEMM
    dgrad_1 = tex.gemm(
508
509
        casted_dact_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1,
Alp Dener's avatar
Alp Dener committed
510
511
        contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
        batched_dims=((x_bdim,), ()),
512
513
        sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(),
        sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None,
514
    )
515

516
    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
517
518
519
520

    # TN GEMM
    # (hidden, batch...) x (hidden, batch...)
    wgrad_1 = tex.gemm(
521
522
        casted_ln_out,
        casted_dact_out.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
523
524
        contracting_dims=(x_contracting_dims, g_contracting_dims),
        batched_dims=((x_bdim,), (x_bdim,)),
525
    )
526

527
528
    wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)

529
530
531
532
533
534
535
536
537
538
    dx, dgamma, dbeta = tex.normalization_bwd(
        dgrad_1,
        x,
        mu,
        rsigma,
        gamma,
        beta,
        zero_centered_gamma=zero_centered_gamma,
        epsilon=epsilon,
        norm_type=norm_type,
539
    )
540

541
542
    return (dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, quantizer_sets)

543

544
_layernorm_mlp.defvjp(_layernorm_mlp_fwd_rule, _layernorm_mlp_bwd_rule)