layernorm_mlp.py 16.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
11
12
13
14
"""Multi-layer perceptron (MLP) operations with layer normalization for Transformer Engine in JAX.

This module provides optimized implementations of MLP blocks commonly used in transformer
architectures. Each MLP block consists of:
1. Layer normalization
2. First dense layer transformation (GEMM1) with bias and activation
3. Second dense layer transformation (GEMM2) with bias

The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints.
"""
15

16
from typing import List, Tuple, Sequence, Union, Callable
17
from functools import partial
18
19
20

import jax
import jax.numpy as jnp
21
from jax.ad_checkpoint import checkpoint_name
22

23
from . import cpp_extensions as tex
24
from .cpp_extensions.quantization import AmaxScope
25
from .layernorm import canonicalize_norm_type
26
27
28
29
30
from .quantize import (
    with_sharding_constraint_by_logical_axes,
    QuantizerSet,
    noop_quantizer_set,
    TensorUsage,
31
    get_quantize_config,
32
)
Alp Dener's avatar
Alp Dener committed
33
34


35
def layernorm_mlp(
36
37
38
39
40
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernels: List[jnp.ndarray],
    biases: List[jnp.ndarray],
41
    norm_type: str,
42
43
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
44
    norm_input_axes: Tuple[str, ...] = None,
45
46
    dot_1_input_axes: Tuple[str, ...] = None,
    dot_2_input_axes: Tuple[str, ...] = None,
47
48
    kernel_1_axes: Tuple[str, ...] = None,
    kernel_2_axes: Tuple[str, ...] = None,
49
50
51
    ffn1_ckpt_name: str = "ffn1",
    ffn2_ckpt_name: str = "ffn2",
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
52
    quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
53
) -> jnp.ndarray:
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    """Apply layer normalization followed by MLP block.

    This function implements the following sequence of operations:
        1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
        2. First dense layer transformation: y1 = x * kernel1 + bias1
        3. Activation function: y2 = activation(y1)
        4. Second dense layer transformation: y3 = y2 * kernel2 + bias2

    Args:
        x: Input tensor with shape [batch..., hidden_in]
        gamma: Scale parameter for normalization with shape [hidden_in]
        beta: Bias parameter for normalization with shape [hidden_in]
        kernels: List of two weight matrices:
            - kernel1: [hidden_in, intermediate]
            - kernel2: [intermediate, hidden_in]
        biases: List of two bias terms:
            - bias1: [intermediate]
            - bias2: [hidden_in]
        norm_type: Type of normalization ("layernorm" or "rmsnorm")
        zero_centered_gamma: Whether to use zero-centered gamma for normalization
        epsilon: Small constant for numerical stability in normalization
        norm_input_axes: Logical axes for sharding the layernorm input
        dot_1_input_axes: Logical axes for sharding the first matrix multiplication
        dot_2_input_axes: Logical axes for sharding the second matrix multiplication
78
79
        kernel_1_axes: Logical axes for sharding the first weight matrix
        kernel_2_axes: Logical axes for sharding the second weight matrix
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        ffn1_ckpt_name: Name for checkpointing the first feed-forward network
        ffn2_ckpt_name: Name for checkpointing the second feed-forward network
        activation_type: Activation function(s) to apply after the first dense layer transformation
        quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations

    Returns:
        Output tensor with shape [batch..., hidden_in]

    Note:
        - For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
          must be False
        - The function supports automatic differentiation through JAX's custom VJP
        - Quantization is applied to both dense layer transformations
        - Checkpointing is applied to both feed-forward networks for memory efficiency
94
95
96
97
98
99
100
101
    """
    assert len(kernels) == 2

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]

102
103
104
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "rmsnorm":
        assert beta is None, "beta should be None if norm_type is 'rmsnorm'"
105
106
        assert (
            not zero_centered_gamma
107
        ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
108

109
110
111
112
113
    if not get_quantize_config().is_fp8_enabled():
        input_dtype = x.dtype
        kernel_1 = kernel_1.astype(input_dtype)
        kernel_2 = kernel_2.astype(input_dtype)

114
    output = _layernorm_mlp(
115
116
117
118
119
120
121
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
122
        norm_type,
123
124
        zero_centered_gamma,
        epsilon,
125
        norm_input_axes,
126
127
        dot_1_input_axes,
        dot_2_input_axes,
128
129
        kernel_1_axes,
        kernel_2_axes,
130
131
132
        ffn1_ckpt_name,
        ffn2_ckpt_name,
        activation_type,
133
        quantizer_sets,
134
    )
135
136
137
    return output


138
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
139
def _layernorm_mlp(
140
141
142
143
144
145
146
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernel_1: jnp.ndarray,
    kernel_2: jnp.ndarray,
    bias_1: jnp.ndarray,
    bias_2: jnp.ndarray,
147
    norm_type: str,
148
149
    zero_centered_gamma: bool,
    epsilon: float,
150
    norm_input_axes: Tuple[str, ...],
151
152
    dot_1_input_axes: Tuple[str, ...],
    dot_2_input_axes: Tuple[str, ...],
153
154
    kernel_1_axes: Tuple[str, ...],
    kernel_2_axes: Tuple[str, ...],
155
156
157
    ffn1_ckpt_name: str,
    ffn2_ckpt_name: str,
    activation_type: Sequence[Union[str, Callable]],
158
    quantizer_sets,
159
):
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    """Internal implementation of layernorm_mlp with custom VJP.

    This function implements the forward pass of layernorm_mlp with support for
    automatic differentiation. It handles the normalization, dense layer transformations,
    activation, and quantization operations.

    Args:
        x: Input tensor
        gamma: Scale parameter for normalization
        beta: Bias parameter for normalization
        kernel_1: First weight matrix
        kernel_2: Second weight matrix
        bias_1: First bias term
        bias_2: Second bias term
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
        norm_input_axes: Logical axes for layernorm sharding
        dot_1_input_axes: Logical axes for first matrix multiplication sharding
        dot_2_input_axes: Logical axes for second matrix multiplication sharding
        ffn1_ckpt_name: Name for first feed-forward network checkpointing
        ffn2_ckpt_name: Name for second feed-forward network checkpointing
        activation_type: Activation function(s)
        quantizer_sets: Tuple of quantizer sets

    Returns:
        Output tensor from the combined operations
    """
    output, _ = _layernorm_mlp_fwd_rule(
189
190
191
192
193
194
195
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
196
        norm_type,
197
198
        zero_centered_gamma,
        epsilon,
199
        norm_input_axes,
200
201
        dot_1_input_axes,
        dot_2_input_axes,
202
203
        kernel_1_axes,
        kernel_2_axes,
204
        ffn1_ckpt_name,
205
206
        ffn2_ckpt_name,
        activation_type,
207
        quantizer_sets,
208
209
210
211
    )
    return output


212
def _layernorm_mlp_fwd_rule(
213
214
215
216
217
218
219
    x,
    gamma,
    beta,
    kernel_1,
    kernel_2,
    bias_1,
    bias_2,
220
    norm_type,
221
222
    zero_centered_gamma,
    epsilon,
223
    norm_input_axes,
224
225
    dot_1_input_axes,
    dot_2_input_axes,
226
227
    kernel_1_axes,
    kernel_2_axes,
228
229
230
    ffn1_ckpt_name,
    ffn2_ckpt_name,
    activation_type,
231
    quantizer_sets,
232
):
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    """Forward pass rule for layernorm_mlp.

    Implements the forward pass computation including:
    1. Layer normalization with quantization
    2. First matrix multiplication with quantized kernel
    3. Activation function application
    4. Second matrix multiplication with quantized kernel
    5. Optional bias additions
    6. Sharding constraints
    7. Checkpointing for memory efficiency

    Returns:
        Tuple of (output, context) for automatic differentiation
    """
247
    del kernel_1_axes, kernel_2_axes
248

249
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
250
251

    # x should be in shape of (batch..., hidden)
252
    # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
253
    # Kernel_2 should be in shape of (intermediate, hidden_in)
254
    assert len(kernel_1.shape) == 3
255
    assert len(kernel_2.shape) == 2
256
    assert kernel_1.shape[-2] == len(activation_type)
257
258

    x_contracting_dims = (len(x.shape) - 1,)
259
    k_contracting_dims = (0,)
260

261
    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
262

263
264
265
266
267
268
269
270
271
272
273
274
275
    use_bias_1 = bias_1 is not None
    use_bias_2 = bias_1 is not None

    x = with_sharding_constraint_by_logical_axes(x, norm_input_axes)

    casted_ln_out, mu, rsigma = tex.normalization_fwd(
        x,
        gamma,
        beta,
        zero_centered_gamma,
        epsilon,
        norm_type,
        quantizer=ffn1_quantizer_set.x,
276
        amax_scope=AmaxScope.TPSP,
277
    )
278
    casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
279

Alp Dener's avatar
Alp Dener committed
280
    casted_kernel_1 = tex.quantize(
281
        kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, amax_scope=AmaxScope.FSDP
Alp Dener's avatar
Alp Dener committed
282
    )
283

284
    # NN GEMM
285
    # (batch..., hidden_in) x (hidden_in, hidden_out)
286
    dot_1_output = tex.gemm(
287
288
        casted_ln_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
289
290
291
        contracting_dims=(x_contracting_dims, k_contracting_dims),
        bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
292
    )
293

Alp Dener's avatar
Alp Dener committed
294
    if use_bias_1 and tex.gemm_uses_jax_dot():
295
296
297
        bias_1_shape = bias_1.shape
        bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
        dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
298

299
300
301
302
303
304
305
    # This sharding constraint is needed to correct the Shardy sharding propagation
    if dot_2_input_axes is not None:
        dot_1_output_axes = (
            dot_2_input_axes[:-1] + (None,) + dot_2_input_axes[-1:]
        )  # add the act_num axis
        dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)

306
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
307
308

    # (batch..., hidden_in) -> (batch..., hidden)
Alp Dener's avatar
Alp Dener committed
309
    casted_act_out = tex.act_lu(
310
311
312
        dot_1_output,
        activation_type,
        quantizer=ffn2_quantizer_set.x,
Alp Dener's avatar
Alp Dener committed
313
    )
314

315
    casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
316

Alp Dener's avatar
Alp Dener committed
317
    casted_kernel_2 = tex.quantize(
318
319
        kernel_2,
        quantizer=ffn2_quantizer_set.kernel,
320
        amax_scope=AmaxScope.FSDP,
Alp Dener's avatar
Alp Dener committed
321
    )
322

323
    # NN GEMM
324
    # (batch..., hidden_in) x (hidden_out, hidden_in)
325
    dot_2_output = tex.gemm(
326
327
        casted_act_out.get_tensor(TensorUsage.LHS),
        casted_kernel_2.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
328
329
330
        contracting_dims=(x_contracting_dims, k_contracting_dims),
        bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
331
    )
332

Alp Dener's avatar
Alp Dener committed
333
    if use_bias_2 and tex.gemm_uses_jax_dot():
334
335
336
        bias_2_shape = bias_2.shape
        bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
        dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
337

338
339
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

340
341
342
343
344
    ctx = (
        x,
        mu,
        rsigma,
        gamma,
345
        beta,
346
347
        casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
        casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS),
348
        dot_1_output,
349
350
        casted_act_out.get_tensor(TensorUsage.LHS_TRANS),
        casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS),
351
        x_contracting_dims,
352
353
354
355
356
357
        k_contracting_dims,
        kernel_1.shape,
        kernel_2.shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
358
    )
359
360
361
362

    return dot_2_output, ctx


363
364
def _layernorm_mlp_bwd_rule(
    norm_type,
365
366
    zero_centered_gamma,
    epsilon,
367
    norm_input_axes,
368
369
    dot_1_input_axes,
    dot_2_input_axes,
370
371
372
373
    kernel_1_axes,
    kernel_2_axes,
    ffn1_ckpt_name,
    ffn2_ckpt_name,
374
375
376
377
    activation_type,
    ctx,
    grad,
):
378
379
380
381
382
383
384
385
386
387
388
389
390
    """Backward pass rule for layernorm_mlp.

    Implements the backward pass computation including:
    1. Gradient computation for second matrix multiplication
    2. Gradient computation for activation function
    3. Gradient computation for first matrix multiplication
    4. Gradient computation for layer normalization
    5. Gradient computation for bias terms
    6. Proper handling of quantization

    Returns:
        Tuple of gradients for all input parameters
    """
391
    del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name
392
393
394
395
396
    (
        x,
        mu,
        rsigma,
        gamma,
397
        beta,
398
399
        casted_ln_out,
        casted_kernel_1,
400
        dot_1_output,
401
402
        casted_act_out,
        casted_kernel_2,
403
404
405
406
407
408
409
        x_contracting_dims_in_fwd,
        k_contracting_dims_in_fwd,
        kernel_1_shape,
        kernel_2_shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
410
    ) = ctx
411

412
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
413
414
415

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
416
417

    casted_grad, dbias_2 = tex.quantize_dbias(
418
419
420
        grad,
        is_dbias=use_bias_2,
        quantizer=ffn1_quantizer_set.dgrad,
421
        amax_scope=AmaxScope.TPSP,
422
    )
423

424
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
425
    g_contracting_dims_2 = tuple(
426
427
428
        range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
    )
    # k_non_contracting_dims
429
    k_contracting_dims_2 = tuple(
430
        dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
431
    )
432

433
    # NT GEMM
434
    # (batch..., hidden_out) x (hidden_in, hidden_out)
435
    dgrad_2 = tex.gemm(
436
437
        casted_grad.get_tensor(TensorUsage.LHS),
        casted_kernel_2,
Alp Dener's avatar
Alp Dener committed
438
        contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
439
    )
440
441
442

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

443
    x_contracting_dims = g_contracting_dims = tuple(
444
        range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
445
446
    )

447
448
449
    # TN GEMM
    # (hidden, batch...,) x (hidden, batch...)
    wgrad_2 = tex.gemm(
450
451
        casted_act_out,
        casted_grad.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
452
        contracting_dims=(x_contracting_dims, g_contracting_dims),
453
    )
454
    wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
455

456
457
458
459
460
461
    casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
        dgrad_2,
        dot_1_output,
        activation_type=activation_type,
        is_dbias=use_bias_1,
        quantizer=ffn2_quantizer_set.dgrad,
462
    )
463
464

    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
465
    dact_out_ndim = casted_dact_out.get_tensor(TensorUsage.LHS).data.ndim
466
467
    g_contracting_dims_1 = tuple(
        range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
468
    )
469
    # k_non_contracting_dims
470
    k_contracting_dims_1 = tuple(
471
        dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
472
    )
473
474
475

    # NT GEMM
    dgrad_1 = tex.gemm(
476
477
        casted_dact_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1,
Alp Dener's avatar
Alp Dener committed
478
        contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
479
    )
480

481
    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
482
483
484
485

    # TN GEMM
    # (hidden, batch...) x (hidden, batch...)
    wgrad_1 = tex.gemm(
486
487
        casted_ln_out,
        casted_dact_out.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
488
        contracting_dims=(x_contracting_dims, g_contracting_dims),
489
    )
490

491
492
    wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)

493
494
495
496
497
498
499
500
501
502
    dx, dgamma, dbeta = tex.normalization_bwd(
        dgrad_1,
        x,
        mu,
        rsigma,
        gamma,
        beta,
        zero_centered_gamma=zero_centered_gamma,
        epsilon=epsilon,
        norm_type=norm_type,
503
    )
504

505
506
    return (dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, quantizer_sets)

507

508
_layernorm_mlp.defvjp(_layernorm_mlp_fwd_rule, _layernorm_mlp_bwd_rule)