layernorm.py 4.34 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
"""Layer normalization operations for Transformer Engine in JAX.

This module provides optimized layer normalization operations for transformer
architectures, including support for different normalization types and quantization.
It implements various normalization strategies like LayerNorm and RMSNorm, with
optional zero-centered gamma and epsilon parameters.
"""
11

12
from functools import partial
13

14
15
16
import jax
import jax.numpy as jnp

17
from . import cpp_extensions as tex
18

19
20
21
from .quantize import (
    Quantizer,
)
22

23
24
25
26
27
28
29
30
31

def canonicalize_norm_type(x):
    """Convert normalization type string to canonical form.

    Args:
        x: Input normalization type string

    Returns:
        Canonicalized normalization type string
32
33
34
    """
    canonicalized = x.lower().strip().replace("-", "").replace("_", "")
    assert canonicalized in ["layernorm", "rmsnorm"]
35
36
37
    return canonicalized


38
def layernorm(
39
    x: jnp.ndarray,
40
41
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
42
    norm_type: str,
43
44
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
45
    quantizer: Quantizer = None,
46
):
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    """Apply layer normalization with optional quantization.

    This function implements layer normalization with support for different
    normalization types and optional quantization. It normalizes the input
    tensor using the provided gamma and beta parameters.

    Args:
        x: Input tensor to normalize
        gamma: Scale parameter for normalization
        beta: Shift parameter for normalization
        norm_type: Type of normalization to apply
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
        quantizer: Optional quantizer for quantizing the output

    Returns:
        Normalized output tensor
64
    """
65
    output = _layernorm(x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer)
66
67
68
    return output


69
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def _layernorm(x, gamma, beta, norm_type: str, zero_centered_gamma, epsilon, quantizer):
    """Internal implementation of layer normalization with custom VJP.

    This function implements the core layer normalization logic with support
    for custom vector-Jacobian product (VJP) for automatic differentiation.

    Args:
        x: Input tensor
        gamma: Scale parameter
        beta: Shift parameter
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
        quantizer: Optional quantizer

    Returns:
        Normalized tensor
87
    """
88
89
    output, _ = _layernorm_fwd_rule(
        x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer
90
    )
91
92
93
    return output


94
95
def _layernorm_fwd_rule(x, gamma, beta, norm_type: str, zero_centered_gamma, epsilon, quantizer):
    """Forward pass rule for layer normalization.
96

97
98
99
100
101
102
103
104
    Args:
        x: Input tensor
        gamma: Scale parameter
        beta: Shift parameter
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
        quantizer: Optional quantizer
105

106
107
108
    Returns:
        Tuple of (output, context) for backward pass
    """
109

110
111
112
    norm_type = canonicalize_norm_type(norm_type)
    output, mu, rsigma = tex.normalization_fwd(
        x, gamma, beta, zero_centered_gamma, epsilon, norm_type, quantizer
113
    )
114
115
    # This is a no-op for higher-precision tensors
    output = output.dequantize()
116

117
    return output, (x, mu, rsigma, gamma, beta, quantizer)
118
119


120
121
def _layernorm_bwd_rule(norm_type, zero_centered_gamma, epsilon, ctx, dz):
    """Backward pass rule for layer normalization.
122

123
124
125
126
127
128
    Args:
        norm_type: Type of normalization
        zero_centered_gamma: Whether to use zero-centered gamma
        epsilon: Small constant for numerical stability
        ctx: Context from forward pass
        dz: Gradient from upstream
129

130
131
132
133
    Returns:
        Tuple of gradients with respect to inputs
    """
    x, mu, rsigma, gamma, beta, quantizer = ctx
134

135
136
    dx, dgamma, dbeta = tex.normalization_bwd(
        dz, x, mu, rsigma, gamma, beta, zero_centered_gamma, epsilon, norm_type
137
    )
138
    return dx, dgamma, dbeta, quantizer
139

140

141
_layernorm.defvjp(_layernorm_fwd_rule, _layernorm_bwd_rule)