layernorm.py 10.1 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX layernorm modules"""

6
from functools import partial
7
8
9
import jax
import jax.numpy as jnp

10
from .cpp_extensions import cast_transpose, transpose
11
12
from .cpp_extensions import rmsnorm_fwd, rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd, layernorm_fwd_fp8, layernorm_bwd
13
14
from .dot import fp8_dot_impl
from .fp8 import FP8Helper, FP8MetaPackage
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29


def canonicalize_layernorm_type(x):
    '''
    Canonicalize the layernorm type
    '''
    canonicalized = x.lower().strip().replace('-', '').replace('_', '')
    assert canonicalized in ['layernorm', 'rmsnorm']
    return canonicalized


def layernorm(inputs: jnp.ndarray,
              gamma: jnp.ndarray,
              beta: jnp.ndarray,
              layernorm_type: str,
30
              zero_centered_gamma: bool = False,
31
              epsilon: float = 1e-6):
32
    """
33
34
    LN/RMSNorm  wrapper
    Only support layernorm_type in ['layernorm', 'rmsnorm']
35
    """
36
37
38
39
40
41
    output = _layernorm(inputs,
                        gamma,
                        beta,
                        layernorm_type=layernorm_type,
                        zero_centered_gamma=zero_centered_gamma,
                        epsilon=epsilon)
42
43
44
    return output


45
46
47
48
49
50
51
52
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
def _layernorm(x,
               gamma,
               beta,
               layernorm_type: str,
               zero_centered_gamma: bool = False,
               epsilon: float = 1e-6):
    output, _ = _layernorm_fwd_rule(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon)
53
54
55
    return output


56
57
58
59
60
61
62
def _layernorm_fwd_rule(x,
                        gamma,
                        beta,
                        layernorm_type: str,
                        zero_centered_gamma: bool = False,
                        epsilon: float = 1e-6):
    layernorm_type = canonicalize_layernorm_type(layernorm_type)
63
    if layernorm_type == 'layernorm':
64
        output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon)
65
    elif layernorm_type == 'rmsnorm':
66
67
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
68
69
        output, rsigma = rmsnorm_fwd(x, gamma, epsilon)
        mu = None
70
71
72
    else:
        raise ValueError(f"{layernorm_type=} is not supported.")
    return output, (x, mu, rsigma, gamma)
73
74


75
76
def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz):
    x, mu, rsigma, gamma = ctx
77
    if layernorm_type == 'layernorm':
78
79
80
81
82
83
84
85
        dx, dgamma, dbeta = layernorm_bwd(dz,
                                          x,
                                          mu,
                                          rsigma,
                                          gamma,
                                          zero_centered_gamma=zero_centered_gamma,
                                          epsilon=epsilon)
    elif layernorm_type == 'rmsnorm':
86
87
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
88
89
90
91
        dx, dgamma = rmsnorm_bwd(dz, x, rsigma, gamma, epsilon=epsilon)
        dbeta = None
    else:
        raise ValueError(f"{layernorm_type=} is not supported.")
92

93
    return dx, dgamma, dbeta
94
95


96
_layernorm.defvjp(_layernorm_fwd_rule, _layernorm_bwd_rule)
97
98


99
100
def layernorm_fp8_dot(x: jnp.ndarray,
                      kernel: jnp.ndarray,
101
102
                      gamma: jnp.ndarray,
                      beta: jnp.ndarray,
103
                      fp8_meta_pkg: FP8MetaPackage,
104
                      layernorm_type: str,
105
                      zero_centered_gamma: bool = False,
106
                      epsilon: float = 1e-6) -> jnp.ndarray:
107
    """
108
    Layernorm + FP8 GEMM
109
    """
110
111
112
113
114
115
116
117
    fp8_max = fp8_meta_pkg.fp8_max
    amax = fp8_meta_pkg.amax
    scale = fp8_meta_pkg.scale
    scale_inv = fp8_meta_pkg.scale_inv
    fwd_dtype = FP8Helper.FWD_DTYPE
    bwd_dtype = FP8Helper.BWD_DTYPE
    output = _layernorm_fp8_dot(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv,
                                layernorm_type, fwd_dtype, bwd_dtype, zero_centered_gamma, epsilon)
118
119
120
    return output


121
122
123
124
125
126
127
128
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12))
def _layernorm_fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
                       fp8_max: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
                       scale_inv: jnp.ndarray, layernorm_type: str, fwd_dtype: jnp.dtype,
                       bwd_dtype: jnp.dtype, zero_centered_gamma: bool, epsilon: float):
    output, _ = _layernorm_fp8_dot_fwd_rule(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv,
                                            layernorm_type, fwd_dtype, bwd_dtype,
                                            zero_centered_gamma, epsilon)
129
130
131
    return output


132
133
def _layernorm_fp8_dot_fwd_rule(
        x,
134
135
136
        kernel,
        gamma,
        beta,
137
        fp8_max,
138
139
140
141
142
143
        amax,
        scale,
        scale_inv,
        layernorm_type,
        fwd_dtype,
        bwd_dtype,    # pylint: disable=unused-argument
144
        zero_centered_gamma,
145
146
147
148
149
        epsilon):

    x_contracting_dims = (len(x.shape) - 1,)
    k_contracting_dims = (0,)
    assert x.shape[-1] == kernel.shape[0]
150

151
152
    amax = FP8Helper.update_amax_history(amax)

153
154
155
156
157
    gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)

    x_amax = amax[gemm_x_idx, 0:1]
    x_scale = scale[gemm_x_idx]
    x_scale_inv = scale_inv[gemm_x_idx]
158
159

    if layernorm_type == 'layernorm':
160
161
162
163
164
165
166
167
168
169
        ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
            x,
            gamma,
            beta,
            x_amax,
            x_scale,
            x_scale_inv,
            out_dtype=fwd_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon)
170
    else:
171
172
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
173
174
175
176
177
178
179
        ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
                                                         gamma,
                                                         x_amax,
                                                         x_scale,
                                                         x_scale_inv,
                                                         out_dtype=fwd_dtype,
                                                         epsilon=epsilon)
180
181
        mu = None

182
    assert x.shape == ln_out.shape
183

184
    kernel_amax = amax[gemm_kernel_idx, 0:1]
185
186
187
    kernel_scale = scale[gemm_kernel_idx]
    kernel_scale_inv = scale_inv[gemm_kernel_idx]

188
189
190
191
    # Kernel in (hidden_in, hidden_out...)
    casted_kerenl, casted_kerenl_t, updated_kernel_amax = \
        cast_transpose(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype,
                       static_axis_boundary=-1, transpose_axis_boundary=1)
192

193
194
195
196
    # (batch..., hidden_in) x (hidden_in, hidden_out...)
    kt_contracting_dims = (kernel.ndim - 1,)
    output = fp8_dot_impl(ln_out, casted_kerenl_t, x_scale_inv, kernel_scale_inv, x.dtype,
                          (x_contracting_dims, kt_contracting_dims))
197

198
199
200
    ctx = (ln_out, casted_kerenl, fp8_max, amax, scale, scale_inv, updated_x_amax,
           updated_kernel_amax, x.shape, kernel.shape, mu, rsigma, x, gamma, x_contracting_dims,
           k_contracting_dims)
201
202
203
204

    return output, ctx


205
def _layernorm_fp8_dot_bwd_rule(
206
        layernorm_type,
207
        fwd_dtype,    # pylint: disable=unused-argument
208
        bwd_dtype,
209
210
        zero_centered_gamma,
        epsilon,
211
        ctx,
212
213
214
215
216
        grad):
    ln_out_, casted_kerenl, fp8_max, amax, scale, scale_inv, \
    updated_x_amax, updated_kernel_amax, \
    x_shape, kernel_shape, mu, rsigma, x, gamma, \
    x_contracting_dims, k_contracting_dims = ctx
217

218
219
220
    ln_out_t = transpose(ln_out_, static_axis_boundary=-1, transpose_axis_boundary=-1)

    gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
221

222
    grad_amax = amax[gemm_grad_idx, 0:1]
223
224
225
    grad_scale = scale[gemm_grad_idx]
    grad_scale_inv = scale_inv[gemm_grad_idx]

226
227
228
    casted_grad, casted_grad_t, updated_grad_amax = \
        cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
                       static_axis_boundary=-1, transpose_axis_boundary=min(x_contracting_dims))
229

230
231
232
233
234
    xt_constracting_dim = tuple(range(len(x_contracting_dims), len(x_shape)))
    gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim))
    x_scale_inv = scale_inv[gemm_x_idx]
    wgrad = fp8_dot_impl(ln_out_t, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
                         (xt_constracting_dim, gt_constracting_dim))
235

236
237
238
    g_constracting_dim = tuple(
        range(grad.ndim - len(kernel_shape) + len(k_contracting_dims), grad.ndim))
    k_constracting_dim = tuple(range(len(k_contracting_dims), len(kernel_shape)))
239
    kernel_scale_inv = scale_inv[gemm_kernel_idx]
240
241
    dgrad = fp8_dot_impl(casted_grad, casted_kerenl, grad_scale_inv, kernel_scale_inv, grad.dtype,
                         (g_constracting_dim, k_constracting_dim))
242
243

    if layernorm_type == 'layernorm':
244
245
246
247
248
249
250
        dx, dgamma, dbeta = layernorm_bwd(dgrad,
                                          x,
                                          mu,
                                          rsigma,
                                          gamma,
                                          zero_centered_gamma=zero_centered_gamma,
                                          epsilon=epsilon)
251
    else:
252
253
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
254
255
        dx, dgamma = rmsnorm_bwd(dgrad, x, rsigma, gamma, epsilon=epsilon)
        dbeta = None
256

257
258
259
    amax = amax.at[gemm_x_idx, 0].set(updated_x_amax[0])
    amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax[0])
    amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0])
260

261
    scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
262

263
264
265
    return dx, wgrad, \
           dgamma, dbeta, \
           fp8_max, amax, scale, scale_inv
266
267


268
_layernorm_fp8_dot.defvjp(_layernorm_fp8_dot_fwd_rule, _layernorm_fp8_dot_bwd_rule)