layernorm.py 11.7 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX layernorm modules"""

6
from functools import partial
7
from typing import List, Tuple
8

9
10
11
import jax
import jax.numpy as jnp

12
from . import cpp_extensions as tex
13
from .dot import fp8_dot_impl, get_precision_of_fp8_dot
14
from .fp8 import FP8Helper, FP8MetaPackage
15
from .sharding import with_sharding_constraint_by_logical_axes
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30


def canonicalize_layernorm_type(x):
    '''
    Canonicalize the layernorm type
    '''
    canonicalized = x.lower().strip().replace('-', '').replace('_', '')
    assert canonicalized in ['layernorm', 'rmsnorm']
    return canonicalized


def layernorm(inputs: jnp.ndarray,
              gamma: jnp.ndarray,
              beta: jnp.ndarray,
              layernorm_type: str,
31
              zero_centered_gamma: bool = False,
32
              epsilon: float = 1e-6):
33
    """
34
35
    LN/RMSNorm  wrapper
    Only support layernorm_type in ['layernorm', 'rmsnorm']
36
    """
37
38
39
40
41
42
    output = _layernorm(inputs,
                        gamma,
                        beta,
                        layernorm_type=layernorm_type,
                        zero_centered_gamma=zero_centered_gamma,
                        epsilon=epsilon)
43
44
45
    return output


46
47
48
49
50
51
52
53
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
def _layernorm(x,
               gamma,
               beta,
               layernorm_type: str,
               zero_centered_gamma: bool = False,
               epsilon: float = 1e-6):
    output, _ = _layernorm_fwd_rule(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon)
54
55
56
    return output


57
58
59
60
61
62
63
def _layernorm_fwd_rule(x,
                        gamma,
                        beta,
                        layernorm_type: str,
                        zero_centered_gamma: bool = False,
                        epsilon: float = 1e-6):
    layernorm_type = canonicalize_layernorm_type(layernorm_type)
64
    if layernorm_type == 'layernorm':
65
        output, mu, rsigma = tex.layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon)
66
    elif layernorm_type == 'rmsnorm':
67
68
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
69
        output, rsigma = tex.rmsnorm_fwd(x, gamma, epsilon)
70
        mu = None
71
72
73
    else:
        raise ValueError(f"{layernorm_type=} is not supported.")
    return output, (x, mu, rsigma, gamma)
74
75


76
77
def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz):
    x, mu, rsigma, gamma = ctx
78
    if layernorm_type == 'layernorm':
79
        dx, dgamma, dbeta = tex.layernorm_bwd(dz,
80
81
82
83
84
85
86
                                          x,
                                          mu,
                                          rsigma,
                                          gamma,
                                          zero_centered_gamma=zero_centered_gamma,
                                          epsilon=epsilon)
    elif layernorm_type == 'rmsnorm':
87
88
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
89
        dx, dgamma = tex.rmsnorm_bwd(dz, x, rsigma, gamma, epsilon=epsilon)
90
91
92
        dbeta = None
    else:
        raise ValueError(f"{layernorm_type=} is not supported.")
93

94
    return dx, dgamma, dbeta
95
96


97
_layernorm.defvjp(_layernorm_fwd_rule, _layernorm_bwd_rule)
98
99


100
101
102
103
104
105
106
107
108
109
110
111
112
113
def layernorm_fp8_dot(
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    fp8_meta_pkg: FP8MetaPackage,
    layernorm_type: str,
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
    layernorm_input_axes: Tuple[
        str, ...] = None,    # The logic axes of sharding constraint to the layernorm input.
    dot_input_axes: Tuple[str,
                          ...] = None    # The logic axes of sharding constraint to the dot input.
) -> jnp.ndarray:
114
    """
115
    Layernorm + FP8 GEMM
116
    """
117
118
    amax_list = fp8_meta_pkg.amax_list
    scale_list = fp8_meta_pkg.scale_list
119
120
    fwd_dtype = FP8Helper.FWD_DTYPE
    bwd_dtype = FP8Helper.BWD_DTYPE
121
122
    output = _layernorm_fp8_dot(x, kernel, gamma, beta, amax_list, scale_list, layernorm_type,
                                fwd_dtype, bwd_dtype, zero_centered_gamma, epsilon,
123
                                layernorm_input_axes, dot_input_axes)
124
125
126
    return output


127
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12))
128
def _layernorm_fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
129
130
131
                       amax_list: List[jnp.ndarray], scale_list: List[jnp.ndarray],
                       layernorm_type: str, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
                       zero_centered_gamma: bool, epsilon: float,
132
                       layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...]):
133
    output, _ = _layernorm_fp8_dot_fwd_rule(x, kernel, gamma, beta, amax_list, scale_list,
134
                                            layernorm_type, fwd_dtype, bwd_dtype,
135
136
                                            zero_centered_gamma, epsilon, layernorm_input_axes,
                                            dot_input_axes)
137
138
139
    return output


140
141
def _layernorm_fp8_dot_fwd_rule(
        x,
142
143
144
        kernel,
        gamma,
        beta,
145
146
        amax_list,
        scale_list,
147
148
149
        layernorm_type,
        fwd_dtype,
        bwd_dtype,    # pylint: disable=unused-argument
150
        zero_centered_gamma,
151
152
153
        epsilon,
        layernorm_input_axes,
        dot_input_axes):
154
155
156
157

    x_contracting_dims = (len(x.shape) - 1,)
    k_contracting_dims = (0,)
    assert x.shape[-1] == kernel.shape[0]
158

159
    maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
160
161
162
        FP8Helper.generate_fp8_meta_dtype_converter_pair(*amax_list, *scale_list)
    amax_list = maybe_fm32_to_fp32(*amax_list)
    scale_list = maybe_fm32_to_fp32(*scale_list)
163

164
165
166
167
    fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype]
    scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale(amax_list, scale_list,
                                                                 fp8_dtype_list)
    amax_list = FP8MetaPackage.update_amax_list(amax_list)
168

169
170
171
    x_amax = amax_list[FP8MetaPackage.INPUT_IDX][0:1]
    x_scale = scale_list[FP8MetaPackage.INPUT_IDX]
    x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX]
172

173
174
    x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)

175
    if layernorm_type == 'layernorm':
176
        ln_out, mu, rsigma, updated_x_amax = tex.layernorm_fwd_fp8(
177
178
179
180
181
182
183
184
185
            x,
            gamma,
            beta,
            x_amax,
            x_scale,
            x_scale_inv,
            out_dtype=fwd_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon)
186
    else:
187
188
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
189
        ln_out, rsigma, updated_x_amax = tex.rmsnorm_fwd_fp8(x,
190
191
192
193
194
195
                                                         gamma,
                                                         x_amax,
                                                         x_scale,
                                                         x_scale_inv,
                                                         out_dtype=fwd_dtype,
                                                         epsilon=epsilon)
196
197
        mu = None

198
    assert x.shape == ln_out.shape
199

200
201
202
    kernel_amax = amax_list[FP8MetaPackage.WEIGHT_IDX][0:1]
    kernel_scale = scale_list[FP8MetaPackage.WEIGHT_IDX]
    kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX]
203

204
    # Kernel in (hidden_in, hidden_out...)
205
206
207
    # Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel, updated_kernel_amax = \
208
        tex.cast_fp8(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype)
209

210
211
    ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_input_axes)

212
    # (batch..., hidden_in) x (hidden_in, hidden_out...)
213
214
215
    output = fp8_dot_impl(ln_out, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype,
                          (x_contracting_dims, k_contracting_dims),
                          get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
216

217
    ctx = (ln_out, casted_kernel, amax_list, scale_list, scale_inv_list, updated_x_amax,
218
           updated_kernel_amax, x.shape, kernel.shape, mu, rsigma, x, gamma, x_contracting_dims,
219
           k_contracting_dims, maybe_fp32_to_fm32)
220
221
222
223

    return output, ctx


224
def _layernorm_fp8_dot_bwd_rule(
225
        layernorm_type,
226
        fwd_dtype,    # pylint: disable=unused-argument
227
        bwd_dtype,
228
229
        zero_centered_gamma,
        epsilon,
230
231
        layernorm_input_axes,
        dot_input_axes,    # pylint: disable=unused-argument
232
        ctx,
233
        grad):
234
    ln_out_, casted_kernel, amax_list, scale_list, scale_inv_list, \
235
236
    updated_x_amax, updated_kernel_amax, \
    x_shape, kernel_shape, mu, rsigma, x, gamma, \
237
    x_contracting_dims, k_contracting_dims, maybe_fp32_to_fm32 = ctx
238

239
    ln_out_t = tex.transpose(ln_out_, static_axis_boundary=-1, transpose_axis_boundary=-1)
240

241
242
243
    grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1]
    grad_scale = scale_list[FP8MetaPackage.GRAD_IDX]
    grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_IDX]
244

245
    casted_grad, casted_grad_t, updated_grad_amax = \
246
        tex.cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
247
                       static_axis_boundary=-1, transpose_axis_boundary=min(x_contracting_dims))
248

249
250
    xt_constracting_dim = tuple(range(len(x_contracting_dims), len(x_shape)))
    gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim))
251
    x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX]
252
    wgrad = fp8_dot_impl(ln_out_t, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
253
254
                         (xt_constracting_dim, gt_constracting_dim),
                         get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
255

256
    g_for_dgrad_constracting_dim = tuple(
257
258
        range(grad.ndim - len(kernel_shape) + len(k_contracting_dims), grad.ndim))
    k_constracting_dim = tuple(range(len(k_contracting_dims), len(kernel_shape)))
259
    kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX]
260
261
262
    dgrad = fp8_dot_impl(casted_grad, casted_kernel, grad_scale_inv, kernel_scale_inv, grad.dtype,
                         (g_for_dgrad_constracting_dim, k_constracting_dim),
                         get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
263

264
    dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
265
    if layernorm_type == 'layernorm':
266
        dx, dgamma, dbeta = tex.layernorm_bwd(dgrad,
267
268
269
270
271
272
                                          x,
                                          mu,
                                          rsigma,
                                          gamma,
                                          zero_centered_gamma=zero_centered_gamma,
                                          epsilon=epsilon)
273
    else:
274
275
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
276
        dx, dgamma = tex.rmsnorm_bwd(dgrad, x, rsigma, gamma, epsilon=epsilon)
277
        dbeta = None
278

279
280
281
282
283
284
    amax_list[FP8MetaPackage.INPUT_IDX] = \
        amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0])
    amax_list[FP8MetaPackage.WEIGHT_IDX] = \
        amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax[0])
    amax_list[FP8MetaPackage.GRAD_IDX] = \
        amax_list[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0])
285

286
287
    amax_list = maybe_fp32_to_fm32(*amax_list)
    scale_list = maybe_fp32_to_fm32(*scale_list)
288

289
290
    return dx, wgrad, \
           dgamma, dbeta, \
291
           amax_list, scale_list
292
293


294
_layernorm_fp8_dot.defvjp(_layernorm_fp8_dot_fwd_rule, _layernorm_fp8_dot_bwd_rule)