layernorm_dense.py 10.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused Layer normalization and dense layer transformation operations for Transformer Engine in JAX.

This module provides optimized implementations of layer normalization followed by
dense layer transformation (GEMM) operations, which are commonly used in transformer
architectures. It supports various normalization types, quantization, and
distributed training through sharding constraints.
"""

from functools import partial
from typing import Tuple

import jax
import jax.numpy as jnp

from . import cpp_extensions as tex
19
from .cpp_extensions.amax import AmaxScope
20
21
22
23
24

from .quantize import (
    QuantizerSet,
    noop_quantizer_set,
    with_sharding_constraint_by_logical_axes,
25
    TensorUsage,
26
    get_quantize_config,
27
)
Alp Dener's avatar
Alp Dener committed
28
29


30
31
32
33
34
35
36
37
38
def layernorm_dense(
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    bias: jnp.ndarray = None,
    norm_type: str = "layernorm",
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
39
    transpose_batch_sequence: bool = False,
40
41
    layernorm_input_axes: Tuple[str, ...] = None,
    dot_input_axes: Tuple[str, ...] = None,
42
    kernel_axes: Tuple[str, ...] = None,
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray:
    """Apply layer normalization followed by dense layer transformation.

    This function implements the following sequence of operations:
        1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
        2. Linear transformation: y = x * kernel + bias

    Args:
        x: Input tensor with shape [batch..., hidden_in]
        kernel: Weight matrix with shape [hidden_in, hidden_out]
        gamma: Scale parameter for normalization with shape [hidden_in]
        beta: Bias parameter for normalization with shape [hidden_in]
        bias: Optional bias term for dense layer transformation with shape [hidden_out]
        norm_type: Type of normalization ("layernorm" or "rmsnorm")
        zero_centered_gamma: Whether to use zero-centered gamma for normalization
        epsilon: Small constant for numerical stability in normalization
60
        transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
61
62
        layernorm_input_axes: Logical axes for sharding the layernorm input
        dot_input_axes: Logical axes for sharding the matrix multiplication input
63
        kernel_axes: Logical axes for sharding the weight matrix
64
65
66
67
68
69
70
71
72
73
74
        quantizer_set: Set of quantizers for different tensor types

    Returns:
        Output tensor with shape [batch..., hidden_out]

    Note:
        - For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
          must be False
        - The function supports automatic differentiation through JAX's custom VJP
        - Quantization is applied to both the normalized input and kernel
    """
75
76
77
78
79

    if not get_quantize_config().is_fp8_enabled():
        input_dtype = x.dtype
        kernel = kernel.astype(input_dtype)

80
81
82
83
84
85
86
87
88
    output = _layernorm_dense(
        x,
        kernel,
        gamma,
        beta,
        bias,
        norm_type,
        zero_centered_gamma,
        epsilon,
89
        transpose_batch_sequence,
90
91
        layernorm_input_axes,
        dot_input_axes,
92
        kernel_axes,
93
94
95
96
97
98
99
100
101
102
103
104
105
        quantizer_set,
    )
    return output


@partial(
    jax.custom_vjp,
    nondiff_argnums=(
        5,
        6,
        7,
        8,
        9,
106
        10,
107
        11,
108
109
110
111
112
113
114
115
116
117
118
    ),
)
def _layernorm_dense(
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    bias: jnp.ndarray,
    norm_type: str,
    zero_centered_gamma: bool,
    epsilon: float,
119
    transpose_batch_sequence: bool,
120
121
    layernorm_input_axes: Tuple[str, ...],
    dot_input_axes: Tuple[str, ...],
122
    kernel_axes: Tuple[str, ...],
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    quantizer_set,
):
    """Internal implementation of layernorm_dense with custom VJP.

    This function implements the forward pass of layernorm_dense with support for
    automatic differentiation. It handles the normalization and dense layer transformation
    operations, including quantization and sharding constraints.

    Args:
        x: Input tensor
        kernel: Weight matrix
        gamma: Scale parameter for normalization
        beta: Bias parameter for normalization
        bias: Optional bias term
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
140
        transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        layernorm_input_axes: Logical axes for layernorm sharding
        dot_input_axes: Logical axes for matrix multiplication sharding
        quantizer_set: Set of quantizers

    Returns:
        Output tensor from the combined operations
    """
    output, _ = _layernorm_dense_fwd_rule(
        x,
        kernel,
        gamma,
        beta,
        bias,
        norm_type,
        zero_centered_gamma,
        epsilon,
157
        transpose_batch_sequence,
158
159
        layernorm_input_axes,
        dot_input_axes,
160
        kernel_axes,
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        quantizer_set,
    )
    return output


def _layernorm_dense_fwd_rule(
    x,
    kernel,
    gamma,
    beta,
    bias,
    norm_type,
    zero_centered_gamma,
    epsilon,
175
    transpose_batch_sequence,
176
177
    layernorm_input_axes,
    dot_input_axes,
178
    kernel_axes,
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    quantizer_set,
):
    """Forward pass rule for layernorm_dense.

    Implements the forward pass computation including:
    1. Layer normalization with quantization
    2. Matrix multiplication with quantized kernel
    3. Optional bias addition
    4. Sharding constraints

    Returns:
        Tuple of (output, context) for automatic differentiation
    """
    x_contracting_dims = (len(x.shape) - 1,)
    k_contracting_dims = (0,)
    assert x.shape[-1] == kernel.shape[0]

    x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)

    casted_ln_out, mu, rsigma = tex.normalization_fwd(
        x,
        gamma,
        beta,
        zero_centered_gamma,
        epsilon,
        norm_type,
Alp Dener's avatar
Alp Dener committed
205
        quantizer=quantizer_set.x,
206
207
        amax_scope=AmaxScope.TPSP,
        transpose_batch_sequence=transpose_batch_sequence,
208
    )
209
    casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
210
211

    # Kernel in (hidden_in, hidden_out...)
212
    flatten_axis = 1 - len(kernel.shape)
Alp Dener's avatar
Alp Dener committed
213
    casted_kernel = tex.quantize(
214
215
216
        kernel,
        flatten_axis=flatten_axis,
        quantizer=quantizer_set.kernel,
217
218
        amax_scope=AmaxScope.FSDP,
        transpose_batch_sequence=transpose_batch_sequence,
Alp Dener's avatar
Alp Dener committed
219
    )
220
    casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
221
222
223

    # NN GEMM
    # (batch..., hidden_in) x (hidden_in, hidden_out...)
Alp Dener's avatar
Alp Dener committed
224
    use_bias = bias is not None
225
    output = tex.gemm(
226
227
        casted_ln_out.get_tensor(TensorUsage.LHS),
        casted_kernel.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
228
        contracting_dims=(x_contracting_dims, k_contracting_dims),
229
        transpose_batch_sequence=transpose_batch_sequence,
Alp Dener's avatar
Alp Dener committed
230
231
        bias=bias if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
232
233
    )

Alp Dener's avatar
Alp Dener committed
234
    if use_bias and tex.gemm_uses_jax_dot():
235
236
237
238
        bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
        output += jnp.reshape(bias, bias_new_shape)

    ctx = (
239
240
        casted_ln_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(quantizer_set.x),
        casted_kernel.get_tensor(TensorUsage.RHS_TRANS).checkpoint(quantizer_set.kernel),
241
242
243
244
245
246
247
248
249
250
251
        x.shape,
        kernel.shape,
        mu,
        rsigma,
        x,
        gamma,
        beta,
        x_contracting_dims,
        k_contracting_dims,
        use_bias,
        quantizer_set,
252
        flatten_axis,
253
254
255
256
257
258
259
260
261
    )

    return output, ctx


def _layernorm_dense_bwd_rule(
    norm_type,
    zero_centered_gamma,
    epsilon,
262
    transpose_batch_sequence,
263
    layernorm_input_axes,
264
    dot_input_axes,
265
    kernel_axes,
266
267
268
269
270
271
272
273
274
275
276
277
278
279
    ctx,
    grad,
):
    """Backward pass rule for layernorm_dense.

    Implements the backward pass computation including:
    1. Gradient computation for matrix multiplication
    2. Gradient computation for layer normalization
    3. Gradient computation for bias terms
    4. Proper handling of quantization

    Returns:
        Tuple of gradients for all input parameters
    """
280
    del dot_input_axes
281
    (
282
283
        casted_ln_out,
        casted_kernel,
284
285
286
287
288
289
290
291
292
293
294
        x_shape,
        kernel_shape,
        mu,
        rsigma,
        x,
        gamma,
        beta,
        x_contracting_dims_in_fwd,
        k_contracting_dims_in_fwd,
        use_bias,
        quantizer_set,
295
        flatten_axis,
296
297
    ) = ctx

298
    casted_grad, dbias = tex.quantize_dbias(
Alp Dener's avatar
Alp Dener committed
299
300
301
302
        grad,
        is_dbias=use_bias,
        flatten_axis=flatten_axis,
        quantizer=quantizer_set.dgrad,
303
304
        amax_scope=AmaxScope.TPSP,
        transpose_batch_sequence=transpose_batch_sequence,
305
    )
306
307
308
309
310
311
312
313
314
315
316
317

    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
    g_constracting_dim = tuple(
        range(grad.ndim - len(kernel_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
    )
    # k_non_contracting_dims
    k_constracting_dim = tuple(
        dim for dim in range(len(kernel_shape)) if dim not in k_contracting_dims_in_fwd
    )

    # NT GEMM
    dgrad = tex.gemm(
318
319
        casted_grad.get_tensor(TensorUsage.LHS),
        casted_kernel,
Alp Dener's avatar
Alp Dener committed
320
        contracting_dims=(g_constracting_dim, k_constracting_dim),
321
        transpose_batch_sequence=transpose_batch_sequence,
322
323
324
325
326
327
328
329
330
331
    )

    dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)

    g_constracting_dim = x_constracting_dim = tuple(
        range(0, len(x_shape) - len(x_contracting_dims_in_fwd))
    )

    # TN GEMM
    wgrad = tex.gemm(
332
333
        casted_ln_out,
        casted_grad.get_tensor(TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
334
        contracting_dims=(x_constracting_dim, g_constracting_dim),
335
        transpose_batch_sequence=transpose_batch_sequence,
336
337
    )

338
339
    wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)

340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    dx, dgamma, dbeta = tex.normalization_bwd(
        dgrad,
        x,
        mu,
        rsigma,
        gamma,
        beta,
        zero_centered_gamma=zero_centered_gamma,
        epsilon=epsilon,
        norm_type=norm_type,
    )

    return dx, wgrad, dgamma, dbeta, dbias, quantizer_set


_layernorm_dense.defvjp(_layernorm_dense_fwd_rule, _layernorm_dense_bwd_rule)