layernorm_mlp.py 17.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
11
12
13
14
"""Multi-layer perceptron (MLP) operations with layer normalization for Transformer Engine in JAX.

This module provides optimized implementations of MLP blocks commonly used in transformer
architectures. Each MLP block consists of:
1. Layer normalization
2. First dense layer transformation (GEMM1) with bias and activation
3. Second dense layer transformation (GEMM2) with bias

The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints.
"""
15

Alp Dener's avatar
Alp Dener committed
16
import warnings
17
from typing import List, Tuple, Sequence, Union, Callable
18
from functools import partial
19
20
21

import jax
import jax.numpy as jnp
22
from jax.ad_checkpoint import checkpoint_name
23

24
from . import cpp_extensions as tex
25
from .layernorm import canonicalize_norm_type
26
27
28
29
30
31
from .quantize import (
    with_sharding_constraint_by_logical_axes,
    QuantizerSet,
    noop_quantizer_set,
    TensorUsage,
)
32
from .sharding import get_non_contracting_logical_axes
33

34

Alp Dener's avatar
Alp Dener committed
35
36
37
38
39
40
41
42
43
44
LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = False


def _issue_batch_first_warning(msg):
    global LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED
    if not LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED:
        warnings.warn(msg, UserWarning)
        LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = True


45
def layernorm_mlp(
46
47
48
49
50
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernels: List[jnp.ndarray],
    biases: List[jnp.ndarray],
51
    norm_type: str,
52
53
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
54
    norm_input_axes: Tuple[str, ...] = None,
55
56
    dot_1_input_axes: Tuple[str, ...] = None,
    dot_2_input_axes: Tuple[str, ...] = None,
57
58
    kernel_1_axes: Tuple[str, ...] = None,
    kernel_2_axes: Tuple[str, ...] = None,
59
60
61
    ffn1_ckpt_name: str = "ffn1",
    ffn2_ckpt_name: str = "ffn2",
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
Alp Dener's avatar
Alp Dener committed
62
    batch_first: bool = True,
63
    quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
64
) -> jnp.ndarray:
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    """Apply layer normalization followed by MLP block.

    This function implements the following sequence of operations:
        1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
        2. First dense layer transformation: y1 = x * kernel1 + bias1
        3. Activation function: y2 = activation(y1)
        4. Second dense layer transformation: y3 = y2 * kernel2 + bias2

    Args:
        x: Input tensor with shape [batch..., hidden_in]
        gamma: Scale parameter for normalization with shape [hidden_in]
        beta: Bias parameter for normalization with shape [hidden_in]
        kernels: List of two weight matrices:
            - kernel1: [hidden_in, intermediate]
            - kernel2: [intermediate, hidden_in]
        biases: List of two bias terms:
            - bias1: [intermediate]
            - bias2: [hidden_in]
        norm_type: Type of normalization ("layernorm" or "rmsnorm")
        zero_centered_gamma: Whether to use zero-centered gamma for normalization
        epsilon: Small constant for numerical stability in normalization
        norm_input_axes: Logical axes for sharding the layernorm input
        dot_1_input_axes: Logical axes for sharding the first matrix multiplication
        dot_2_input_axes: Logical axes for sharding the second matrix multiplication
89
90
        kernel_1_axes: Logical axes for sharding the first weight matrix
        kernel_2_axes: Logical axes for sharding the second weight matrix
91
92
93
        ffn1_ckpt_name: Name for checkpointing the first feed-forward network
        ffn2_ckpt_name: Name for checkpointing the second feed-forward network
        activation_type: Activation function(s) to apply after the first dense layer transformation
Alp Dener's avatar
Alp Dener committed
94
        batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
95
96
97
98
99
100
101
102
103
104
105
        quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations

    Returns:
        Output tensor with shape [batch..., hidden_in]

    Note:
        - For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
          must be False
        - The function supports automatic differentiation through JAX's custom VJP
        - Quantization is applied to both dense layer transformations
        - Checkpointing is applied to both feed-forward networks for memory efficiency
106
107
108
109
110
111
112
113
    """
    assert len(kernels) == 2

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]

114
115
116
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "rmsnorm":
        assert beta is None, "beta should be None if norm_type is 'rmsnorm'"
117
118
        assert (
            not zero_centered_gamma
119
        ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
120

121
    output = _layernorm_mlp(
122
123
124
125
126
127
128
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
129
        norm_type,
130
131
        zero_centered_gamma,
        epsilon,
132
        norm_input_axes,
133
134
        dot_1_input_axes,
        dot_2_input_axes,
135
136
        kernel_1_axes,
        kernel_2_axes,
137
138
139
        ffn1_ckpt_name,
        ffn2_ckpt_name,
        activation_type,
Alp Dener's avatar
Alp Dener committed
140
        batch_first,
141
        quantizer_sets,
142
    )
143
144
145
    return output


Alp Dener's avatar
Alp Dener committed
146
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
147
def _layernorm_mlp(
148
149
150
151
152
153
154
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernel_1: jnp.ndarray,
    kernel_2: jnp.ndarray,
    bias_1: jnp.ndarray,
    bias_2: jnp.ndarray,
155
    norm_type: str,
156
157
    zero_centered_gamma: bool,
    epsilon: float,
158
    norm_input_axes: Tuple[str, ...],
159
160
    dot_1_input_axes: Tuple[str, ...],
    dot_2_input_axes: Tuple[str, ...],
161
162
    kernel_1_axes: Tuple[str, ...],
    kernel_2_axes: Tuple[str, ...],
163
164
165
    ffn1_ckpt_name: str,
    ffn2_ckpt_name: str,
    activation_type: Sequence[Union[str, Callable]],
Alp Dener's avatar
Alp Dener committed
166
    batch_first: bool,
167
    quantizer_sets,
168
):
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    """Internal implementation of layernorm_mlp with custom VJP.

    This function implements the forward pass of layernorm_mlp with support for
    automatic differentiation. It handles the normalization, dense layer transformations,
    activation, and quantization operations.

    Args:
        x: Input tensor
        gamma: Scale parameter for normalization
        beta: Bias parameter for normalization
        kernel_1: First weight matrix
        kernel_2: Second weight matrix
        bias_1: First bias term
        bias_2: Second bias term
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
        norm_input_axes: Logical axes for layernorm sharding
        dot_1_input_axes: Logical axes for first matrix multiplication sharding
        dot_2_input_axes: Logical axes for second matrix multiplication sharding
        ffn1_ckpt_name: Name for first feed-forward network checkpointing
        ffn2_ckpt_name: Name for second feed-forward network checkpointing
        activation_type: Activation function(s)
Alp Dener's avatar
Alp Dener committed
192
        batch_first: Assume that X is batched in the first dimension.
193
194
195
196
197
198
        quantizer_sets: Tuple of quantizer sets

    Returns:
        Output tensor from the combined operations
    """
    output, _ = _layernorm_mlp_fwd_rule(
199
200
201
202
203
204
205
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
206
        norm_type,
207
208
        zero_centered_gamma,
        epsilon,
209
        norm_input_axes,
210
211
        dot_1_input_axes,
        dot_2_input_axes,
212
213
        kernel_1_axes,
        kernel_2_axes,
214
        ffn1_ckpt_name,
215
216
        ffn2_ckpt_name,
        activation_type,
Alp Dener's avatar
Alp Dener committed
217
        batch_first,
218
        quantizer_sets,
219
220
221
222
    )
    return output


223
def _layernorm_mlp_fwd_rule(
224
225
226
227
228
229
230
    x,
    gamma,
    beta,
    kernel_1,
    kernel_2,
    bias_1,
    bias_2,
231
    norm_type,
232
233
    zero_centered_gamma,
    epsilon,
234
    norm_input_axes,
235
236
    dot_1_input_axes,
    dot_2_input_axes,
237
238
    kernel_1_axes,
    kernel_2_axes,
239
240
241
    ffn1_ckpt_name,
    ffn2_ckpt_name,
    activation_type,
Alp Dener's avatar
Alp Dener committed
242
    batch_first,
243
    quantizer_sets,
244
):
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    """Forward pass rule for layernorm_mlp.

    Implements the forward pass computation including:
    1. Layer normalization with quantization
    2. First matrix multiplication with quantized kernel
    3. Activation function application
    4. Second matrix multiplication with quantized kernel
    5. Optional bias additions
    6. Sharding constraints
    7. Checkpointing for memory efficiency

    Returns:
        Tuple of (output, context) for automatic differentiation
    """
259
260
    del kernel_2_axes

261
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
262
263

    # x should be in shape of (batch..., hidden)
264
    # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
265
    # Kernel_2 should be in shape of (intermediate, hidden_in)
266
    assert len(kernel_1.shape) == 3
267
    assert len(kernel_2.shape) == 2
268
    assert kernel_1.shape[-2] == len(activation_type)
269
270

    x_contracting_dims = (len(x.shape) - 1,)
271
    k_contracting_dims = (0,)
272

273
    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
274

Alp Dener's avatar
Alp Dener committed
275
276
277
278
279
280
281
282
283
284
285
    x_bdim = None
    if x.ndim > 2:
        if not batch_first:
            _issue_batch_first_warning(
                "TE/JAX `layernorm_mlp()` fused-layer implementation does not officially "
                "support sequence-first inputs and may produce incorrect results when "
                "`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first "
                "inputs at your own discretion."
            )
        x_bdim = 0 if batch_first else x.ndim - 2

286
287
288
289
290
291
292
293
294
295
296
297
298
    use_bias_1 = bias_1 is not None
    use_bias_2 = bias_1 is not None

    x = with_sharding_constraint_by_logical_axes(x, norm_input_axes)

    casted_ln_out, mu, rsigma = tex.normalization_fwd(
        x,
        gamma,
        beta,
        zero_centered_gamma,
        epsilon,
        norm_type,
        quantizer=ffn1_quantizer_set.x,
Alp Dener's avatar
Alp Dener committed
299
        noop_scaled_tensor=True,
300
    )
301
    casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
302

Alp Dener's avatar
Alp Dener committed
303
304
305
    casted_kernel_1 = tex.quantize(
        kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True
    )
306

307
    # NN GEMM
308
    # (batch..., hidden_in) x (hidden_in, hidden_out)
309
    dot_1_output = tex.gemm(
310
311
        casted_ln_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
312
313
314
315
        contracting_dims=(x_contracting_dims, k_contracting_dims),
        batched_dims=((x_bdim,), ()),
        bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
316
    )
317

318
319
320
321
322
323
    if dot_1_input_axes is not None and kernel_1_axes is not None:
        dot_1_output_axes = (
            *get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims),
            *get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims),
        )
        dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)
324

Alp Dener's avatar
Alp Dener committed
325
    if use_bias_1 and tex.gemm_uses_jax_dot():
326
327
328
        bias_1_shape = bias_1.shape
        bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
        dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
329

330
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
331
332

    # (batch..., hidden_in) -> (batch..., hidden)
Alp Dener's avatar
Alp Dener committed
333
334
335
    casted_act_out = tex.act_lu(
        dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True
    )
336

337
    casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
338

Alp Dener's avatar
Alp Dener committed
339
340
341
    casted_kernel_2 = tex.quantize(
        kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True
    )
342

343
    # NN GEMM
344
    # (batch..., hidden_in) x (hidden_out, hidden_in)
345
    dot_2_output = tex.gemm(
346
347
        casted_act_out.get_tensor(TensorUsage.LHS),
        casted_kernel_2.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
348
349
350
351
        contracting_dims=(x_contracting_dims, k_contracting_dims),
        batched_dims=((x_bdim,), ()),
        bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
352
    )
353

Alp Dener's avatar
Alp Dener committed
354
    if use_bias_2 and tex.gemm_uses_jax_dot():
355
356
357
        bias_2_shape = bias_2.shape
        bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
        dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
358

359
360
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

361
362
363
364
365
    ctx = (
        x,
        mu,
        rsigma,
        gamma,
366
        beta,
367
368
        casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
        casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS),
369
        dot_1_output,
370
371
        casted_act_out.get_tensor(TensorUsage.LHS_TRANS),
        casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS),
372
        x_contracting_dims,
373
374
375
376
377
378
        k_contracting_dims,
        kernel_1.shape,
        kernel_2.shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
Alp Dener's avatar
Alp Dener committed
379
        x_bdim,
380
    )
381
382
383
384

    return dot_2_output, ctx


385
386
def _layernorm_mlp_bwd_rule(
    norm_type,
387
388
    zero_centered_gamma,
    epsilon,
389
    norm_input_axes,
390
391
    dot_1_input_axes,
    dot_2_input_axes,
392
393
394
395
    kernel_1_axes,
    kernel_2_axes,
    ffn1_ckpt_name,
    ffn2_ckpt_name,
396
    activation_type,
Alp Dener's avatar
Alp Dener committed
397
    batch_first,
398
399
400
    ctx,
    grad,
):
401
402
403
404
405
406
407
408
409
410
411
412
413
    """Backward pass rule for layernorm_mlp.

    Implements the backward pass computation including:
    1. Gradient computation for second matrix multiplication
    2. Gradient computation for activation function
    3. Gradient computation for first matrix multiplication
    4. Gradient computation for layer normalization
    5. Gradient computation for bias terms
    6. Proper handling of quantization

    Returns:
        Tuple of gradients for all input parameters
    """
Alp Dener's avatar
Alp Dener committed
414
    del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, batch_first
415
416
417
418
419
    (
        x,
        mu,
        rsigma,
        gamma,
420
        beta,
421
422
        casted_ln_out,
        casted_kernel_1,
423
        dot_1_output,
424
425
        casted_act_out,
        casted_kernel_2,
426
427
428
429
430
431
432
        x_contracting_dims_in_fwd,
        k_contracting_dims_in_fwd,
        kernel_1_shape,
        kernel_2_shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
Alp Dener's avatar
Alp Dener committed
433
        x_bdim,
434
    ) = ctx
435

436
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
437
438
439

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
440
441

    casted_grad, dbias_2 = tex.quantize_dbias(
Alp Dener's avatar
Alp Dener committed
442
        grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True
443
    )
444

445
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
446
    g_contracting_dims_2 = tuple(
447
448
449
        range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
    )
    # k_non_contracting_dims
450
    k_contracting_dims_2 = tuple(
451
        dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
452
    )
453

454
    # NT GEMM
455
    # (batch..., hidden_out) x (hidden_in, hidden_out)
456
    dgrad_2 = tex.gemm(
457
458
        casted_grad.get_tensor(TensorUsage.LHS),
        casted_kernel_2,
Alp Dener's avatar
Alp Dener committed
459
460
        contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
        batched_dims=((x_bdim,), ()),
461
    )
462
463
464

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

465
    x_contracting_dims = g_contracting_dims = tuple(
466
        range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
467
468
    )

469
470
471
    # TN GEMM
    # (hidden, batch...,) x (hidden, batch...)
    wgrad_2 = tex.gemm(
472
473
        casted_act_out,
        casted_grad.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
474
475
        contracting_dims=(x_contracting_dims, g_contracting_dims),
        batched_dims=((x_bdim,), (x_bdim,)),
476
    )
477
    wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
478

479
480
481
482
483
484
    casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
        dgrad_2,
        dot_1_output,
        activation_type=activation_type,
        is_dbias=use_bias_1,
        quantizer=ffn2_quantizer_set.dgrad,
Alp Dener's avatar
Alp Dener committed
485
        noop_scaled_tensor=True,
486
    )
487
488

    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
489
    dact_out_ndim = casted_dact_out.get_tensor(TensorUsage.LHS).data.ndim
490
491
    g_contracting_dims_1 = tuple(
        range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
492
    )
493
    # k_non_contracting_dims
494
    k_contracting_dims_1 = tuple(
495
        dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
496
    )
497
498
499

    # NT GEMM
    dgrad_1 = tex.gemm(
500
501
        casted_dact_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1,
Alp Dener's avatar
Alp Dener committed
502
503
        contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
        batched_dims=((x_bdim,), ()),
504
    )
505

506
    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
507
508
509
510

    # TN GEMM
    # (hidden, batch...) x (hidden, batch...)
    wgrad_1 = tex.gemm(
511
512
        casted_ln_out,
        casted_dact_out.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
513
514
        contracting_dims=(x_contracting_dims, g_contracting_dims),
        batched_dims=((x_bdim,), (x_bdim,)),
515
    )
516

517
518
    wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)

519
520
521
522
523
524
525
526
527
528
    dx, dgamma, dbeta = tex.normalization_bwd(
        dgrad_1,
        x,
        mu,
        rsigma,
        gamma,
        beta,
        zero_centered_gamma=zero_centered_gamma,
        epsilon=epsilon,
        norm_type=norm_type,
529
    )
530

531
532
    return (dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, quantizer_sets)

533

534
_layernorm_mlp.defvjp(_layernorm_mlp_fwd_rule, _layernorm_mlp_bwd_rule)