layernorm_mlp.py 15.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
11
12
13
14
"""Multi-layer perceptron (MLP) operations with layer normalization for Transformer Engine in JAX.

This module provides optimized implementations of MLP blocks commonly used in transformer
architectures. Each MLP block consists of:
1. Layer normalization
2. First dense layer transformation (GEMM1) with bias and activation
3. Second dense layer transformation (GEMM2) with bias

The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints.
"""
15

16
from typing import List, Tuple, Sequence, Union, Callable
17
from functools import partial
18
19
20

import jax
import jax.numpy as jnp
21
from jax.ad_checkpoint import checkpoint_name
22

23
from . import cpp_extensions as tex
24
25
from .layernorm import canonicalize_norm_type
from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set
26
from .sharding import get_non_contracting_logical_axes
27

28

29
def layernorm_mlp(
30
31
32
33
34
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernels: List[jnp.ndarray],
    biases: List[jnp.ndarray],
35
    norm_type: str,
36
37
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
38
    norm_input_axes: Tuple[str, ...] = None,
39
40
    dot_1_input_axes: Tuple[str, ...] = None,
    dot_2_input_axes: Tuple[str, ...] = None,
41
42
    kernel_1_axes: Tuple[str, ...] = None,
    kernel_2_axes: Tuple[str, ...] = None,
43
44
45
    ffn1_ckpt_name: str = "ffn1",
    ffn2_ckpt_name: str = "ffn2",
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
46
    quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
47
) -> jnp.ndarray:
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    """Apply layer normalization followed by MLP block.

    This function implements the following sequence of operations:
        1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
        2. First dense layer transformation: y1 = x * kernel1 + bias1
        3. Activation function: y2 = activation(y1)
        4. Second dense layer transformation: y3 = y2 * kernel2 + bias2

    Args:
        x: Input tensor with shape [batch..., hidden_in]
        gamma: Scale parameter for normalization with shape [hidden_in]
        beta: Bias parameter for normalization with shape [hidden_in]
        kernels: List of two weight matrices:
            - kernel1: [hidden_in, intermediate]
            - kernel2: [intermediate, hidden_in]
        biases: List of two bias terms:
            - bias1: [intermediate]
            - bias2: [hidden_in]
        norm_type: Type of normalization ("layernorm" or "rmsnorm")
        zero_centered_gamma: Whether to use zero-centered gamma for normalization
        epsilon: Small constant for numerical stability in normalization
        norm_input_axes: Logical axes for sharding the layernorm input
        dot_1_input_axes: Logical axes for sharding the first matrix multiplication
        dot_2_input_axes: Logical axes for sharding the second matrix multiplication
72
73
        kernel_1_axes: Logical axes for sharding the first weight matrix
        kernel_2_axes: Logical axes for sharding the second weight matrix
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        ffn1_ckpt_name: Name for checkpointing the first feed-forward network
        ffn2_ckpt_name: Name for checkpointing the second feed-forward network
        activation_type: Activation function(s) to apply after the first dense layer transformation
        quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations

    Returns:
        Output tensor with shape [batch..., hidden_in]

    Note:
        - For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
          must be False
        - The function supports automatic differentiation through JAX's custom VJP
        - Quantization is applied to both dense layer transformations
        - Checkpointing is applied to both feed-forward networks for memory efficiency
88
89
90
91
92
93
94
95
    """
    assert len(kernels) == 2

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]

96
97
98
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "rmsnorm":
        assert beta is None, "beta should be None if norm_type is 'rmsnorm'"
99
100
        assert (
            not zero_centered_gamma
101
        ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
102

103
    output = _layernorm_mlp(
104
105
106
107
108
109
110
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
111
        norm_type,
112
113
        zero_centered_gamma,
        epsilon,
114
        norm_input_axes,
115
116
        dot_1_input_axes,
        dot_2_input_axes,
117
118
        kernel_1_axes,
        kernel_2_axes,
119
120
121
        ffn1_ckpt_name,
        ffn2_ckpt_name,
        activation_type,
122
        quantizer_sets,
123
    )
124
125
126
    return output


127
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
128
def _layernorm_mlp(
129
130
131
132
133
134
135
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    kernel_1: jnp.ndarray,
    kernel_2: jnp.ndarray,
    bias_1: jnp.ndarray,
    bias_2: jnp.ndarray,
136
    norm_type: str,
137
138
    zero_centered_gamma: bool,
    epsilon: float,
139
    norm_input_axes: Tuple[str, ...],
140
141
    dot_1_input_axes: Tuple[str, ...],
    dot_2_input_axes: Tuple[str, ...],
142
143
    kernel_1_axes: Tuple[str, ...],
    kernel_2_axes: Tuple[str, ...],
144
145
146
    ffn1_ckpt_name: str,
    ffn2_ckpt_name: str,
    activation_type: Sequence[Union[str, Callable]],
147
    quantizer_sets,
148
):
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    """Internal implementation of layernorm_mlp with custom VJP.

    This function implements the forward pass of layernorm_mlp with support for
    automatic differentiation. It handles the normalization, dense layer transformations,
    activation, and quantization operations.

    Args:
        x: Input tensor
        gamma: Scale parameter for normalization
        beta: Bias parameter for normalization
        kernel_1: First weight matrix
        kernel_2: Second weight matrix
        bias_1: First bias term
        bias_2: Second bias term
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
        norm_input_axes: Logical axes for layernorm sharding
        dot_1_input_axes: Logical axes for first matrix multiplication sharding
        dot_2_input_axes: Logical axes for second matrix multiplication sharding
        ffn1_ckpt_name: Name for first feed-forward network checkpointing
        ffn2_ckpt_name: Name for second feed-forward network checkpointing
        activation_type: Activation function(s)
        quantizer_sets: Tuple of quantizer sets

    Returns:
        Output tensor from the combined operations
    """
    output, _ = _layernorm_mlp_fwd_rule(
178
179
180
181
182
183
184
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
185
        norm_type,
186
187
        zero_centered_gamma,
        epsilon,
188
        norm_input_axes,
189
190
        dot_1_input_axes,
        dot_2_input_axes,
191
192
        kernel_1_axes,
        kernel_2_axes,
193
        ffn1_ckpt_name,
194
195
        ffn2_ckpt_name,
        activation_type,
196
        quantizer_sets,
197
198
199
200
    )
    return output


201
def _layernorm_mlp_fwd_rule(
202
203
204
205
206
207
208
    x,
    gamma,
    beta,
    kernel_1,
    kernel_2,
    bias_1,
    bias_2,
209
    norm_type,
210
211
    zero_centered_gamma,
    epsilon,
212
    norm_input_axes,
213
214
    dot_1_input_axes,
    dot_2_input_axes,
215
216
    kernel_1_axes,
    kernel_2_axes,
217
218
219
    ffn1_ckpt_name,
    ffn2_ckpt_name,
    activation_type,
220
    quantizer_sets,
221
):
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    """Forward pass rule for layernorm_mlp.

    Implements the forward pass computation including:
    1. Layer normalization with quantization
    2. First matrix multiplication with quantized kernel
    3. Activation function application
    4. Second matrix multiplication with quantized kernel
    5. Optional bias additions
    6. Sharding constraints
    7. Checkpointing for memory efficiency

    Returns:
        Tuple of (output, context) for automatic differentiation
    """
236
237
    del kernel_2_axes

238
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
239
240

    # x should be in shape of (batch..., hidden)
241
    # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
242
    # Kernel_2 should be in shape of (intermediate, hidden_in)
243
    assert len(kernel_1.shape) == 3
244
    assert len(kernel_2.shape) == 2
245
    assert kernel_1.shape[-2] == len(activation_type)
246
247

    x_contracting_dims = (len(x.shape) - 1,)
248
    k_contracting_dims = (0,)
249

250
    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
251

252
253
254
255
256
257
258
259
260
261
262
263
264
    use_bias_1 = bias_1 is not None
    use_bias_2 = bias_1 is not None

    x = with_sharding_constraint_by_logical_axes(x, norm_input_axes)

    casted_ln_out, mu, rsigma = tex.normalization_fwd(
        x,
        gamma,
        beta,
        zero_centered_gamma,
        epsilon,
        norm_type,
        quantizer=ffn1_quantizer_set.x,
265
    )
266
    casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
267

268
269
    casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel)

270
    # NN GEMM
271
    # (batch..., hidden_in) x (hidden_in, hidden_out)
272
273
274
275
    dot_1_output = tex.gemm(
        casted_ln_out.get_rowwise_tensor(),
        casted_kernel_1.get_colwise_tensor(),
        (x_contracting_dims, k_contracting_dims),
276
    )
277

278
279
280
281
282
283
    if dot_1_input_axes is not None and kernel_1_axes is not None:
        dot_1_output_axes = (
            *get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims),
            *get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims),
        )
        dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)
284

285
    if use_bias_1:
286
287
288
        bias_1_shape = bias_1.shape
        bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
        dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
289

290
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
291
292

    # (batch..., hidden_in) -> (batch..., hidden)
293
    casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x)
294

295
    casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
296

297
    casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel)
298

299
    # NN GEMM
300
    # (batch..., hidden_in) x (hidden_out, hidden_in)
301
302
303
304
    dot_2_output = tex.gemm(
        casted_act_out.get_rowwise_tensor(),
        casted_kernel_2.get_colwise_tensor(),
        (x_contracting_dims, k_contracting_dims),
305
    )
306

307
    if use_bias_2:
308
309
310
        bias_2_shape = bias_2.shape
        bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
        dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
311

312
313
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

314
315
316
317
318
    ctx = (
        x,
        mu,
        rsigma,
        gamma,
319
        beta,
320
321
        casted_ln_out.get_colwise_tensor(),
        casted_kernel_1.get_rowwise_tensor(),
322
        dot_1_output,
323
324
        casted_act_out.get_colwise_tensor(),
        casted_kernel_2.get_rowwise_tensor(),
325
        x_contracting_dims,
326
327
328
329
330
331
        k_contracting_dims,
        kernel_1.shape,
        kernel_2.shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
332
    )
333
334
335
336

    return dot_2_output, ctx


337
338
def _layernorm_mlp_bwd_rule(
    norm_type,
339
340
    zero_centered_gamma,
    epsilon,
341
    norm_input_axes,
342
343
    dot_1_input_axes,
    dot_2_input_axes,
344
345
346
347
    kernel_1_axes,
    kernel_2_axes,
    ffn1_ckpt_name,
    ffn2_ckpt_name,
348
349
350
351
    activation_type,
    ctx,
    grad,
):
352
353
354
355
356
357
358
359
360
361
362
363
364
    """Backward pass rule for layernorm_mlp.

    Implements the backward pass computation including:
    1. Gradient computation for second matrix multiplication
    2. Gradient computation for activation function
    3. Gradient computation for first matrix multiplication
    4. Gradient computation for layer normalization
    5. Gradient computation for bias terms
    6. Proper handling of quantization

    Returns:
        Tuple of gradients for all input parameters
    """
365
    del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name
366
367
368
369
370
    (
        x,
        mu,
        rsigma,
        gamma,
371
        beta,
372
373
        colwise_casted_ln_out,
        rowwise_casted_kernel_1,
374
        dot_1_output,
375
376
377
378
379
380
381
382
383
        colwise_casted_act_out,
        rowwise_casted_kernel_2,
        x_contracting_dims_in_fwd,
        k_contracting_dims_in_fwd,
        kernel_1_shape,
        kernel_2_shape,
        use_bias_1,
        use_bias_2,
        quantizer_sets,
384
    ) = ctx
385

386
    ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
387
388
389

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
390
391
392

    casted_grad, dbias_2 = tex.quantize_dbias(
        grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad
393
    )
394

395
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
396
    g_contracting_dims_2 = tuple(
397
398
399
        range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
    )
    # k_non_contracting_dims
400
    k_contracting_dims_2 = tuple(
401
        dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
402
    )
403

404
    # NT GEMM
405
    # (batch..., hidden_out) x (hidden_in, hidden_out)
406
407
408
    dgrad_2 = tex.gemm(
        casted_grad.get_rowwise_tensor(),
        rowwise_casted_kernel_2,
409
        (g_contracting_dims_2, k_contracting_dims_2),
410
    )
411
412
413

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

414
    x_contracting_dims = g_contracting_dims = tuple(
415
        range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
416
417
    )

418
419
420
421
422
    # TN GEMM
    # (hidden, batch...,) x (hidden, batch...)
    wgrad_2 = tex.gemm(
        colwise_casted_act_out,
        casted_grad.get_colwise_tensor(),
423
        (x_contracting_dims, g_contracting_dims),
424
    )
425
    wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
426

427
428
429
430
431
432
    casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
        dgrad_2,
        dot_1_output,
        activation_type=activation_type,
        is_dbias=use_bias_1,
        quantizer=ffn2_quantizer_set.dgrad,
433
    )
434
435

    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
436
437
438
    dact_out_ndim = casted_dact_out.get_rowwise_tensor().data.ndim
    g_contracting_dims_1 = tuple(
        range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
439
    )
440
    # k_non_contracting_dims
441
    k_contracting_dims_1 = tuple(
442
        dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
443
    )
444
445
446
447
448

    # NT GEMM
    dgrad_1 = tex.gemm(
        casted_dact_out.get_rowwise_tensor(),
        rowwise_casted_kernel_1,
449
        (g_contracting_dims_1, k_contracting_dims_1),
450
    )
451

452
    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
453
454
455
456
457
458

    # TN GEMM
    # (hidden, batch...) x (hidden, batch...)
    wgrad_1 = tex.gemm(
        colwise_casted_ln_out,
        casted_dact_out.get_colwise_tensor(),
459
        (x_contracting_dims, g_contracting_dims),
460
    )
461

462
463
    wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)

464
465
466
467
468
469
470
471
472
473
    dx, dgamma, dbeta = tex.normalization_bwd(
        dgrad_1,
        x,
        mu,
        rsigma,
        gamma,
        beta,
        zero_centered_gamma=zero_centered_gamma,
        epsilon=epsilon,
        norm_type=norm_type,
474
    )
475

476
477
    return (dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, quantizer_sets)

478

479
_layernorm_mlp.defvjp(_layernorm_mlp_fwd_rule, _layernorm_mlp_bwd_rule)