layernorm_mlp.py 16 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
11
12
13
14
"""Multi-layer perceptron (MLP) operations with layer normalization for Transformer Engine in JAX.

This module provides optimized implementations of MLP blocks commonly used in transformer
architectures. Each MLP block consists of:
1. Layer normalization
2. First dense layer transformation (GEMM1) with bias and activation
3. Second dense layer transformation (GEMM2) with bias

The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints.
"""
15

16
from typing import List, Tuple, Sequence, Union, Callable
17
from functools import partial
18
19
20

import jax
import jax.numpy as jnp
21
from jax.ad_checkpoint import checkpoint_name
22

23
from . import cpp_extensions as tex
24
from .layernorm import canonicalize_norm_type
25
26
27
28
29
30
from .quantize import (
    with_sharding_constraint_by_logical_axes,
    QuantizerSet,
    noop_quantizer_set,
    TensorUsage,
)
Alp Dener's avatar
Alp Dener committed
31
32


33
def layernorm_mlp(
34
35
36
37
38
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernels: List[jnp.ndarray],
    biases: List[jnp.ndarray],
39
    norm_type: str,
40
41
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
42
    norm_input_axes: Tuple[str, ...] = None,
43
44
    dot_1_input_axes: Tuple[str, ...] = None,
    dot_2_input_axes: Tuple[str, ...] = None,
45
46
    kernel_1_axes: Tuple[str, ...] = None,
    kernel_2_axes: Tuple[str, ...] = None,
47
48
49
    ffn1_ckpt_name: str = "ffn1",
    ffn2_ckpt_name: str = "ffn2",
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
50
    quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
51
) -> jnp.ndarray:
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    """Apply layer normalization followed by MLP block.

    This function implements the following sequence of operations:
        1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
        2. First dense layer transformation: y1 = x * kernel1 + bias1
        3. Activation function: y2 = activation(y1)
        4. Second dense layer transformation: y3 = y2 * kernel2 + bias2

    Args:
        x: Input tensor with shape [batch..., hidden_in]
        gamma: Scale parameter for normalization with shape [hidden_in]
        beta: Bias parameter for normalization with shape [hidden_in]
        kernels: List of two weight matrices:
            - kernel1: [hidden_in, intermediate]
            - kernel2: [intermediate, hidden_in]
        biases: List of two bias terms:
            - bias1: [intermediate]
            - bias2: [hidden_in]
        norm_type: Type of normalization ("layernorm" or "rmsnorm")
        zero_centered_gamma: Whether to use zero-centered gamma for normalization
        epsilon: Small constant for numerical stability in normalization
        norm_input_axes: Logical axes for sharding the layernorm input
        dot_1_input_axes: Logical axes for sharding the first matrix multiplication
        dot_2_input_axes: Logical axes for sharding the second matrix multiplication
76
77
        kernel_1_axes: Logical axes for sharding the first weight matrix
        kernel_2_axes: Logical axes for sharding the second weight matrix
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        ffn1_ckpt_name: Name for checkpointing the first feed-forward network
        ffn2_ckpt_name: Name for checkpointing the second feed-forward network
        activation_type: Activation function(s) to apply after the first dense layer transformation
        quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations

    Returns:
        Output tensor with shape [batch..., hidden_in]

    Note:
        - For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
          must be False
        - The function supports automatic differentiation through JAX's custom VJP
        - Quantization is applied to both dense layer transformations
        - Checkpointing is applied to both feed-forward networks for memory efficiency
92
93
94
95
96
97
98
99
    """
    assert len(kernels) == 2

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]

100
101
102
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "rmsnorm":
        assert beta is None, "beta should be None if norm_type is 'rmsnorm'"
103
104
        assert (
            not zero_centered_gamma
105
        ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
106

107
    output = _layernorm_mlp(
108
109
110
111
112
113
114
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
115
        norm_type,
116
117
        zero_centered_gamma,
        epsilon,
118
        norm_input_axes,
119
120
        dot_1_input_axes,
        dot_2_input_axes,
121
122
        kernel_1_axes,
        kernel_2_axes,
123
124
125
        ffn1_ckpt_name,
        ffn2_ckpt_name,
        activation_type,
126
        quantizer_sets,
127
    )
128
129
130
    return output


131
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
132
def _layernorm_mlp(
133
134
135
136
137
138
139
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernel_1: jnp.ndarray,
    kernel_2: jnp.ndarray,
    bias_1: jnp.ndarray,
    bias_2: jnp.ndarray,
140
    norm_type: str,
141
142
    zero_centered_gamma: bool,
    epsilon: float,
143
    norm_input_axes: Tuple[str, ...],
144
145
    dot_1_input_axes: Tuple[str, ...],
    dot_2_input_axes: Tuple[str, ...],
146
147
    kernel_1_axes: Tuple[str, ...],
    kernel_2_axes: Tuple[str, ...],
148
149
150
    ffn1_ckpt_name: str,
    ffn2_ckpt_name: str,
    activation_type: Sequence[Union[str, Callable]],
151
    quantizer_sets,
152
):
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    """Internal implementation of layernorm_mlp with custom VJP.

    This function implements the forward pass of layernorm_mlp with support for
    automatic differentiation. It handles the normalization, dense layer transformations,
    activation, and quantization operations.

    Args:
        x: Input tensor
        gamma: Scale parameter for normalization
        beta: Bias parameter for normalization
        kernel_1: First weight matrix
        kernel_2: Second weight matrix
        bias_1: First bias term
        bias_2: Second bias term
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
        norm_input_axes: Logical axes for layernorm sharding
        dot_1_input_axes: Logical axes for first matrix multiplication sharding
        dot_2_input_axes: Logical axes for second matrix multiplication sharding
        ffn1_ckpt_name: Name for first feed-forward network checkpointing
        ffn2_ckpt_name: Name for second feed-forward network checkpointing
        activation_type: Activation function(s)
        quantizer_sets: Tuple of quantizer sets

    Returns:
        Output tensor from the combined operations
    """
    output, _ = _layernorm_mlp_fwd_rule(
182
183
184
185
186
187
188
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
189
        norm_type,
190
191
        zero_centered_gamma,
        epsilon,
192
        norm_input_axes,
193
194
        dot_1_input_axes,
        dot_2_input_axes,
195
196
        kernel_1_axes,
        kernel_2_axes,
197
        ffn1_ckpt_name,
198
199
        ffn2_ckpt_name,
        activation_type,
200
        quantizer_sets,
201
202
203
204
    )
    return output


205
def _layernorm_mlp_fwd_rule(
206
207
208
209
210
211
212
    x,
    gamma,
    beta,
    kernel_1,
    kernel_2,
    bias_1,
    bias_2,
213
    norm_type,
214
215
    zero_centered_gamma,
    epsilon,
216
    norm_input_axes,
217
218
    dot_1_input_axes,
    dot_2_input_axes,
219
220
    kernel_1_axes,
    kernel_2_axes,
221
222
223
    ffn1_ckpt_name,
    ffn2_ckpt_name,
    activation_type,
224
    quantizer_sets,
225
):
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    """Forward pass rule for layernorm_mlp.

    Implements the forward pass computation including:
    1. Layer normalization with quantization
    2. First matrix multiplication with quantized kernel
    3. Activation function application
    4. Second matrix multiplication with quantized kernel
    5. Optional bias additions
    6. Sharding constraints
    7. Checkpointing for memory efficiency

    Returns:
        Tuple of (output, context) for automatic differentiation
    """
240
    del kernel_1_axes, kernel_2_axes
241

242
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
243
244

    # x should be in shape of (batch..., hidden)
245
    # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
246
    # Kernel_2 should be in shape of (intermediate, hidden_in)
247
    assert len(kernel_1.shape) == 3
248
    assert len(kernel_2.shape) == 2
249
    assert kernel_1.shape[-2] == len(activation_type)
250
251

    x_contracting_dims = (len(x.shape) - 1,)
252
    k_contracting_dims = (0,)
253

254
    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
255

256
257
258
259
260
261
262
263
264
265
266
267
268
    use_bias_1 = bias_1 is not None
    use_bias_2 = bias_1 is not None

    x = with_sharding_constraint_by_logical_axes(x, norm_input_axes)

    casted_ln_out, mu, rsigma = tex.normalization_fwd(
        x,
        gamma,
        beta,
        zero_centered_gamma,
        epsilon,
        norm_type,
        quantizer=ffn1_quantizer_set.x,
Alp Dener's avatar
Alp Dener committed
269
        noop_scaled_tensor=True,
270
    )
271
    casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
272

Alp Dener's avatar
Alp Dener committed
273
274
275
    casted_kernel_1 = tex.quantize(
        kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True
    )
276

277
    # NN GEMM
278
    # (batch..., hidden_in) x (hidden_in, hidden_out)
279
    dot_1_output = tex.gemm(
280
281
        casted_ln_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
282
283
284
        contracting_dims=(x_contracting_dims, k_contracting_dims),
        bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
285
    )
286

Alp Dener's avatar
Alp Dener committed
287
    if use_bias_1 and tex.gemm_uses_jax_dot():
288
289
290
        bias_1_shape = bias_1.shape
        bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
        dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
291

292
293
294
295
296
297
298
    # This sharding constraint is needed to correct the Shardy sharding propagation
    if dot_2_input_axes is not None:
        dot_1_output_axes = (
            dot_2_input_axes[:-1] + (None,) + dot_2_input_axes[-1:]
        )  # add the act_num axis
        dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)

299
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
300
301

    # (batch..., hidden_in) -> (batch..., hidden)
Alp Dener's avatar
Alp Dener committed
302
303
304
    casted_act_out = tex.act_lu(
        dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True
    )
305

306
    casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
307

Alp Dener's avatar
Alp Dener committed
308
309
310
    casted_kernel_2 = tex.quantize(
        kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True
    )
311

312
    # NN GEMM
313
    # (batch..., hidden_in) x (hidden_out, hidden_in)
314
    dot_2_output = tex.gemm(
315
316
        casted_act_out.get_tensor(TensorUsage.LHS),
        casted_kernel_2.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
317
318
319
        contracting_dims=(x_contracting_dims, k_contracting_dims),
        bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
320
    )
321

Alp Dener's avatar
Alp Dener committed
322
    if use_bias_2 and tex.gemm_uses_jax_dot():
323
324
325
        bias_2_shape = bias_2.shape
        bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
        dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
326

327
328
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

329
330
331
332
333
    ctx = (
        x,
        mu,
        rsigma,
        gamma,
334
        beta,
335
336
        casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
        casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS),
337
        dot_1_output,
338
339
        casted_act_out.get_tensor(TensorUsage.LHS_TRANS),
        casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS),
340
        x_contracting_dims,
341
342
343
344
345
346
        k_contracting_dims,
        kernel_1.shape,
        kernel_2.shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
347
    )
348
349
350
351

    return dot_2_output, ctx


352
353
def _layernorm_mlp_bwd_rule(
    norm_type,
354
355
    zero_centered_gamma,
    epsilon,
356
    norm_input_axes,
357
358
    dot_1_input_axes,
    dot_2_input_axes,
359
360
361
362
    kernel_1_axes,
    kernel_2_axes,
    ffn1_ckpt_name,
    ffn2_ckpt_name,
363
364
365
366
    activation_type,
    ctx,
    grad,
):
367
368
369
370
371
372
373
374
375
376
377
378
379
    """Backward pass rule for layernorm_mlp.

    Implements the backward pass computation including:
    1. Gradient computation for second matrix multiplication
    2. Gradient computation for activation function
    3. Gradient computation for first matrix multiplication
    4. Gradient computation for layer normalization
    5. Gradient computation for bias terms
    6. Proper handling of quantization

    Returns:
        Tuple of gradients for all input parameters
    """
380
    del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name
381
382
383
384
385
    (
        x,
        mu,
        rsigma,
        gamma,
386
        beta,
387
388
        casted_ln_out,
        casted_kernel_1,
389
        dot_1_output,
390
391
        casted_act_out,
        casted_kernel_2,
392
393
394
395
396
397
398
        x_contracting_dims_in_fwd,
        k_contracting_dims_in_fwd,
        kernel_1_shape,
        kernel_2_shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
399
    ) = ctx
400

401
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
402
403
404

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
405
406

    casted_grad, dbias_2 = tex.quantize_dbias(
Alp Dener's avatar
Alp Dener committed
407
        grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True
408
    )
409

410
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
411
    g_contracting_dims_2 = tuple(
412
413
414
        range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
    )
    # k_non_contracting_dims
415
    k_contracting_dims_2 = tuple(
416
        dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
417
    )
418

419
    # NT GEMM
420
    # (batch..., hidden_out) x (hidden_in, hidden_out)
421
    dgrad_2 = tex.gemm(
422
423
        casted_grad.get_tensor(TensorUsage.LHS),
        casted_kernel_2,
Alp Dener's avatar
Alp Dener committed
424
        contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
425
    )
426
427
428

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

429
    x_contracting_dims = g_contracting_dims = tuple(
430
        range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
431
432
    )

433
434
435
    # TN GEMM
    # (hidden, batch...,) x (hidden, batch...)
    wgrad_2 = tex.gemm(
436
437
        casted_act_out,
        casted_grad.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
438
        contracting_dims=(x_contracting_dims, g_contracting_dims),
439
    )
440
    wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
441

442
443
444
445
446
447
    casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
        dgrad_2,
        dot_1_output,
        activation_type=activation_type,
        is_dbias=use_bias_1,
        quantizer=ffn2_quantizer_set.dgrad,
Alp Dener's avatar
Alp Dener committed
448
        noop_scaled_tensor=True,
449
    )
450
451

    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
452
    dact_out_ndim = casted_dact_out.get_tensor(TensorUsage.LHS).data.ndim
453
454
    g_contracting_dims_1 = tuple(
        range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
455
    )
456
    # k_non_contracting_dims
457
    k_contracting_dims_1 = tuple(
458
        dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
459
    )
460
461
462

    # NT GEMM
    dgrad_1 = tex.gemm(
463
464
        casted_dact_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1,
Alp Dener's avatar
Alp Dener committed
465
        contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
466
    )
467

468
    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
469
470
471
472

    # TN GEMM
    # (hidden, batch...) x (hidden, batch...)
    wgrad_1 = tex.gemm(
473
474
        casted_ln_out,
        casted_dact_out.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
475
        contracting_dims=(x_contracting_dims, g_contracting_dims),
476
    )
477

478
479
    wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)

480
481
482
483
484
485
486
487
488
489
    dx, dgamma, dbeta = tex.normalization_bwd(
        dgrad_1,
        x,
        mu,
        rsigma,
        gamma,
        beta,
        zero_centered_gamma=zero_centered_gamma,
        epsilon=epsilon,
        norm_type=norm_type,
490
    )
491

492
493
    return (dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, quantizer_sets)

494

495
_layernorm_mlp.defvjp(_layernorm_mlp_fwd_rule, _layernorm_mlp_bwd_rule)