layernorm_dense.py 9.05 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused Layer normalization and dense layer transformation operations for Transformer Engine in JAX.

This module provides optimized implementations of layer normalization followed by
dense layer transformation (GEMM) operations, which are commonly used in transformer
architectures. It supports various normalization types, quantization, and
distributed training through sharding constraints.
"""

from functools import partial
from typing import Tuple

import jax
import jax.numpy as jnp

from . import cpp_extensions as tex

from .quantize import (
    QuantizerSet,
    noop_quantizer_set,
    with_sharding_constraint_by_logical_axes,
)


def layernorm_dense(
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    bias: jnp.ndarray = None,
    norm_type: str = "layernorm",
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
    layernorm_input_axes: Tuple[str, ...] = None,
    dot_input_axes: Tuple[str, ...] = None,
38
    kernel_axes: Tuple[str, ...] = None,
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray:
    """Apply layer normalization followed by dense layer transformation.

    This function implements the following sequence of operations:
        1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
        2. Linear transformation: y = x * kernel + bias

    Args:
        x: Input tensor with shape [batch..., hidden_in]
        kernel: Weight matrix with shape [hidden_in, hidden_out]
        gamma: Scale parameter for normalization with shape [hidden_in]
        beta: Bias parameter for normalization with shape [hidden_in]
        bias: Optional bias term for dense layer transformation with shape [hidden_out]
        norm_type: Type of normalization ("layernorm" or "rmsnorm")
        zero_centered_gamma: Whether to use zero-centered gamma for normalization
        epsilon: Small constant for numerical stability in normalization
        layernorm_input_axes: Logical axes for sharding the layernorm input
        dot_input_axes: Logical axes for sharding the matrix multiplication input
58
        kernel_axes: Logical axes for sharding the weight matrix
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        quantizer_set: Set of quantizers for different tensor types

    Returns:
        Output tensor with shape [batch..., hidden_out]

    Note:
        - For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
          must be False
        - The function supports automatic differentiation through JAX's custom VJP
        - Quantization is applied to both the normalized input and kernel
    """
    output = _layernorm_dense(
        x,
        kernel,
        gamma,
        beta,
        bias,
        norm_type,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_input_axes,
81
        kernel_axes,
82
83
84
85
86
87
88
89
90
91
92
93
94
        quantizer_set,
    )
    return output


@partial(
    jax.custom_vjp,
    nondiff_argnums=(
        5,
        6,
        7,
        8,
        9,
95
        10,
96
97
98
99
100
101
102
103
104
105
106
107
108
    ),
)
def _layernorm_dense(
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    bias: jnp.ndarray,
    norm_type: str,
    zero_centered_gamma: bool,
    epsilon: float,
    layernorm_input_axes: Tuple[str, ...],
    dot_input_axes: Tuple[str, ...],
109
    kernel_axes: Tuple[str, ...],
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    quantizer_set,
):
    """Internal implementation of layernorm_dense with custom VJP.

    This function implements the forward pass of layernorm_dense with support for
    automatic differentiation. It handles the normalization and dense layer transformation
    operations, including quantization and sharding constraints.

    Args:
        x: Input tensor
        kernel: Weight matrix
        gamma: Scale parameter for normalization
        beta: Bias parameter for normalization
        bias: Optional bias term
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
        layernorm_input_axes: Logical axes for layernorm sharding
        dot_input_axes: Logical axes for matrix multiplication sharding
        quantizer_set: Set of quantizers

    Returns:
        Output tensor from the combined operations
    """
    output, _ = _layernorm_dense_fwd_rule(
        x,
        kernel,
        gamma,
        beta,
        bias,
        norm_type,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_input_axes,
145
        kernel_axes,
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        quantizer_set,
    )
    return output


def _layernorm_dense_fwd_rule(
    x,
    kernel,
    gamma,
    beta,
    bias,
    norm_type,
    zero_centered_gamma,
    epsilon,
    layernorm_input_axes,
    dot_input_axes,
162
    kernel_axes,
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    quantizer_set,
):
    """Forward pass rule for layernorm_dense.

    Implements the forward pass computation including:
    1. Layer normalization with quantization
    2. Matrix multiplication with quantized kernel
    3. Optional bias addition
    4. Sharding constraints

    Returns:
        Tuple of (output, context) for automatic differentiation
    """
    x_contracting_dims = (len(x.shape) - 1,)
    k_contracting_dims = (0,)
    assert x.shape[-1] == kernel.shape[0]

    x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)

    casted_ln_out, mu, rsigma = tex.normalization_fwd(
        x,
        gamma,
        beta,
        zero_centered_gamma,
        epsilon,
        norm_type,
        quantizer_set.x,
    )
191
    casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
192
193

    # Kernel in (hidden_in, hidden_out...)
194
195
196
    flatten_axis = 1 - len(kernel.shape)
    casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel)
    casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224

    # NN GEMM
    # (batch..., hidden_in) x (hidden_in, hidden_out...)
    output = tex.gemm(
        casted_ln_out.get_rowwise_tensor(),
        casted_kernel.get_colwise_tensor(),
        (x_contracting_dims, k_contracting_dims),
    )

    use_bias = bias is not None
    if use_bias:
        bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
        output += jnp.reshape(bias, bias_new_shape)

    ctx = (
        casted_ln_out.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None,
        casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None,
        x.shape,
        kernel.shape,
        mu,
        rsigma,
        x,
        gamma,
        beta,
        x_contracting_dims,
        k_contracting_dims,
        use_bias,
        quantizer_set,
225
        flatten_axis,
226
227
228
229
230
231
232
233
234
235
236
    )

    return output, ctx


def _layernorm_dense_bwd_rule(
    norm_type,
    zero_centered_gamma,
    epsilon,
    layernorm_input_axes,
    dot_input_axes,  # pylint: disable=unused-argument
237
    kernel_axes,
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    ctx,
    grad,
):
    """Backward pass rule for layernorm_dense.

    Implements the backward pass computation including:
    1. Gradient computation for matrix multiplication
    2. Gradient computation for layer normalization
    3. Gradient computation for bias terms
    4. Proper handling of quantization

    Returns:
        Tuple of gradients for all input parameters
    """
    (
        colwise_casted_ln_out,
        rowwise_casted_kernel,
        x_shape,
        kernel_shape,
        mu,
        rsigma,
        x,
        gamma,
        beta,
        x_contracting_dims_in_fwd,
        k_contracting_dims_in_fwd,
        use_bias,
        quantizer_set,
266
        flatten_axis,
267
268
    ) = ctx

269
270
271
    casted_grad, dbias = tex.quantize_dbias(
        grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad
    )
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301

    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
    g_constracting_dim = tuple(
        range(grad.ndim - len(kernel_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
    )
    # k_non_contracting_dims
    k_constracting_dim = tuple(
        dim for dim in range(len(kernel_shape)) if dim not in k_contracting_dims_in_fwd
    )

    # NT GEMM
    dgrad = tex.gemm(
        casted_grad.get_rowwise_tensor(),
        rowwise_casted_kernel,
        (g_constracting_dim, k_constracting_dim),
    )

    dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)

    g_constracting_dim = x_constracting_dim = tuple(
        range(0, len(x_shape) - len(x_contracting_dims_in_fwd))
    )

    # TN GEMM
    wgrad = tex.gemm(
        colwise_casted_ln_out,
        casted_grad.get_colwise_tensor(),
        (x_constracting_dim, g_constracting_dim),
    )

302
303
    wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)

304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
    dx, dgamma, dbeta = tex.normalization_bwd(
        dgrad,
        x,
        mu,
        rsigma,
        gamma,
        beta,
        zero_centered_gamma=zero_centered_gamma,
        epsilon=epsilon,
        norm_type=norm_type,
    )

    return dx, wgrad, dgamma, dbeta, dbias, quantizer_set


_layernorm_dense.defvjp(_layernorm_dense_fwd_rule, _layernorm_dense_bwd_rule)