layernorm_dense.py 10.3 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# See LICENSE for license information.
"""Fused Layer normalization and dense layer transformation operations for Transformer Engine in JAX.

This module provides optimized implementations of layer normalization followed by
dense layer transformation (GEMM) operations, which are commonly used in transformer
architectures. It supports various normalization types, quantization, and
distributed training through sharding constraints.
"""

from functools import partial
from typing import Tuple

import jax
import jax.numpy as jnp

from . import cpp_extensions as tex
19
from .cpp_extensions.amax import AmaxScope
20
21
22
23
24

from .quantize import (
    QuantizerSet,
    noop_quantizer_set,
    with_sharding_constraint_by_logical_axes,
25
    TensorUsage,
26
)
Alp Dener's avatar
Alp Dener committed
27
28


29
30
31
32
33
34
35
36
37
def layernorm_dense(
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    bias: jnp.ndarray = None,
    norm_type: str = "layernorm",
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
38
    transpose_batch_sequence: bool = False,
39
40
    layernorm_input_axes: Tuple[str, ...] = None,
    dot_input_axes: Tuple[str, ...] = None,
41
    kernel_axes: Tuple[str, ...] = None,
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray:
    """Apply layer normalization followed by dense layer transformation.

    This function implements the following sequence of operations:
        1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
        2. Linear transformation: y = x * kernel + bias

    Args:
        x: Input tensor with shape [batch..., hidden_in]
        kernel: Weight matrix with shape [hidden_in, hidden_out]
        gamma: Scale parameter for normalization with shape [hidden_in]
        beta: Bias parameter for normalization with shape [hidden_in]
        bias: Optional bias term for dense layer transformation with shape [hidden_out]
        norm_type: Type of normalization ("layernorm" or "rmsnorm")
        zero_centered_gamma: Whether to use zero-centered gamma for normalization
        epsilon: Small constant for numerical stability in normalization
59
        transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
60
61
        layernorm_input_axes: Logical axes for sharding the layernorm input
        dot_input_axes: Logical axes for sharding the matrix multiplication input
62
        kernel_axes: Logical axes for sharding the weight matrix
63
64
65
66
67
68
69
70
71
72
73
        quantizer_set: Set of quantizers for different tensor types

    Returns:
        Output tensor with shape [batch..., hidden_out]

    Note:
        - For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
          must be False
        - The function supports automatic differentiation through JAX's custom VJP
        - Quantization is applied to both the normalized input and kernel
    """
74

75
    if quantizer_set == noop_quantizer_set:
76
77
78
        input_dtype = x.dtype
        kernel = kernel.astype(input_dtype)

79
80
81
82
83
84
85
86
87
    output = _layernorm_dense(
        x,
        kernel,
        gamma,
        beta,
        bias,
        norm_type,
        zero_centered_gamma,
        epsilon,
88
        transpose_batch_sequence,
89
90
        layernorm_input_axes,
        dot_input_axes,
91
        kernel_axes,
92
93
94
95
96
97
98
99
100
101
102
103
104
        quantizer_set,
    )
    return output


@partial(
    jax.custom_vjp,
    nondiff_argnums=(
        5,
        6,
        7,
        8,
        9,
105
        10,
106
        11,
107
108
109
110
111
112
113
114
115
116
117
    ),
)
def _layernorm_dense(
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    bias: jnp.ndarray,
    norm_type: str,
    zero_centered_gamma: bool,
    epsilon: float,
118
    transpose_batch_sequence: bool,
119
120
    layernorm_input_axes: Tuple[str, ...],
    dot_input_axes: Tuple[str, ...],
121
    kernel_axes: Tuple[str, ...],
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    quantizer_set,
):
    """Internal implementation of layernorm_dense with custom VJP.

    This function implements the forward pass of layernorm_dense with support for
    automatic differentiation. It handles the normalization and dense layer transformation
    operations, including quantization and sharding constraints.

    Args:
        x: Input tensor
        kernel: Weight matrix
        gamma: Scale parameter for normalization
        beta: Bias parameter for normalization
        bias: Optional bias term
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
139
        transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        layernorm_input_axes: Logical axes for layernorm sharding
        dot_input_axes: Logical axes for matrix multiplication sharding
        quantizer_set: Set of quantizers

    Returns:
        Output tensor from the combined operations
    """
    output, _ = _layernorm_dense_fwd_rule(
        x,
        kernel,
        gamma,
        beta,
        bias,
        norm_type,
        zero_centered_gamma,
        epsilon,
156
        transpose_batch_sequence,
157
158
        layernorm_input_axes,
        dot_input_axes,
159
        kernel_axes,
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        quantizer_set,
    )
    return output


def _layernorm_dense_fwd_rule(
    x,
    kernel,
    gamma,
    beta,
    bias,
    norm_type,
    zero_centered_gamma,
    epsilon,
174
    transpose_batch_sequence,
175
176
    layernorm_input_axes,
    dot_input_axes,
177
    kernel_axes,
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    quantizer_set,
):
    """Forward pass rule for layernorm_dense.

    Implements the forward pass computation including:
    1. Layer normalization with quantization
    2. Matrix multiplication with quantized kernel
    3. Optional bias addition
    4. Sharding constraints

    Returns:
        Tuple of (output, context) for automatic differentiation
    """
    x_contracting_dims = (len(x.shape) - 1,)
    k_contracting_dims = (0,)
    assert x.shape[-1] == kernel.shape[0]

    x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)

    casted_ln_out, mu, rsigma = tex.normalization_fwd(
        x,
        gamma,
        beta,
        zero_centered_gamma,
        epsilon,
        norm_type,
Alp Dener's avatar
Alp Dener committed
204
        quantizer=quantizer_set.x,
205
206
        amax_scope=AmaxScope.TPSP,
        transpose_batch_sequence=transpose_batch_sequence,
207
    )
208
    casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
209
210

    # Kernel in (hidden_in, hidden_out...)
211
    flatten_axis = 1 - len(kernel.shape)
Alp Dener's avatar
Alp Dener committed
212
    casted_kernel = tex.quantize(
213
214
215
        kernel,
        flatten_axis=flatten_axis,
        quantizer=quantizer_set.kernel,
216
217
        amax_scope=AmaxScope.FSDP,
        transpose_batch_sequence=transpose_batch_sequence,
Alp Dener's avatar
Alp Dener committed
218
    )
219
    casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
220
221
222

    # NN GEMM
    # (batch..., hidden_in) x (hidden_in, hidden_out...)
Alp Dener's avatar
Alp Dener committed
223
    use_bias = bias is not None
224
    output = tex.gemm(
225
226
        casted_ln_out.get_tensor(TensorUsage.LHS),
        casted_kernel.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
227
        contracting_dims=(x_contracting_dims, k_contracting_dims),
228
        transpose_batch_sequence=transpose_batch_sequence,
Alp Dener's avatar
Alp Dener committed
229
230
        bias=bias if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
231
232
    )

Alp Dener's avatar
Alp Dener committed
233
    if use_bias and tex.gemm_uses_jax_dot():
234
235
236
237
        bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
        output += jnp.reshape(bias, bias_new_shape)

    ctx = (
238
239
        casted_ln_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(quantizer_set.x),
        casted_kernel.get_tensor(TensorUsage.RHS_TRANS).checkpoint(quantizer_set.kernel),
240
241
242
243
244
245
246
247
248
249
250
        x.shape,
        kernel.shape,
        mu,
        rsigma,
        x,
        gamma,
        beta,
        x_contracting_dims,
        k_contracting_dims,
        use_bias,
        quantizer_set,
251
        flatten_axis,
252
253
254
255
256
257
258
259
260
    )

    return output, ctx


def _layernorm_dense_bwd_rule(
    norm_type,
    zero_centered_gamma,
    epsilon,
261
    transpose_batch_sequence,
262
    layernorm_input_axes,
263
    dot_input_axes,
264
    kernel_axes,
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    ctx,
    grad,
):
    """Backward pass rule for layernorm_dense.

    Implements the backward pass computation including:
    1. Gradient computation for matrix multiplication
    2. Gradient computation for layer normalization
    3. Gradient computation for bias terms
    4. Proper handling of quantization

    Returns:
        Tuple of gradients for all input parameters
    """
279
    del dot_input_axes
280
    (
281
282
        casted_ln_out,
        casted_kernel,
283
284
285
286
287
288
289
290
291
292
293
        x_shape,
        kernel_shape,
        mu,
        rsigma,
        x,
        gamma,
        beta,
        x_contracting_dims_in_fwd,
        k_contracting_dims_in_fwd,
        use_bias,
        quantizer_set,
294
        flatten_axis,
295
296
    ) = ctx

297
    casted_grad, dbias = tex.quantize_dbias(
Alp Dener's avatar
Alp Dener committed
298
299
300
301
        grad,
        is_dbias=use_bias,
        flatten_axis=flatten_axis,
        quantizer=quantizer_set.dgrad,
302
303
        amax_scope=AmaxScope.TPSP,
        transpose_batch_sequence=transpose_batch_sequence,
304
    )
305
306
307
308
309
310
311
312
313
314
315
316

    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
    g_constracting_dim = tuple(
        range(grad.ndim - len(kernel_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
    )
    # k_non_contracting_dims
    k_constracting_dim = tuple(
        dim for dim in range(len(kernel_shape)) if dim not in k_contracting_dims_in_fwd
    )

    # NT GEMM
    dgrad = tex.gemm(
317
318
        casted_grad.get_tensor(TensorUsage.LHS),
        casted_kernel,
Alp Dener's avatar
Alp Dener committed
319
        contracting_dims=(g_constracting_dim, k_constracting_dim),
320
        transpose_batch_sequence=transpose_batch_sequence,
321
322
323
324
325
326
327
328
329
330
    )

    dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)

    g_constracting_dim = x_constracting_dim = tuple(
        range(0, len(x_shape) - len(x_contracting_dims_in_fwd))
    )

    # TN GEMM
    wgrad = tex.gemm(
331
332
        casted_ln_out,
        casted_grad.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
333
        contracting_dims=(x_constracting_dim, g_constracting_dim),
334
        transpose_batch_sequence=transpose_batch_sequence,
335
336
    )

337
338
    wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)

339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
    dx, dgamma, dbeta = tex.normalization_bwd(
        dgrad,
        x,
        mu,
        rsigma,
        gamma,
        beta,
        zero_centered_gamma=zero_centered_gamma,
        epsilon=epsilon,
        norm_type=norm_type,
    )

    return dx, wgrad, dgamma, dbeta, dbias, quantizer_set


_layernorm_dense.defvjp(_layernorm_dense_fwd_rule, _layernorm_dense_bwd_rule)