layernorm_mlp.py 15.5 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
11
12
13
14
"""Multi-layer perceptron (MLP) operations with layer normalization for Transformer Engine in JAX.

This module provides optimized implementations of MLP blocks commonly used in transformer
architectures. Each MLP block consists of:
1. Layer normalization
2. First dense layer transformation (GEMM1) with bias and activation
3. Second dense layer transformation (GEMM2) with bias

The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints.
"""
15

16
from typing import List, Tuple, Sequence, Union, Callable
17
from functools import partial
18
19
20

import jax
import jax.numpy as jnp
21
from jax.ad_checkpoint import checkpoint_name
22

23
from . import cpp_extensions as tex
24
from .layernorm import canonicalize_norm_type
25
26
27
28
29
30
from .quantize import (
    with_sharding_constraint_by_logical_axes,
    QuantizerSet,
    noop_quantizer_set,
    TensorUsage,
)
31
from .sharding import get_non_contracting_logical_axes
32

33

34
def layernorm_mlp(
35
36
37
38
39
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernels: List[jnp.ndarray],
    biases: List[jnp.ndarray],
40
    norm_type: str,
41
42
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
43
    norm_input_axes: Tuple[str, ...] = None,
44
45
    dot_1_input_axes: Tuple[str, ...] = None,
    dot_2_input_axes: Tuple[str, ...] = None,
46
47
    kernel_1_axes: Tuple[str, ...] = None,
    kernel_2_axes: Tuple[str, ...] = None,
48
49
50
    ffn1_ckpt_name: str = "ffn1",
    ffn2_ckpt_name: str = "ffn2",
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
51
    quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
52
) -> jnp.ndarray:
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    """Apply layer normalization followed by MLP block.

    This function implements the following sequence of operations:
        1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
        2. First dense layer transformation: y1 = x * kernel1 + bias1
        3. Activation function: y2 = activation(y1)
        4. Second dense layer transformation: y3 = y2 * kernel2 + bias2

    Args:
        x: Input tensor with shape [batch..., hidden_in]
        gamma: Scale parameter for normalization with shape [hidden_in]
        beta: Bias parameter for normalization with shape [hidden_in]
        kernels: List of two weight matrices:
            - kernel1: [hidden_in, intermediate]
            - kernel2: [intermediate, hidden_in]
        biases: List of two bias terms:
            - bias1: [intermediate]
            - bias2: [hidden_in]
        norm_type: Type of normalization ("layernorm" or "rmsnorm")
        zero_centered_gamma: Whether to use zero-centered gamma for normalization
        epsilon: Small constant for numerical stability in normalization
        norm_input_axes: Logical axes for sharding the layernorm input
        dot_1_input_axes: Logical axes for sharding the first matrix multiplication
        dot_2_input_axes: Logical axes for sharding the second matrix multiplication
77
78
        kernel_1_axes: Logical axes for sharding the first weight matrix
        kernel_2_axes: Logical axes for sharding the second weight matrix
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        ffn1_ckpt_name: Name for checkpointing the first feed-forward network
        ffn2_ckpt_name: Name for checkpointing the second feed-forward network
        activation_type: Activation function(s) to apply after the first dense layer transformation
        quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations

    Returns:
        Output tensor with shape [batch..., hidden_in]

    Note:
        - For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
          must be False
        - The function supports automatic differentiation through JAX's custom VJP
        - Quantization is applied to both dense layer transformations
        - Checkpointing is applied to both feed-forward networks for memory efficiency
93
94
95
96
97
98
99
100
    """
    assert len(kernels) == 2

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]

101
102
103
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "rmsnorm":
        assert beta is None, "beta should be None if norm_type is 'rmsnorm'"
104
105
        assert (
            not zero_centered_gamma
106
        ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
107

108
    output = _layernorm_mlp(
109
110
111
112
113
114
115
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
116
        norm_type,
117
118
        zero_centered_gamma,
        epsilon,
119
        norm_input_axes,
120
121
        dot_1_input_axes,
        dot_2_input_axes,
122
123
        kernel_1_axes,
        kernel_2_axes,
124
125
126
        ffn1_ckpt_name,
        ffn2_ckpt_name,
        activation_type,
127
        quantizer_sets,
128
    )
129
130
131
    return output


132
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
133
def _layernorm_mlp(
134
135
136
137
138
139
140
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernel_1: jnp.ndarray,
    kernel_2: jnp.ndarray,
    bias_1: jnp.ndarray,
    bias_2: jnp.ndarray,
141
    norm_type: str,
142
143
    zero_centered_gamma: bool,
    epsilon: float,
144
    norm_input_axes: Tuple[str, ...],
145
146
    dot_1_input_axes: Tuple[str, ...],
    dot_2_input_axes: Tuple[str, ...],
147
148
    kernel_1_axes: Tuple[str, ...],
    kernel_2_axes: Tuple[str, ...],
149
150
151
    ffn1_ckpt_name: str,
    ffn2_ckpt_name: str,
    activation_type: Sequence[Union[str, Callable]],
152
    quantizer_sets,
153
):
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    """Internal implementation of layernorm_mlp with custom VJP.

    This function implements the forward pass of layernorm_mlp with support for
    automatic differentiation. It handles the normalization, dense layer transformations,
    activation, and quantization operations.

    Args:
        x: Input tensor
        gamma: Scale parameter for normalization
        beta: Bias parameter for normalization
        kernel_1: First weight matrix
        kernel_2: Second weight matrix
        bias_1: First bias term
        bias_2: Second bias term
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
        norm_input_axes: Logical axes for layernorm sharding
        dot_1_input_axes: Logical axes for first matrix multiplication sharding
        dot_2_input_axes: Logical axes for second matrix multiplication sharding
        ffn1_ckpt_name: Name for first feed-forward network checkpointing
        ffn2_ckpt_name: Name for second feed-forward network checkpointing
        activation_type: Activation function(s)
        quantizer_sets: Tuple of quantizer sets

    Returns:
        Output tensor from the combined operations
    """
    output, _ = _layernorm_mlp_fwd_rule(
183
184
185
186
187
188
189
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
190
        norm_type,
191
192
        zero_centered_gamma,
        epsilon,
193
        norm_input_axes,
194
195
        dot_1_input_axes,
        dot_2_input_axes,
196
197
        kernel_1_axes,
        kernel_2_axes,
198
        ffn1_ckpt_name,
199
200
        ffn2_ckpt_name,
        activation_type,
201
        quantizer_sets,
202
203
204
205
    )
    return output


206
def _layernorm_mlp_fwd_rule(
207
208
209
210
211
212
213
    x,
    gamma,
    beta,
    kernel_1,
    kernel_2,
    bias_1,
    bias_2,
214
    norm_type,
215
216
    zero_centered_gamma,
    epsilon,
217
    norm_input_axes,
218
219
    dot_1_input_axes,
    dot_2_input_axes,
220
221
    kernel_1_axes,
    kernel_2_axes,
222
223
224
    ffn1_ckpt_name,
    ffn2_ckpt_name,
    activation_type,
225
    quantizer_sets,
226
):
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    """Forward pass rule for layernorm_mlp.

    Implements the forward pass computation including:
    1. Layer normalization with quantization
    2. First matrix multiplication with quantized kernel
    3. Activation function application
    4. Second matrix multiplication with quantized kernel
    5. Optional bias additions
    6. Sharding constraints
    7. Checkpointing for memory efficiency

    Returns:
        Tuple of (output, context) for automatic differentiation
    """
241
242
    del kernel_2_axes

243
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
244
245

    # x should be in shape of (batch..., hidden)
246
    # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
247
    # Kernel_2 should be in shape of (intermediate, hidden_in)
248
    assert len(kernel_1.shape) == 3
249
    assert len(kernel_2.shape) == 2
250
    assert kernel_1.shape[-2] == len(activation_type)
251
252

    x_contracting_dims = (len(x.shape) - 1,)
253
    k_contracting_dims = (0,)
254

255
    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
256

257
258
259
260
261
262
263
264
265
266
267
268
269
    use_bias_1 = bias_1 is not None
    use_bias_2 = bias_1 is not None

    x = with_sharding_constraint_by_logical_axes(x, norm_input_axes)

    casted_ln_out, mu, rsigma = tex.normalization_fwd(
        x,
        gamma,
        beta,
        zero_centered_gamma,
        epsilon,
        norm_type,
        quantizer=ffn1_quantizer_set.x,
270
    )
271
    casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
272

273
274
    casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel)

275
    # NN GEMM
276
    # (batch..., hidden_in) x (hidden_in, hidden_out)
277
    dot_1_output = tex.gemm(
278
279
        casted_ln_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1.get_tensor(TensorUsage.RHS),
280
        (x_contracting_dims, k_contracting_dims),
281
    )
282

283
284
285
286
287
288
    if dot_1_input_axes is not None and kernel_1_axes is not None:
        dot_1_output_axes = (
            *get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims),
            *get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims),
        )
        dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)
289

290
    if use_bias_1:
291
292
293
        bias_1_shape = bias_1.shape
        bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
        dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
294

295
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
296
297

    # (batch..., hidden_in) -> (batch..., hidden)
298
    casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x)
299

300
    casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
301

302
    casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel)
303

304
    # NN GEMM
305
    # (batch..., hidden_in) x (hidden_out, hidden_in)
306
    dot_2_output = tex.gemm(
307
308
        casted_act_out.get_tensor(TensorUsage.LHS),
        casted_kernel_2.get_tensor(TensorUsage.RHS),
309
        (x_contracting_dims, k_contracting_dims),
310
    )
311

312
    if use_bias_2:
313
314
315
        bias_2_shape = bias_2.shape
        bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
        dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
316

317
318
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

319
320
321
322
323
    ctx = (
        x,
        mu,
        rsigma,
        gamma,
324
        beta,
325
326
        casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
        casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS),
327
        dot_1_output,
328
329
        casted_act_out.get_tensor(TensorUsage.LHS_TRANS),
        casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS),
330
        x_contracting_dims,
331
332
333
334
335
336
        k_contracting_dims,
        kernel_1.shape,
        kernel_2.shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
337
    )
338
339
340
341

    return dot_2_output, ctx


342
343
def _layernorm_mlp_bwd_rule(
    norm_type,
344
345
    zero_centered_gamma,
    epsilon,
346
    norm_input_axes,
347
348
    dot_1_input_axes,
    dot_2_input_axes,
349
350
351
352
    kernel_1_axes,
    kernel_2_axes,
    ffn1_ckpt_name,
    ffn2_ckpt_name,
353
354
355
356
    activation_type,
    ctx,
    grad,
):
357
358
359
360
361
362
363
364
365
366
367
368
369
    """Backward pass rule for layernorm_mlp.

    Implements the backward pass computation including:
    1. Gradient computation for second matrix multiplication
    2. Gradient computation for activation function
    3. Gradient computation for first matrix multiplication
    4. Gradient computation for layer normalization
    5. Gradient computation for bias terms
    6. Proper handling of quantization

    Returns:
        Tuple of gradients for all input parameters
    """
370
    del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name
371
372
373
374
375
    (
        x,
        mu,
        rsigma,
        gamma,
376
        beta,
377
378
        casted_ln_out,
        casted_kernel_1,
379
        dot_1_output,
380
381
        casted_act_out,
        casted_kernel_2,
382
383
384
385
386
387
388
        x_contracting_dims_in_fwd,
        k_contracting_dims_in_fwd,
        kernel_1_shape,
        kernel_2_shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
389
    ) = ctx
390

391
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
392
393
394

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
395
396
397

    casted_grad, dbias_2 = tex.quantize_dbias(
        grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad
398
    )
399

400
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
401
    g_contracting_dims_2 = tuple(
402
403
404
        range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
    )
    # k_non_contracting_dims
405
    k_contracting_dims_2 = tuple(
406
        dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
407
    )
408

409
    # NT GEMM
410
    # (batch..., hidden_out) x (hidden_in, hidden_out)
411
    dgrad_2 = tex.gemm(
412
413
        casted_grad.get_tensor(TensorUsage.LHS),
        casted_kernel_2,
414
        (g_contracting_dims_2, k_contracting_dims_2),
415
    )
416
417
418

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

419
    x_contracting_dims = g_contracting_dims = tuple(
420
        range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
421
422
    )

423
424
425
    # TN GEMM
    # (hidden, batch...,) x (hidden, batch...)
    wgrad_2 = tex.gemm(
426
427
        casted_act_out,
        casted_grad.get_tensor(TensorUsage.RHS),
428
        (x_contracting_dims, g_contracting_dims),
429
    )
430
    wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
431

432
433
434
435
436
437
    casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
        dgrad_2,
        dot_1_output,
        activation_type=activation_type,
        is_dbias=use_bias_1,
        quantizer=ffn2_quantizer_set.dgrad,
438
    )
439
440

    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
441
    dact_out_ndim = casted_dact_out.get_tensor(TensorUsage.LHS).data.ndim
442
443
    g_contracting_dims_1 = tuple(
        range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
444
    )
445
    # k_non_contracting_dims
446
    k_contracting_dims_1 = tuple(
447
        dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
448
    )
449
450
451

    # NT GEMM
    dgrad_1 = tex.gemm(
452
453
        casted_dact_out.get_tensor(TensorUsage.LHS),
        casted_kernel_1,
454
        (g_contracting_dims_1, k_contracting_dims_1),
455
    )
456

457
    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
458
459
460
461

    # TN GEMM
    # (hidden, batch...) x (hidden, batch...)
    wgrad_1 = tex.gemm(
462
463
        casted_ln_out,
        casted_dact_out.get_tensor(TensorUsage.RHS),
464
        (x_contracting_dims, g_contracting_dims),
465
    )
466

467
468
    wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)

469
470
471
472
473
474
475
476
477
478
    dx, dgamma, dbeta = tex.normalization_bwd(
        dgrad_1,
        x,
        mu,
        rsigma,
        gamma,
        beta,
        zero_centered_gamma=zero_centered_gamma,
        epsilon=epsilon,
        norm_type=norm_type,
479
    )
480

481
482
    return (dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, quantizer_sets)

483

484
_layernorm_mlp.defvjp(_layernorm_mlp_fwd_rule, _layernorm_mlp_bwd_rule)