layernorm.py 11.2 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX layernorm modules"""

6
from functools import partial
7
8
from typing import Tuple

9
10
11
import jax
import jax.numpy as jnp

12
from .cpp_extensions import cast_fp8, cast_transpose, transpose
13
14
from .cpp_extensions import rmsnorm_fwd, rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd, layernorm_fwd_fp8, layernorm_bwd
15
from .dot import fp8_dot_impl, get_precision_of_fp8_dot
16
from .fp8 import FP8Helper, FP8MetaPackage
17
from .sharding import with_sharding_constraint_by_logical_axes
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32


def canonicalize_layernorm_type(x):
    '''
    Canonicalize the layernorm type
    '''
    canonicalized = x.lower().strip().replace('-', '').replace('_', '')
    assert canonicalized in ['layernorm', 'rmsnorm']
    return canonicalized


def layernorm(inputs: jnp.ndarray,
              gamma: jnp.ndarray,
              beta: jnp.ndarray,
              layernorm_type: str,
33
              zero_centered_gamma: bool = False,
34
              epsilon: float = 1e-6):
35
    """
36
37
    LN/RMSNorm  wrapper
    Only support layernorm_type in ['layernorm', 'rmsnorm']
38
    """
39
40
41
42
43
44
    output = _layernorm(inputs,
                        gamma,
                        beta,
                        layernorm_type=layernorm_type,
                        zero_centered_gamma=zero_centered_gamma,
                        epsilon=epsilon)
45
46
47
    return output


48
49
50
51
52
53
54
55
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
def _layernorm(x,
               gamma,
               beta,
               layernorm_type: str,
               zero_centered_gamma: bool = False,
               epsilon: float = 1e-6):
    output, _ = _layernorm_fwd_rule(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon)
56
57
58
    return output


59
60
61
62
63
64
65
def _layernorm_fwd_rule(x,
                        gamma,
                        beta,
                        layernorm_type: str,
                        zero_centered_gamma: bool = False,
                        epsilon: float = 1e-6):
    layernorm_type = canonicalize_layernorm_type(layernorm_type)
66
    if layernorm_type == 'layernorm':
67
        output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon)
68
    elif layernorm_type == 'rmsnorm':
69
70
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
71
72
        output, rsigma = rmsnorm_fwd(x, gamma, epsilon)
        mu = None
73
74
75
    else:
        raise ValueError(f"{layernorm_type=} is not supported.")
    return output, (x, mu, rsigma, gamma)
76
77


78
79
def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz):
    x, mu, rsigma, gamma = ctx
80
    if layernorm_type == 'layernorm':
81
82
83
84
85
86
87
88
        dx, dgamma, dbeta = layernorm_bwd(dz,
                                          x,
                                          mu,
                                          rsigma,
                                          gamma,
                                          zero_centered_gamma=zero_centered_gamma,
                                          epsilon=epsilon)
    elif layernorm_type == 'rmsnorm':
89
90
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
91
92
93
94
        dx, dgamma = rmsnorm_bwd(dz, x, rsigma, gamma, epsilon=epsilon)
        dbeta = None
    else:
        raise ValueError(f"{layernorm_type=} is not supported.")
95

96
    return dx, dgamma, dbeta
97
98


99
_layernorm.defvjp(_layernorm_fwd_rule, _layernorm_bwd_rule)
100
101


102
103
104
105
106
107
108
109
110
111
112
113
114
115
def layernorm_fp8_dot(
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    fp8_meta_pkg: FP8MetaPackage,
    layernorm_type: str,
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
    layernorm_input_axes: Tuple[
        str, ...] = None,    # The logic axes of sharding constraint to the layernorm input.
    dot_input_axes: Tuple[str,
                          ...] = None    # The logic axes of sharding constraint to the dot input.
) -> jnp.ndarray:
116
    """
117
    Layernorm + FP8 GEMM
118
    """
119
120
121
122
123
124
125
    fp8_max = fp8_meta_pkg.fp8_max
    amax = fp8_meta_pkg.amax
    scale = fp8_meta_pkg.scale
    scale_inv = fp8_meta_pkg.scale_inv
    fwd_dtype = FP8Helper.FWD_DTYPE
    bwd_dtype = FP8Helper.BWD_DTYPE
    output = _layernorm_fp8_dot(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv,
126
127
                                layernorm_type, fwd_dtype, bwd_dtype, zero_centered_gamma, epsilon,
                                layernorm_input_axes, dot_input_axes)
128
129
130
    return output


131
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14))
132
133
134
def _layernorm_fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
                       fp8_max: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
                       scale_inv: jnp.ndarray, layernorm_type: str, fwd_dtype: jnp.dtype,
135
136
                       bwd_dtype: jnp.dtype, zero_centered_gamma: bool, epsilon: float,
                       layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...]):
137
138
    output, _ = _layernorm_fp8_dot_fwd_rule(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv,
                                            layernorm_type, fwd_dtype, bwd_dtype,
139
140
                                            zero_centered_gamma, epsilon, layernorm_input_axes,
                                            dot_input_axes)
141
142
143
    return output


144
145
def _layernorm_fp8_dot_fwd_rule(
        x,
146
147
148
        kernel,
        gamma,
        beta,
149
        fp8_max,
150
151
152
153
154
155
        amax,
        scale,
        scale_inv,
        layernorm_type,
        fwd_dtype,
        bwd_dtype,    # pylint: disable=unused-argument
156
        zero_centered_gamma,
157
158
159
        epsilon,
        layernorm_input_axes,
        dot_input_axes):
160
161
162
163

    x_contracting_dims = (len(x.shape) - 1,)
    k_contracting_dims = (0,)
    assert x.shape[-1] == kernel.shape[0]
164

165
166
    amax = FP8Helper.update_amax_history(amax)

167
168
169
170
171
    gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)

    x_amax = amax[gemm_x_idx, 0:1]
    x_scale = scale[gemm_x_idx]
    x_scale_inv = scale_inv[gemm_x_idx]
172

173
174
    x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)

175
    if layernorm_type == 'layernorm':
176
177
178
179
180
181
182
183
184
185
        ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
            x,
            gamma,
            beta,
            x_amax,
            x_scale,
            x_scale_inv,
            out_dtype=fwd_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon)
186
    else:
187
188
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
189
190
191
192
193
194
195
        ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
                                                         gamma,
                                                         x_amax,
                                                         x_scale,
                                                         x_scale_inv,
                                                         out_dtype=fwd_dtype,
                                                         epsilon=epsilon)
196
197
        mu = None

198
    assert x.shape == ln_out.shape
199

200
    kernel_amax = amax[gemm_kernel_idx, 0:1]
201
202
203
    kernel_scale = scale[gemm_kernel_idx]
    kernel_scale_inv = scale_inv[gemm_kernel_idx]

204
    # Kernel in (hidden_in, hidden_out...)
205
206
207
208
    # Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel, updated_kernel_amax = \
        cast_fp8(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype)
209

210
211
    ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_input_axes)

212
    # (batch..., hidden_in) x (hidden_in, hidden_out...)
213
214
215
    output = fp8_dot_impl(ln_out, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype,
                          (x_contracting_dims, k_contracting_dims),
                          get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
216

217
    ctx = (ln_out, casted_kernel, fp8_max, amax, scale, scale_inv, updated_x_amax,
218
219
           updated_kernel_amax, x.shape, kernel.shape, mu, rsigma, x, gamma, x_contracting_dims,
           k_contracting_dims)
220
221
222
223

    return output, ctx


224
def _layernorm_fp8_dot_bwd_rule(
225
        layernorm_type,
226
        fwd_dtype,    # pylint: disable=unused-argument
227
        bwd_dtype,
228
229
        zero_centered_gamma,
        epsilon,
230
231
        layernorm_input_axes,
        dot_input_axes,    # pylint: disable=unused-argument
232
        ctx,
233
        grad):
234
    ln_out_, casted_kernel, fp8_max, amax, scale, scale_inv, \
235
236
237
    updated_x_amax, updated_kernel_amax, \
    x_shape, kernel_shape, mu, rsigma, x, gamma, \
    x_contracting_dims, k_contracting_dims = ctx
238

239
240
241
    ln_out_t = transpose(ln_out_, static_axis_boundary=-1, transpose_axis_boundary=-1)

    gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
242

243
    grad_amax = amax[gemm_grad_idx, 0:1]
244
245
246
    grad_scale = scale[gemm_grad_idx]
    grad_scale_inv = scale_inv[gemm_grad_idx]

247
248
249
    casted_grad, casted_grad_t, updated_grad_amax = \
        cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
                       static_axis_boundary=-1, transpose_axis_boundary=min(x_contracting_dims))
250

251
252
253
254
    xt_constracting_dim = tuple(range(len(x_contracting_dims), len(x_shape)))
    gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim))
    x_scale_inv = scale_inv[gemm_x_idx]
    wgrad = fp8_dot_impl(ln_out_t, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
255
256
                         (xt_constracting_dim, gt_constracting_dim),
                         get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
257

258
    g_for_dgrad_constracting_dim = tuple(
259
260
        range(grad.ndim - len(kernel_shape) + len(k_contracting_dims), grad.ndim))
    k_constracting_dim = tuple(range(len(k_contracting_dims), len(kernel_shape)))
261
    kernel_scale_inv = scale_inv[gemm_kernel_idx]
262
263
264
    dgrad = fp8_dot_impl(casted_grad, casted_kernel, grad_scale_inv, kernel_scale_inv, grad.dtype,
                         (g_for_dgrad_constracting_dim, k_constracting_dim),
                         get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
265

266
    dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
267
    if layernorm_type == 'layernorm':
268
269
270
271
272
273
274
        dx, dgamma, dbeta = layernorm_bwd(dgrad,
                                          x,
                                          mu,
                                          rsigma,
                                          gamma,
                                          zero_centered_gamma=zero_centered_gamma,
                                          epsilon=epsilon)
275
    else:
276
277
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
278
279
        dx, dgamma = rmsnorm_bwd(dgrad, x, rsigma, gamma, epsilon=epsilon)
        dbeta = None
280

281
282
283
    amax = amax.at[gemm_x_idx, 0].set(updated_x_amax[0])
    amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax[0])
    amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0])
284

285
    scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
286

287
288
289
    return dx, wgrad, \
           dgamma, dbeta, \
           fp8_max, amax, scale, scale_inv
290
291


292
_layernorm_fp8_dot.defvjp(_layernorm_fp8_dot_fwd_rule, _layernorm_fp8_dot_bwd_rule)