layernorm.py 10.4 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX layernorm modules"""

6
from functools import partial
7
8
9
import jax
import jax.numpy as jnp

10
from .cpp_extensions import cast_fp8, cast_transpose, transpose
11
12
from .cpp_extensions import rmsnorm_fwd, rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd, layernorm_fwd_fp8, layernorm_bwd
13
from .dot import fp8_dot_impl, get_precision_of_fp8_dot
14
from .fp8 import FP8Helper, FP8MetaPackage
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29


def canonicalize_layernorm_type(x):
    '''
    Canonicalize the layernorm type
    '''
    canonicalized = x.lower().strip().replace('-', '').replace('_', '')
    assert canonicalized in ['layernorm', 'rmsnorm']
    return canonicalized


def layernorm(inputs: jnp.ndarray,
              gamma: jnp.ndarray,
              beta: jnp.ndarray,
              layernorm_type: str,
30
              zero_centered_gamma: bool = False,
31
              epsilon: float = 1e-6):
32
    """
33
34
    LN/RMSNorm  wrapper
    Only support layernorm_type in ['layernorm', 'rmsnorm']
35
    """
36
37
38
39
40
41
    output = _layernorm(inputs,
                        gamma,
                        beta,
                        layernorm_type=layernorm_type,
                        zero_centered_gamma=zero_centered_gamma,
                        epsilon=epsilon)
42
43
44
    return output


45
46
47
48
49
50
51
52
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
def _layernorm(x,
               gamma,
               beta,
               layernorm_type: str,
               zero_centered_gamma: bool = False,
               epsilon: float = 1e-6):
    output, _ = _layernorm_fwd_rule(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon)
53
54
55
    return output


56
57
58
59
60
61
62
def _layernorm_fwd_rule(x,
                        gamma,
                        beta,
                        layernorm_type: str,
                        zero_centered_gamma: bool = False,
                        epsilon: float = 1e-6):
    layernorm_type = canonicalize_layernorm_type(layernorm_type)
63
    if layernorm_type == 'layernorm':
64
        output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon)
65
    elif layernorm_type == 'rmsnorm':
66
67
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
68
69
        output, rsigma = rmsnorm_fwd(x, gamma, epsilon)
        mu = None
70
71
72
    else:
        raise ValueError(f"{layernorm_type=} is not supported.")
    return output, (x, mu, rsigma, gamma)
73
74


75
76
def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz):
    x, mu, rsigma, gamma = ctx
77
    if layernorm_type == 'layernorm':
78
79
80
81
82
83
84
85
        dx, dgamma, dbeta = layernorm_bwd(dz,
                                          x,
                                          mu,
                                          rsigma,
                                          gamma,
                                          zero_centered_gamma=zero_centered_gamma,
                                          epsilon=epsilon)
    elif layernorm_type == 'rmsnorm':
86
87
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
88
89
90
91
        dx, dgamma = rmsnorm_bwd(dz, x, rsigma, gamma, epsilon=epsilon)
        dbeta = None
    else:
        raise ValueError(f"{layernorm_type=} is not supported.")
92

93
    return dx, dgamma, dbeta
94
95


96
_layernorm.defvjp(_layernorm_fwd_rule, _layernorm_bwd_rule)
97
98


99
100
def layernorm_fp8_dot(x: jnp.ndarray,
                      kernel: jnp.ndarray,
101
102
                      gamma: jnp.ndarray,
                      beta: jnp.ndarray,
103
                      fp8_meta_pkg: FP8MetaPackage,
104
                      layernorm_type: str,
105
                      zero_centered_gamma: bool = False,
106
                      epsilon: float = 1e-6) -> jnp.ndarray:
107
    """
108
    Layernorm + FP8 GEMM
109
    """
110
111
112
113
114
115
116
117
    fp8_max = fp8_meta_pkg.fp8_max
    amax = fp8_meta_pkg.amax
    scale = fp8_meta_pkg.scale
    scale_inv = fp8_meta_pkg.scale_inv
    fwd_dtype = FP8Helper.FWD_DTYPE
    bwd_dtype = FP8Helper.BWD_DTYPE
    output = _layernorm_fp8_dot(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv,
                                layernorm_type, fwd_dtype, bwd_dtype, zero_centered_gamma, epsilon)
118
119
120
    return output


121
122
123
124
125
126
127
128
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12))
def _layernorm_fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
                       fp8_max: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
                       scale_inv: jnp.ndarray, layernorm_type: str, fwd_dtype: jnp.dtype,
                       bwd_dtype: jnp.dtype, zero_centered_gamma: bool, epsilon: float):
    output, _ = _layernorm_fp8_dot_fwd_rule(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv,
                                            layernorm_type, fwd_dtype, bwd_dtype,
                                            zero_centered_gamma, epsilon)
129
130
131
    return output


132
133
def _layernorm_fp8_dot_fwd_rule(
        x,
134
135
136
        kernel,
        gamma,
        beta,
137
        fp8_max,
138
139
140
141
142
143
        amax,
        scale,
        scale_inv,
        layernorm_type,
        fwd_dtype,
        bwd_dtype,    # pylint: disable=unused-argument
144
        zero_centered_gamma,
145
146
147
148
149
        epsilon):

    x_contracting_dims = (len(x.shape) - 1,)
    k_contracting_dims = (0,)
    assert x.shape[-1] == kernel.shape[0]
150

151
152
    amax = FP8Helper.update_amax_history(amax)

153
154
155
156
157
    gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)

    x_amax = amax[gemm_x_idx, 0:1]
    x_scale = scale[gemm_x_idx]
    x_scale_inv = scale_inv[gemm_x_idx]
158
159

    if layernorm_type == 'layernorm':
160
161
162
163
164
165
166
167
168
169
        ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
            x,
            gamma,
            beta,
            x_amax,
            x_scale,
            x_scale_inv,
            out_dtype=fwd_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon)
170
    else:
171
172
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
173
174
175
176
177
178
179
        ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
                                                         gamma,
                                                         x_amax,
                                                         x_scale,
                                                         x_scale_inv,
                                                         out_dtype=fwd_dtype,
                                                         epsilon=epsilon)
180
181
        mu = None

182
    assert x.shape == ln_out.shape
183

184
    kernel_amax = amax[gemm_kernel_idx, 0:1]
185
186
187
    kernel_scale = scale[gemm_kernel_idx]
    kernel_scale_inv = scale_inv[gemm_kernel_idx]

188
    # Kernel in (hidden_in, hidden_out...)
189
190
191
192
    # Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel, updated_kernel_amax = \
        cast_fp8(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype)
193

194
    # (batch..., hidden_in) x (hidden_in, hidden_out...)
195
196
197
    output = fp8_dot_impl(ln_out, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype,
                          (x_contracting_dims, k_contracting_dims),
                          get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
198

199
    ctx = (ln_out, casted_kernel, fp8_max, amax, scale, scale_inv, updated_x_amax,
200
201
           updated_kernel_amax, x.shape, kernel.shape, mu, rsigma, x, gamma, x_contracting_dims,
           k_contracting_dims)
202
203
204
205

    return output, ctx


206
def _layernorm_fp8_dot_bwd_rule(
207
        layernorm_type,
208
        fwd_dtype,    # pylint: disable=unused-argument
209
        bwd_dtype,
210
211
        zero_centered_gamma,
        epsilon,
212
        ctx,
213
        grad):
214
    ln_out_, casted_kernel, fp8_max, amax, scale, scale_inv, \
215
216
217
    updated_x_amax, updated_kernel_amax, \
    x_shape, kernel_shape, mu, rsigma, x, gamma, \
    x_contracting_dims, k_contracting_dims = ctx
218

219
220
221
    ln_out_t = transpose(ln_out_, static_axis_boundary=-1, transpose_axis_boundary=-1)

    gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
222

223
    grad_amax = amax[gemm_grad_idx, 0:1]
224
225
226
    grad_scale = scale[gemm_grad_idx]
    grad_scale_inv = scale_inv[gemm_grad_idx]

227
228
229
    casted_grad, casted_grad_t, updated_grad_amax = \
        cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
                       static_axis_boundary=-1, transpose_axis_boundary=min(x_contracting_dims))
230

231
232
233
234
    xt_constracting_dim = tuple(range(len(x_contracting_dims), len(x_shape)))
    gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim))
    x_scale_inv = scale_inv[gemm_x_idx]
    wgrad = fp8_dot_impl(ln_out_t, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
235
236
                         (xt_constracting_dim, gt_constracting_dim),
                         get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
237

238
    g_for_dgrad_constracting_dim = tuple(
239
240
        range(grad.ndim - len(kernel_shape) + len(k_contracting_dims), grad.ndim))
    k_constracting_dim = tuple(range(len(k_contracting_dims), len(kernel_shape)))
241
    kernel_scale_inv = scale_inv[gemm_kernel_idx]
242
243
244
    dgrad = fp8_dot_impl(casted_grad, casted_kernel, grad_scale_inv, kernel_scale_inv, grad.dtype,
                         (g_for_dgrad_constracting_dim, k_constracting_dim),
                         get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
245
246

    if layernorm_type == 'layernorm':
247
248
249
250
251
252
253
        dx, dgamma, dbeta = layernorm_bwd(dgrad,
                                          x,
                                          mu,
                                          rsigma,
                                          gamma,
                                          zero_centered_gamma=zero_centered_gamma,
                                          epsilon=epsilon)
254
    else:
255
256
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
257
258
        dx, dgamma = rmsnorm_bwd(dgrad, x, rsigma, gamma, epsilon=epsilon)
        dbeta = None
259

260
261
262
    amax = amax.at[gemm_x_idx, 0].set(updated_x_amax[0])
    amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax[0])
    amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0])
263

264
    scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
265

266
267
268
    return dx, wgrad, \
           dgamma, dbeta, \
           fp8_max, amax, scale, scale_inv
269
270


271
_layernorm_fp8_dot.defvjp(_layernorm_fp8_dot_fwd_rule, _layernorm_fp8_dot_bwd_rule)