layernorm.py 4.35 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
"""Layer normalization operations for Transformer Engine in JAX.

This module provides optimized layer normalization operations for transformer
architectures, including support for different normalization types and quantization.
It implements various normalization strategies like LayerNorm and RMSNorm, with
optional zero-centered gamma and epsilon parameters.
"""
11

12
from functools import partial
13

14
15
16
import jax
import jax.numpy as jnp

17
from . import cpp_extensions as tex
18

19
20
21
22
from .quantize import (
    ScaledTensor,
    Quantizer,
)
23

24
25
26
27
28
29
30
31
32

def canonicalize_norm_type(x):
    """Convert normalization type string to canonical form.

    Args:
        x: Input normalization type string

    Returns:
        Canonicalized normalization type string
33
34
35
    """
    canonicalized = x.lower().strip().replace("-", "").replace("_", "")
    assert canonicalized in ["layernorm", "rmsnorm"]
36
37
38
    return canonicalized


39
def layernorm(
40
    x: jnp.ndarray,
41
42
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
43
    norm_type: str,
44
45
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
46
    quantizer: Quantizer = None,
47
):
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    """Apply layer normalization with optional quantization.

    This function implements layer normalization with support for different
    normalization types and optional quantization. It normalizes the input
    tensor using the provided gamma and beta parameters.

    Args:
        x: Input tensor to normalize
        gamma: Scale parameter for normalization
        beta: Shift parameter for normalization
        norm_type: Type of normalization to apply
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
        quantizer: Optional quantizer for quantizing the output

    Returns:
        Normalized output tensor
65
    """
66
    output = _layernorm(x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer)
67
68
69
    return output


70
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def _layernorm(x, gamma, beta, norm_type: str, zero_centered_gamma, epsilon, quantizer):
    """Internal implementation of layer normalization with custom VJP.

    This function implements the core layer normalization logic with support
    for custom vector-Jacobian product (VJP) for automatic differentiation.

    Args:
        x: Input tensor
        gamma: Scale parameter
        beta: Shift parameter
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
        quantizer: Optional quantizer

    Returns:
        Normalized tensor
88
    """
89
90
    output, _ = _layernorm_fwd_rule(
        x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer
91
    )
92
93
94
    return output


95
96
def _layernorm_fwd_rule(x, gamma, beta, norm_type: str, zero_centered_gamma, epsilon, quantizer):
    """Forward pass rule for layer normalization.
97

98
99
100
101
102
103
104
105
    Args:
        x: Input tensor
        gamma: Scale parameter
        beta: Shift parameter
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
        quantizer: Optional quantizer
106

107
108
109
    Returns:
        Tuple of (output, context) for backward pass
    """
110

111
112
113
    norm_type = canonicalize_norm_type(norm_type)
    output, mu, rsigma = tex.normalization_fwd(
        x, gamma, beta, zero_centered_gamma, epsilon, norm_type, quantizer
114
    )
115
116
    if isinstance(output, ScaledTensor):
        output = output.dequantize()
117

118
    return output, (x, mu, rsigma, gamma, beta, quantizer)
119
120


121
122
def _layernorm_bwd_rule(norm_type, zero_centered_gamma, epsilon, ctx, dz):
    """Backward pass rule for layer normalization.
123

124
125
126
127
128
129
    Args:
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
        ctx: Context from forward pass
        dz: Gradient from upstream
130

131
132
133
134
    Returns:
        Tuple of gradients with respect to inputs
    """
    x, mu, rsigma, gamma, beta, quantizer = ctx
135

136
137
    dx, dgamma, dbeta = tex.normalization_bwd(
        dz, x, mu, rsigma, gamma, beta, zero_centered_gamma, epsilon, norm_type
138
    )
139
    return dx, dgamma, dbeta, quantizer
140

141

142
_layernorm.defvjp(_layernorm_fwd_rule, _layernorm_bwd_rule)