layernorm_mlp.py 14.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
11
12
13
14
"""Multi-layer perceptron (MLP) operations with layer normalization for Transformer Engine in JAX.

This module provides optimized implementations of MLP blocks commonly used in transformer
architectures. Each MLP block consists of:
1. Layer normalization
2. First dense layer transformation (GEMM1) with bias and activation
3. Second dense layer transformation (GEMM2) with bias

The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints.
"""
15

16
from typing import List, Tuple, Sequence, Union, Callable
17
from functools import partial
18
19
20

import jax
import jax.numpy as jnp
21
from jax.ad_checkpoint import checkpoint_name
22

23
from . import cpp_extensions as tex
24
25
from .layernorm import canonicalize_norm_type
from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set
26

27

28
def layernorm_mlp(
29
30
31
32
33
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernels: List[jnp.ndarray],
    biases: List[jnp.ndarray],
34
    norm_type: str,
35
36
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
37
    norm_input_axes: Tuple[str, ...] = None,
38
39
40
41
42
    dot_1_input_axes: Tuple[str, ...] = None,
    dot_2_input_axes: Tuple[str, ...] = None,
    ffn1_ckpt_name: str = "ffn1",
    ffn2_ckpt_name: str = "ffn2",
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
43
    quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
44
) -> jnp.ndarray:
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    """Apply layer normalization followed by MLP block.

    This function implements the following sequence of operations:
        1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
        2. First dense layer transformation: y1 = x * kernel1 + bias1
        3. Activation function: y2 = activation(y1)
        4. Second dense layer transformation: y3 = y2 * kernel2 + bias2

    Args:
        x: Input tensor with shape [batch..., hidden_in]
        gamma: Scale parameter for normalization with shape [hidden_in]
        beta: Bias parameter for normalization with shape [hidden_in]
        kernels: List of two weight matrices:
            - kernel1: [hidden_in, intermediate]
            - kernel2: [intermediate, hidden_in]
        biases: List of two bias terms:
            - bias1: [intermediate]
            - bias2: [hidden_in]
        norm_type: Type of normalization ("layernorm" or "rmsnorm")
        zero_centered_gamma: Whether to use zero-centered gamma for normalization
        epsilon: Small constant for numerical stability in normalization
        norm_input_axes: Logical axes for sharding the layernorm input
        dot_1_input_axes: Logical axes for sharding the first matrix multiplication
        dot_2_input_axes: Logical axes for sharding the second matrix multiplication
        ffn1_ckpt_name: Name for checkpointing the first feed-forward network
        ffn2_ckpt_name: Name for checkpointing the second feed-forward network
        activation_type: Activation function(s) to apply after the first dense layer transformation
        quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations

    Returns:
        Output tensor with shape [batch..., hidden_in]

    Note:
        - For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
          must be False
        - The function supports automatic differentiation through JAX's custom VJP
        - Quantization is applied to both dense layer transformations
        - Checkpointing is applied to both feed-forward networks for memory efficiency
83
84
85
86
87
88
89
90
    """
    assert len(kernels) == 2

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]

91
92
93
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "rmsnorm":
        assert beta is None, "beta should be None if norm_type is 'rmsnorm'"
94
95
        assert (
            not zero_centered_gamma
96
        ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
97

98
    output = _layernorm_mlp(
99
100
101
102
103
104
105
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
106
        norm_type,
107
108
        zero_centered_gamma,
        epsilon,
109
        norm_input_axes,
110
111
112
113
114
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,
        ffn2_ckpt_name,
        activation_type,
115
        quantizer_sets,
116
    )
117
118
119
    return output


120
121
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15))
def _layernorm_mlp(
122
123
124
125
126
127
128
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernel_1: jnp.ndarray,
    kernel_2: jnp.ndarray,
    bias_1: jnp.ndarray,
    bias_2: jnp.ndarray,
129
    norm_type: str,
130
131
    zero_centered_gamma: bool,
    epsilon: float,
132
    norm_input_axes: Tuple[str, ...],
133
134
135
136
137
    dot_1_input_axes: Tuple[str, ...],
    dot_2_input_axes: Tuple[str, ...],
    ffn1_ckpt_name: str,
    ffn2_ckpt_name: str,
    activation_type: Sequence[Union[str, Callable]],
138
    quantizer_sets,
139
):
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    """Internal implementation of layernorm_mlp with custom VJP.

    This function implements the forward pass of layernorm_mlp with support for
    automatic differentiation. It handles the normalization, dense layer transformations,
    activation, and quantization operations.

    Args:
        x: Input tensor
        gamma: Scale parameter for normalization
        beta: Bias parameter for normalization
        kernel_1: First weight matrix
        kernel_2: Second weight matrix
        bias_1: First bias term
        bias_2: Second bias term
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
        norm_input_axes: Logical axes for layernorm sharding
        dot_1_input_axes: Logical axes for first matrix multiplication sharding
        dot_2_input_axes: Logical axes for second matrix multiplication sharding
        ffn1_ckpt_name: Name for first feed-forward network checkpointing
        ffn2_ckpt_name: Name for second feed-forward network checkpointing
        activation_type: Activation function(s)
        quantizer_sets: Tuple of quantizer sets

    Returns:
        Output tensor from the combined operations
    """
    output, _ = _layernorm_mlp_fwd_rule(
169
170
171
172
173
174
175
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
176
        norm_type,
177
178
        zero_centered_gamma,
        epsilon,
179
        norm_input_axes,
180
181
182
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,
183
184
        ffn2_ckpt_name,
        activation_type,
185
        quantizer_sets,
186
187
188
189
    )
    return output


190
def _layernorm_mlp_fwd_rule(
191
192
193
194
195
196
197
    x,
    gamma,
    beta,
    kernel_1,
    kernel_2,
    bias_1,
    bias_2,
198
    norm_type,
199
200
    zero_centered_gamma,
    epsilon,
201
    norm_input_axes,
202
203
204
205
206
    dot_1_input_axes,
    dot_2_input_axes,
    ffn1_ckpt_name,
    ffn2_ckpt_name,
    activation_type,
207
    quantizer_sets,
208
):
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    """Forward pass rule for layernorm_mlp.

    Implements the forward pass computation including:
    1. Layer normalization with quantization
    2. First matrix multiplication with quantized kernel
    3. Activation function application
    4. Second matrix multiplication with quantized kernel
    5. Optional bias additions
    6. Sharding constraints
    7. Checkpointing for memory efficiency

    Returns:
        Tuple of (output, context) for automatic differentiation
    """
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
224
225

    # x should be in shape of (batch..., hidden)
226
227
228
    # Kernel_1 should be in shape of (hidden_in, activation_len * intermediate)
    # Kernel_2 should be in shape of (intermediate, hidden_in)
    assert len(kernel_1.shape) == 2
229
    assert len(kernel_2.shape) == 2
230
    assert kernel_1.shape[1] == kernel_2.shape[0] * len(activation_type)
231
232

    x_contracting_dims = (len(x.shape) - 1,)
233
    k_contracting_dims = (0,)
234

235
236
    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
    assert kernel_1.shape[1] == len(activation_type) * kernel_2.shape[0]
237

238
239
240
241
242
243
244
245
246
247
248
249
250
    use_bias_1 = bias_1 is not None
    use_bias_2 = bias_1 is not None

    x = with_sharding_constraint_by_logical_axes(x, norm_input_axes)

    casted_ln_out, mu, rsigma = tex.normalization_fwd(
        x,
        gamma,
        beta,
        zero_centered_gamma,
        epsilon,
        norm_type,
        quantizer=ffn1_quantizer_set.x,
251
    )
252

253
254
255
    casted_kernel_1 = tex.quantize(kernel_1, quantizer=ffn1_quantizer_set.kernel)

    casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
256

257
    # NN GEMM
258
    # (batch..., hidden_in) x (hidden_in, hidden_out)
259
260
261
262
    dot_1_output = tex.gemm(
        casted_ln_out.get_rowwise_tensor(),
        casted_kernel_1.get_colwise_tensor(),
        (x_contracting_dims, k_contracting_dims),
263
    )
264
    if use_bias_1:
265
266
267
        bias_1_shape = bias_1.shape
        bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
        dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
268

269
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
270
271

    # (batch..., hidden_in) -> (batch..., hidden)
272
    casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x)
273

274
    casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
275

276
    casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel)
277

278
    # NN GEMM
279
    # (batch..., hidden_in) x (hidden_out, hidden_in)
280
281
282
283
    dot_2_output = tex.gemm(
        casted_act_out.get_rowwise_tensor(),
        casted_kernel_2.get_colwise_tensor(),
        (x_contracting_dims, k_contracting_dims),
284
    )
285

286
    if use_bias_2:
287
288
289
        bias_2_shape = bias_2.shape
        bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
        dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
290

291
292
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

293
294
295
296
297
    ctx = (
        x,
        mu,
        rsigma,
        gamma,
298
        beta,
299
300
        casted_ln_out.get_colwise_tensor(),
        casted_kernel_1.get_rowwise_tensor(),
301
        dot_1_output,
302
303
        casted_act_out.get_colwise_tensor(),
        casted_kernel_2.get_rowwise_tensor(),
304
        x_contracting_dims,
305
306
307
308
309
310
        k_contracting_dims,
        kernel_1.shape,
        kernel_2.shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
311
    )
312
313
314
315

    return dot_2_output, ctx


316
317
def _layernorm_mlp_bwd_rule(
    norm_type,
318
319
    zero_centered_gamma,
    epsilon,
320
    norm_input_axes,
321
322
323
324
325
326
327
328
    dot_1_input_axes,
    dot_2_input_axes,
    ffn1_ckpt_name,  # pylint: disable=unused-argument
    ffn2_ckpt_name,  # pylint: disable=unused-argument
    activation_type,
    ctx,
    grad,
):
329
330
331
332
333
334
335
336
337
338
339
340
341
    """Backward pass rule for layernorm_mlp.

    Implements the backward pass computation including:
    1. Gradient computation for second matrix multiplication
    2. Gradient computation for activation function
    3. Gradient computation for first matrix multiplication
    4. Gradient computation for layer normalization
    5. Gradient computation for bias terms
    6. Proper handling of quantization

    Returns:
        Tuple of gradients for all input parameters
    """
342
343
344
345
346
    (
        x,
        mu,
        rsigma,
        gamma,
347
        beta,
348
349
        colwise_casted_ln_out,
        rowwise_casted_kernel_1,
350
        dot_1_output,
351
352
353
354
355
356
357
358
359
        colwise_casted_act_out,
        rowwise_casted_kernel_2,
        x_contracting_dims_in_fwd,
        k_contracting_dims_in_fwd,
        kernel_1_shape,
        kernel_2_shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
360
    ) = ctx
361

362
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
363
364
365

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
366
367
368

    casted_grad, dbias_2 = tex.quantize_dbias(
        grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad
369
    )
370

371
372
373
374
375
376
377
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
    g_constracting_dim_2 = tuple(
        range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
    )
    # k_non_contracting_dims
    k_constracting_dim_2 = tuple(
        dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
378
    )
379

380
    # NT GEMM
381
    # (batch..., hidden_out) x (hidden_in, hidden_out)
382
383
384
385
    dgrad_2 = tex.gemm(
        casted_grad.get_rowwise_tensor(),
        rowwise_casted_kernel_2,
        (g_constracting_dim_2, k_constracting_dim_2),
386
    )
387
388
389

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

390
391
    x_constracting_dim = g_constracting_dim = tuple(
        range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
392
393
    )

394
395
396
397
398
399
    # TN GEMM
    # (hidden, batch...,) x (hidden, batch...)
    wgrad_2 = tex.gemm(
        colwise_casted_act_out,
        casted_grad.get_colwise_tensor(),
        (x_constracting_dim, g_constracting_dim),
400
    )
401

402
403
404
405
406
407
    casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
        dgrad_2,
        dot_1_output,
        activation_type=activation_type,
        is_dbias=use_bias_1,
        quantizer=ffn2_quantizer_set.dgrad,
408
    )
409
410
411
412

    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
    g_constracting_dim_1 = tuple(
        range(dgrad_2.ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dgrad_2.ndim)
413
    )
414
415
416
    # k_non_contracting_dims
    k_constracting_dim_1 = tuple(
        dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
417
    )
418
419
420
421
422
423

    # NT GEMM
    dgrad_1 = tex.gemm(
        casted_dact_out.get_rowwise_tensor(),
        rowwise_casted_kernel_1,
        (g_constracting_dim_1, k_constracting_dim_1),
424
    )
425
426
427
428
429
430
431
432
433

    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, norm_input_axes)

    # TN GEMM
    # (hidden, batch...) x (hidden, batch...)
    wgrad_1 = tex.gemm(
        colwise_casted_ln_out,
        casted_dact_out.get_colwise_tensor(),
        (x_constracting_dim, g_constracting_dim),
434
    )
435

436
437
438
439
440
441
442
443
444
445
    dx, dgamma, dbeta = tex.normalization_bwd(
        dgrad_1,
        x,
        mu,
        rsigma,
        gamma,
        beta,
        zero_centered_gamma=zero_centered_gamma,
        epsilon=epsilon,
        norm_type=norm_type,
446
    )
447

448
449
    return (dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, quantizer_sets)

450

451
_layernorm_mlp.defvjp(_layernorm_mlp_fwd_rule, _layernorm_mlp_bwd_rule)