layernorm.py 10.9 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX layernorm modules"""

6
from functools import partial
7
from typing import List, Tuple
8

9
10
11
import jax
import jax.numpy as jnp

12
from . import cpp_extensions as tex
13
from .dot import fp8_dot_impl, get_precision_of_fp8_dot
14
from .fp8 import FP8Helper, FP8MetaPackage
15
from .sharding import with_sharding_constraint_by_logical_axes
16
17
18


def canonicalize_layernorm_type(x):
19
    """
20
    Canonicalize the layernorm type
21
22
23
    """
    canonicalized = x.lower().strip().replace("-", "").replace("_", "")
    assert canonicalized in ["layernorm", "rmsnorm"]
24
25
26
    return canonicalized


27
28
29
30
31
32
33
34
def layernorm(
    inputs: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    layernorm_type: str,
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
):
35
    """
36
37
    LN/RMSNorm  wrapper
    Only support layernorm_type in ['layernorm', 'rmsnorm']
38
    """
39
40
41
42
43
44
45
46
    output = _layernorm(
        inputs,
        gamma,
        beta,
        layernorm_type=layernorm_type,
        zero_centered_gamma=zero_centered_gamma,
        epsilon=epsilon,
    )
47
48
49
    return output


50
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
51
52
53
def _layernorm(
    x, gamma, beta, layernorm_type: str, zero_centered_gamma: bool = False, epsilon: float = 1e-6
):
54
    output, _ = _layernorm_fwd_rule(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon)
55
56
57
    return output


58
59
60
def _layernorm_fwd_rule(
    x, gamma, beta, layernorm_type: str, zero_centered_gamma: bool = False, epsilon: float = 1e-6
):
61
    layernorm_type = canonicalize_layernorm_type(layernorm_type)
62
    if layernorm_type == "layernorm":
63
        output, mu, rsigma = tex.layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon)
64
65
66
67
    elif layernorm_type == "rmsnorm":
        assert (
            not zero_centered_gamma
        ), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
68
        output, rsigma = tex.rmsnorm_fwd(x, gamma, epsilon)
69
        mu = None
70
71
    else:
        raise ValueError(f"{layernorm_type=} is not supported.")
72
    return output, (x, mu, rsigma, gamma, beta)
73
74


75
def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz):
76
    x, mu, rsigma, gamma, beta = ctx
77
78
    if layernorm_type == "layernorm":
        dx, dgamma, dbeta = tex.layernorm_bwd(
79
            dz, x, mu, rsigma, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
80
81
82
83
84
        )
    elif layernorm_type == "rmsnorm":
        assert (
            not zero_centered_gamma
        ), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
85
        dx, dgamma = tex.rmsnorm_bwd(dz, x, rsigma, gamma, epsilon=epsilon)
86
87
88
        dbeta = None
    else:
        raise ValueError(f"{layernorm_type=} is not supported.")
89

90
    return dx, dgamma, dbeta
91
92


93
_layernorm.defvjp(_layernorm_fwd_rule, _layernorm_bwd_rule)
94
95


96
97
98
99
100
101
102
103
104
105
def layernorm_fp8_dot(
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    fp8_meta_pkg: FP8MetaPackage,
    layernorm_type: str,
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
    layernorm_input_axes: Tuple[
106
107
108
109
110
        str, ...
    ] = None,  # The logic axes of sharding constraint to the layernorm input.
    dot_input_axes: Tuple[
        str, ...
    ] = None,  # The logic axes of sharding constraint to the dot input.
111
) -> jnp.ndarray:
112
    """
113
    Layernorm + FP8 GEMM
114
    """
115
116
    amax_list = fp8_meta_pkg.amax_list
    scale_list = fp8_meta_pkg.scale_list
117
118
    fwd_dtype = FP8Helper.FWD_DTYPE
    bwd_dtype = FP8Helper.BWD_DTYPE
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    output = _layernorm_fp8_dot(
        x,
        kernel,
        gamma,
        beta,
        amax_list,
        scale_list,
        layernorm_type,
        fwd_dtype,
        bwd_dtype,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_input_axes,
    )
134
135
136
    return output


137
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12))
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def _layernorm_fp8_dot(
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    amax_list: List[jnp.ndarray],
    scale_list: List[jnp.ndarray],
    layernorm_type: str,
    fwd_dtype: jnp.dtype,
    bwd_dtype: jnp.dtype,
    zero_centered_gamma: bool,
    epsilon: float,
    layernorm_input_axes: Tuple[str, ...],
    dot_input_axes: Tuple[str, ...],
):
    output, _ = _layernorm_fp8_dot_fwd_rule(
154
        x,
155
156
157
        kernel,
        gamma,
        beta,
158
159
        amax_list,
        scale_list,
160
161
        layernorm_type,
        fwd_dtype,
162
        bwd_dtype,
163
        zero_centered_gamma,
164
165
        epsilon,
        layernorm_input_axes,
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        dot_input_axes,
    )
    return output


def _layernorm_fp8_dot_fwd_rule(
    x,
    kernel,
    gamma,
    beta,
    amax_list,
    scale_list,
    layernorm_type,
    fwd_dtype,
    bwd_dtype,  # pylint: disable=unused-argument
    zero_centered_gamma,
    epsilon,
    layernorm_input_axes,
    dot_input_axes,
):
186
187
188
189

    x_contracting_dims = (len(x.shape) - 1,)
    k_contracting_dims = (0,)
    assert x.shape[-1] == kernel.shape[0]
190

191
192
193
    maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair(
        *amax_list, *scale_list
    )
194
195
    amax_list = maybe_fm32_to_fp32(*amax_list)
    scale_list = maybe_fm32_to_fp32(*scale_list)
196

197
    fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype]
198
199
200
    scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale(
        amax_list, scale_list, fp8_dtype_list
    )
201
    amax_list = FP8MetaPackage.update_amax_list(amax_list)
202

203
204
205
    x_amax = amax_list[FP8MetaPackage.INPUT_IDX][0:1]
    x_scale = scale_list[FP8MetaPackage.INPUT_IDX]
    x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX]
206

207
208
    x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)

209
    if layernorm_type == "layernorm":
210
        ln_out, mu, rsigma, updated_x_amax = tex.layernorm_fwd_fp8(
211
212
213
214
215
216
217
218
            x,
            gamma,
            beta,
            x_amax,
            x_scale,
            x_scale_inv,
            out_dtype=fwd_dtype,
            zero_centered_gamma=zero_centered_gamma,
219
220
            epsilon=epsilon,
        )
221
    else:
222
223
224
225
226
227
        assert (
            not zero_centered_gamma
        ), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
        ln_out, rsigma, updated_x_amax = tex.rmsnorm_fwd_fp8(
            x, gamma, x_amax, x_scale, x_scale_inv, out_dtype=fwd_dtype, epsilon=epsilon
        )
228
229
        mu = None

230
    assert x.shape == ln_out.shape
231

232
233
234
    kernel_amax = amax_list[FP8MetaPackage.WEIGHT_IDX][0:1]
    kernel_scale = scale_list[FP8MetaPackage.WEIGHT_IDX]
    kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX]
235

236
    # Kernel in (hidden_in, hidden_out...)
237
238
    # Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
239
240
241
    casted_kernel, updated_kernel_amax = tex.cast_fp8(
        kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype
    )
242

243
244
    ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_input_axes)

245
    # (batch..., hidden_in) x (hidden_in, hidden_out...)
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    output = fp8_dot_impl(
        ln_out,
        casted_kernel,
        x_scale_inv,
        kernel_scale_inv,
        x.dtype,
        (x_contracting_dims, k_contracting_dims),
        get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP),
    )

    ctx = (
        ln_out,
        casted_kernel,
        amax_list,
        scale_list,
        scale_inv_list,
        updated_x_amax,
        updated_kernel_amax,
        x.shape,
        kernel.shape,
        mu,
        rsigma,
        x,
        gamma,
270
        beta,
271
272
273
274
        x_contracting_dims,
        k_contracting_dims,
        maybe_fp32_to_fm32,
    )
275
276
277
278

    return output, ctx


279
def _layernorm_fp8_dot_bwd_rule(
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    layernorm_type,
    fwd_dtype,  # pylint: disable=unused-argument
    bwd_dtype,
    zero_centered_gamma,
    epsilon,
    layernorm_input_axes,
    dot_input_axes,  # pylint: disable=unused-argument
    ctx,
    grad,
):
    (
        ln_out_,
        casted_kernel,
        amax_list,
        scale_list,
        scale_inv_list,
        updated_x_amax,
        updated_kernel_amax,
        x_shape,
        kernel_shape,
        mu,
        rsigma,
        x,
        gamma,
304
        beta,
305
306
307
308
        x_contracting_dims,
        k_contracting_dims,
        maybe_fp32_to_fm32,
    ) = ctx
309

310
    ln_out_t = tex.transpose(ln_out_, static_axis_boundary=-1, transpose_axis_boundary=-1)
311

312
313
314
    grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1]
    grad_scale = scale_list[FP8MetaPackage.GRAD_IDX]
    grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_IDX]
315

316
317
318
319
320
321
322
323
324
    casted_grad, casted_grad_t, updated_grad_amax = tex.cast_transpose(
        grad,
        grad_amax,
        grad_scale,
        grad_scale_inv,
        bwd_dtype,
        static_axis_boundary=-1,
        transpose_axis_boundary=min(x_contracting_dims),
    )
325

326
327
    xt_constracting_dim = tuple(range(len(x_contracting_dims), len(x_shape)))
    gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim))
328
    x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX]
329
330
331
332
333
334
335
336
337
    wgrad = fp8_dot_impl(
        ln_out_t,
        casted_grad_t,
        x_scale_inv,
        grad_scale_inv,
        grad.dtype,
        (xt_constracting_dim, gt_constracting_dim),
        get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD),
    )
338

339
    g_for_dgrad_constracting_dim = tuple(
340
341
        range(grad.ndim - len(kernel_shape) + len(k_contracting_dims), grad.ndim)
    )
342
    k_constracting_dim = tuple(range(len(k_contracting_dims), len(kernel_shape)))
343
    kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX]
344
345
346
347
348
349
350
351
352
    dgrad = fp8_dot_impl(
        casted_grad,
        casted_kernel,
        grad_scale_inv,
        kernel_scale_inv,
        grad.dtype,
        (g_for_dgrad_constracting_dim, k_constracting_dim),
        get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD),
    )
353

354
    dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
355
356
    if layernorm_type == "layernorm":
        dx, dgamma, dbeta = tex.layernorm_bwd(
357
358
359
360
361
362
363
364
            dgrad,
            x,
            mu,
            rsigma,
            gamma,
            beta,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon,
365
        )
366
    else:
367
368
369
        assert (
            not zero_centered_gamma
        ), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
370
        dx, dgamma = tex.rmsnorm_bwd(dgrad, x, rsigma, gamma, epsilon=epsilon)
371
        dbeta = None
372

373
    amax_list[FP8MetaPackage.INPUT_IDX] = (
374
        amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0])
375
376
    )
    amax_list[FP8MetaPackage.WEIGHT_IDX] = (
377
        amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax[0])
378
379
    )
    amax_list[FP8MetaPackage.GRAD_IDX] = (
380
        amax_list[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0])
381
    )
382

383
384
    amax_list = maybe_fp32_to_fm32(*amax_list)
    scale_list = maybe_fp32_to_fm32(*scale_list)
385

386
    return dx, wgrad, dgamma, dbeta, amax_list, scale_list
387
388


389
_layernorm_fp8_dot.defvjp(_layernorm_fp8_dot_fwd_rule, _layernorm_fp8_dot_bwd_rule)