layernorm.py 20.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX layernorm modules"""

from typing import Tuple, Sequence
from functools import partial, reduce
import operator
import jax
import jax.numpy as jnp

from transformer_engine_jax import DType as TEDType
from .cpp_extensions import cast_transpose, gemm, jax_dtype_to_te_dtype
from .cpp_extensions import transpose
from .cpp_extensions import rmsnorm_fwd, rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd, layernorm_fwd_fp8, layernorm_bwd
from .fp8 import FP8Helper, FP8GemmPackage
from .sharding import ShardingType, get_elementwise_sharding_meta
from .sharding import get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import is_dp_enabled, is_tp_enabled, merge_axis_resources
21
from .sharding import xmap_runner, extend_fsdp_sharding_meta
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)


def canonicalize_layernorm_type(x):
    '''
    Canonicalize the layernorm type
    '''
    canonicalized = x.lower().strip().replace('-', '').replace('_', '')
    assert canonicalized in ['layernorm', 'rmsnorm']
    return canonicalized


def layernorm(inputs: jnp.ndarray,
              gamma: jnp.ndarray,
              beta: jnp.ndarray,
              layernorm_type: str,
40
              zero_centered_gamma: bool = False,
41
42
43
44
45
46
47
48
49
50
51
52
              epsilon: float = 1e-6,
              sharding_type: ShardingType = ShardingType.SINGLE,
              dp_dim_index: int = 0):
    """
    Layernorm wrapper
    """
    assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
        "layernorm does not support row-split tensor parallelism currently."

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
    if layernorm_type == 'rmsnorm':
        assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
53
54
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
55
56
57
58
59
60

    if sharding_type is ShardingType.SINGLE:
        output = _layernorm(inputs,
                            gamma,
                            beta,
                            layernorm_type=layernorm_type,
61
62
                            zero_centered_gamma=zero_centered_gamma,
                            epsilon=epsilon,
63
                            sharding_type=sharding_type,
64
65
                            dp_axis_name="",
                            fsdp_axis_name="")
66
67
68
69
70
    else:
        dp_axis_name = "batch"
        tp_axis_name = "model"
        sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, gamma.shape,
                                                      dp_dim_index, dp_axis_name, tp_axis_name)
71
72

        sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index})
73
74
75
76
77
78
79
80
81
82
83
84
        inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0])    # 0 for input
        gamma_ = jnp.reshape(gamma, sharding_meta.input_shapes[1])    # 1 for gamma
        beta_ = beta
        beta_in_axis = {}
        if beta_ is not None:
            beta_ = jnp.reshape(beta_, sharding_meta.input_shapes[1])    # 1 for beta
            beta_in_axis = sharding_meta.in_axes[1]

        in_axes = (*sharding_meta.in_axes, beta_in_axis)

        partial_ln = partial(_layernorm,
                             layernorm_type=layernorm_type,
85
86
                             zero_centered_gamma=zero_centered_gamma,
                             epsilon=epsilon,
87
                             sharding_type=sharding_type,
88
89
                             dp_axis_name=dp_axis_name,
                             fsdp_axis_name=fsdp_axis_name)
90
91
92
93
94
95
96
97
98

        output = xmap_runner(partial_ln, in_axes, sharding_meta.out_axes,
                             sharding_meta.axis_resources, (inputs_, gamma_, beta_))

        output = jnp.reshape(output, sharding_meta.output_shapes[0])

    return output


99
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8))
100
def _layernorm(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon, sharding_type,
101
               dp_axis_name, fsdp_axis_name):
102
    output, _ = _layernorm_fwd(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon,
103
                               sharding_type, dp_axis_name, fsdp_axis_name)
104
105
106
107
108
109
110
111
    return output


def _layernorm_fwd(
        x,
        gamma,
        beta,
        layernorm_type,
112
113
        zero_centered_gamma,
        epsilon,
114
        sharding_type,    # pylint: disable=unused-argument
115
116
        dp_axis_name,    # pylint: disable=unused-argument
        fsdp_axis_name    # pylint: disable=unused-argument
117
):
118
    if layernorm_type == 'layernorm':
119
        output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon)
120
    else:
121
122
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
123
124
125
126
127
        output, rsigma = rmsnorm_fwd(x, gamma, epsilon)
        mu = None
    return output, (mu, rsigma, x, gamma)


128
129
def _layernorm_bwd(layernorm_type, zero_centered_gamma, epsilon, sharding_type, dp_axis_name,
                   fsdp_axis_name, ctx, g):
130
131
132
    mu, rsigma, x, gamma = ctx

    if layernorm_type == 'layernorm':
133
134
135
136
137
138
139
        grad_input, grad_gamma, grad_beta = layernorm_bwd(g,
                                                          mu,
                                                          rsigma,
                                                          x,
                                                          gamma,
                                                          zero_centered_gamma=zero_centered_gamma,
                                                          epsilon=epsilon)
140
    else:
141
142
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
143
144
145
146
147
148
149
        grad_input, grad_gamma = rmsnorm_bwd(g, rsigma, x, gamma, epsilon=epsilon)
        grad_beta = None

    if is_dp_enabled(sharding_type.value[0]):
        grad_gamma = jax.lax.psum(grad_gamma, dp_axis_name)
        if grad_beta is not None:
            grad_beta = jax.lax.psum(grad_beta, dp_axis_name)
150
151
152
153
154
    if len(fsdp_axis_name) > 0:
        grad_gamma = jax.lax.psum(grad_gamma, fsdp_axis_name)
        if grad_beta is not None:
            grad_beta = jax.lax.psum(grad_beta, fsdp_axis_name)

155
156
157
158
159
160
161
162
163
164
165
166
167
    return grad_input, grad_gamma, grad_beta


_layernorm.defvjp(_layernorm_fwd, _layernorm_bwd)


def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
                      gamma: jnp.ndarray,
                      beta: jnp.ndarray,
                      layernorm_type: str,
                      fwd_dtype: TEDType,
                      bwd_dtype: TEDType,
                      contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
168
169
                      zero_centered_gamma: bool = False,
                      epsilon: float = 1e-6,
170
                      sharding_type: ShardingType = ShardingType.SINGLE,
171
                      dp_dim_index: int = 0) -> jnp.ndarray:
172
173
174
175
176
177
178
179
180
    """
    LN + fp8 dot fusion wrapper
    """
    assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
        "layernorm_fp8_dot does not support row-split tensor parallelism currently."

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
    if layernorm_type == 'rmsnorm':
        assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
181
182
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

    assert fp8_gemm_pkg.num_of_gemm == 1
    inputs = fp8_gemm_pkg.inputs
    kernel = fp8_gemm_pkg.kernels[0]
    fp8_max = fp8_gemm_pkg.fp8_max
    amax = fp8_gemm_pkg.amax
    scale = fp8_gemm_pkg.scale
    scale_inv = fp8_gemm_pkg.scale_inv

    if sharding_type is ShardingType.SINGLE:
        output = _layernorm_fp8_dot(inputs,
                                    kernel,
                                    gamma,
                                    beta,
                                    fp8_max,
                                    amax,
                                    scale,
                                    scale_inv,
                                    layernorm_type,
                                    fwd_dtype,
                                    bwd_dtype,
                                    contracting_dims,
205
206
                                    zero_centered_gamma=zero_centered_gamma,
                                    epsilon=epsilon,
207
208
                                    sharding_type=sharding_type,
                                    dp_axis_name="",
209
210
                                    tp_axis_name="",
                                    fsdp_axis_name="")
211
212
213
214
215
216
    else:
        dp_axis_name = "batch"
        tp_axis_name = "model"

        ln_sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, gamma.shape,
                                                         dp_dim_index, dp_axis_name, tp_axis_name)
217
        ln_sharding_meta, _ = extend_fsdp_sharding_meta(ln_sharding_meta, {0: dp_dim_index})
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        inputs_ = jnp.reshape(inputs, ln_sharding_meta.input_shapes[0])    # 0 for input
        gamma_ = jnp.reshape(gamma, ln_sharding_meta.input_shapes[1])    # 1 for gamma
        beta_ = beta
        beta_in_axis = {}
        if beta_ is not None:
            beta_ = jnp.reshape(beta_, ln_sharding_meta.input_shapes[1])    # 1 for beta
            beta_in_axis = ln_sharding_meta.in_axes[1]

        kernel_tp_index = None
        # TODO (Ming Huang): Should we add a new argument to support general sharding to kernel? # pylint: disable=fixme
        if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
            kernel_tp_index = len(kernel.shape) - 1
        elif sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
            kernel_tp_index = 0

        input_tp_index = len(inputs.shape) - 1
        dot_sharding_meta = get_dot_sharding_meta(sharding_type, inputs.shape, kernel.shape,
                                                  dp_dim_index, input_tp_index, kernel_tp_index,
                                                  contracting_dims, dp_axis_name, tp_axis_name)
237
238
        dot_sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(dot_sharding_meta,
                                                                      {0: dp_dim_index})
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        kernel_ = jnp.reshape(kernel, dot_sharding_meta.input_shapes[1])    # 1 for kernel

        num_of_fp8_meta_kind = 4    # fp8_max, amax, scale, scale_inv
        fp8_sharding_meta = get_fp8_meta_sharding_meta(sharding_type, num_of_fp8_meta_kind,
                                                       dp_axis_name, tp_axis_name)

        axis_resource = merge_axis_resources([
            ln_sharding_meta.axis_resources, dot_sharding_meta.axis_resources,
            fp8_sharding_meta.axis_resources
        ])

        partial_ln_fp8_dot = partial(_layernorm_fp8_dot,
                                     layernorm_type=layernorm_type,
                                     fwd_dtype=fwd_dtype,
                                     bwd_dtype=bwd_dtype,
                                     contracting_dims=contracting_dims,
255
256
                                     zero_centered_gamma=zero_centered_gamma,
                                     epsilon=epsilon,
257
258
                                     sharding_type=sharding_type,
                                     dp_axis_name=dp_axis_name,
259
260
                                     tp_axis_name=tp_axis_name,
                                     fsdp_axis_name=fsdp_axis_name)
261
262
263
264
265
266
267
268
269
270
271
272

        # input, kernel, gamma, beta, fp8_metas
        in_axes = (ln_sharding_meta.in_axes[0], dot_sharding_meta.in_axes[1],
                   ln_sharding_meta.in_axes[1], beta_in_axis, *fp8_sharding_meta.in_axes)

        output = xmap_runner(partial_ln_fp8_dot, in_axes, dot_sharding_meta.out_axes, axis_resource,
                             (inputs_, kernel_, gamma_, beta_, fp8_max, amax, scale, scale_inv))

        output = jnp.reshape(output, dot_sharding_meta.output_shapes[0])
    return output


273
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
274
275
276
277
def _layernorm_fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray,
                       beta: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
                       scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str,
                       fwd_dtype: TEDType, bwd_dtype: TEDType,
278
                       contracting_dims: Tuple[Sequence[int], Sequence[int]],
279
                       zero_centered_gamma: bool, epsilon: float, sharding_type: ShardingType,
280
                       dp_axis_name: str, tp_axis_name: str, fsdp_axis_name: str) -> jnp.ndarray:
281
    output, _ = _layernorm_fp8_dot_fwd(inputs, kernel, gamma, beta, fp8_maxs, amax, scale,
282
                                       scale_inv, layernorm_type, fwd_dtype, bwd_dtype,
283
                                       contracting_dims, zero_centered_gamma, epsilon,
284
                                       sharding_type, dp_axis_name, tp_axis_name, fsdp_axis_name)
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    return output


def _layernorm_fp8_dot_fwd(
        inputs,
        kernel,
        gamma,
        beta,
        fp8_maxs,
        amax,
        scale,
        scale_inv,
        layernorm_type,
        fwd_dtype,
        bwd_dtype,    # pylint: disable=unused-argument
        contracting_dims,
301
302
        zero_centered_gamma,
        epsilon,
303
304
        sharding_type,
        dp_axis_name,    # pylint: disable=unused-argument
305
306
        tp_axis_name,
        fsdp_axis_name):    # pylint: disable=unused-argument
307
308
309
310
311
312
313
314
315
316

    lhs_contracting_dims, rhs_contracting_dims = contracting_dims
    input_shape_pre = inputs.shape[:min(lhs_contracting_dims)]
    input_shape_suf = inputs.shape[min(lhs_contracting_dims):]
    kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1]
    kernel_shape_suf = kernel.shape[max(rhs_contracting_dims) + 1:]
    input_contracting_size = reduce(operator.mul, input_shape_suf)
    kernel_contracting_size = reduce(operator.mul, kernel_shape_pre)
    assert input_contracting_size == kernel_contracting_size

317
318
    amax = FP8Helper.update_amax_history(amax)

319
320
    gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)

321
    input_amax = amax[gemm_input_idx, 0:1]
322
323
324
325
326
327
328
329
330
    input_scale = scale[gemm_input_idx]
    input_scale_inv = scale_inv[gemm_input_idx]
    if layernorm_type == 'layernorm':
        ln_out, mu, rsigma, input_amax = layernorm_fwd_fp8(inputs,
                                                           gamma,
                                                           beta,
                                                           input_amax,
                                                           input_scale,
                                                           input_scale_inv,
331
                                                           zero_centered_gamma=zero_centered_gamma,
332
333
                                                           epsilon=epsilon)
    else:
334
335
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
336
337
338
339
340
341
342
343
344
345
346
347
        ln_out, rsigma, input_amax = rmsnorm_fwd_fp8(inputs,
                                                     gamma,
                                                     input_amax,
                                                     input_scale,
                                                     input_scale_inv,
                                                     epsilon=epsilon)
        mu = None

    assert inputs.shape == ln_out.shape
    ln_out_ = jnp.reshape(ln_out, (-1, input_contracting_size))
    kernel_ = jnp.reshape(kernel, (kernel_contracting_size, -1))

348
    kernel_amax = amax[gemm_kernel_idx, 0:1]
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    kernel_scale = scale[gemm_kernel_idx]
    kernel_scale_inv = scale_inv[gemm_kernel_idx]
    kernel_cast, kernel_cast_trans, kernel_amax = cast_transpose(kernel_, kernel_amax, kernel_scale,
                                                                 kernel_scale_inv, fwd_dtype)

    output = gemm(kernel_cast_trans, kernel_scale_inv, fwd_dtype, True, ln_out_, input_scale_inv,
                  fwd_dtype, False, jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP)

    if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
        output = jax.lax.psum(output, tp_axis_name)

    # (input_shape_pre, input_shape_suf)
    # x (kernel_shape_pre, kernel_shape_suf)
    # = (input_shape_pre, kernel_shape_suf)
    output_shape = input_shape_pre + kernel_shape_suf
    output = jnp.reshape(output, output_shape)

    ctx = (ln_out_, kernel_cast, fp8_maxs, amax, scale, scale_inv, input_amax, kernel_amax,
           inputs.shape, kernel.shape, mu, rsigma, inputs, gamma)
    return output, ctx


def _layernorm_fp8_dot_bwd(
        layernorm_type,
        fwd_dtype,
        bwd_dtype,
        contracting_dims,    # pylint: disable=unused-argument
376
377
        zero_centered_gamma,
        epsilon,
378
379
380
        sharding_type,
        dp_axis_name,
        tp_axis_name,
381
        fsdp_axis_name,
382
383
384
385
386
387
388
389
390
391
392
        ctx,
        g):
    ln_out_, kernel_cast, \
    fp8_maxs, amax, scale, scale_inv, \
    input_amax, kernel_amax, \
    inputs_shape, kernel_shape, \
    mu, rsigma, inputs, gamma = ctx

    gemm_input_idx, gemm_kernel_idx, gemm_grad_idx = \
        FP8Helper.get_fp8_meta_indices(0)

393
    grad_amax = amax[gemm_grad_idx, 0:1]
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
    grad_scale = scale[gemm_grad_idx]
    grad_scale_inv = scale_inv[gemm_grad_idx]

    ln_out_trans = transpose(ln_out_, fwd_dtype)
    g = jnp.reshape(g, (ln_out_trans.shape[1], -1))

    # cast and transpose the grad_output
    grad_cast, grad_cast_trans, grad_amax = cast_transpose(g, grad_amax, grad_scale, grad_scale_inv,
                                                           bwd_dtype)

    input_scale_inv = scale_inv[gemm_input_idx]
    wgrad = gemm(grad_cast_trans, grad_scale_inv, bwd_dtype, True, ln_out_trans, input_scale_inv,
                 fwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD)

    kernel_scale_inv = scale_inv[gemm_kernel_idx]
    dgrad = gemm(kernel_cast, kernel_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv,
                 bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD)

    dgrad = jnp.reshape(dgrad, inputs_shape)

    if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
        dgrad = jax.lax.psum(dgrad, tp_axis_name)

    if layernorm_type == 'layernorm':
418
419
420
421
422
423
424
        grad_input, grad_gamma, grad_beta = layernorm_bwd(dgrad,
                                                          mu,
                                                          rsigma,
                                                          inputs,
                                                          gamma,
                                                          zero_centered_gamma=zero_centered_gamma,
                                                          epsilon=epsilon)
425
    else:
426
427
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
428
429
430
        grad_input, grad_gamma = rmsnorm_bwd(dgrad, rsigma, inputs, gamma, epsilon=epsilon)
        grad_beta = None

431
432
433
    amax = amax.at[gemm_input_idx, 0].set(input_amax[0])
    amax = amax.at[gemm_kernel_idx, 0].set(kernel_amax[0])
    amax = amax.at[gemm_grad_idx, 0].set(grad_amax[0])
434
435
436
437
438
439
440
441

    if is_dp_enabled(sharding_type.value[0]):
        wgrad = jax.lax.psum(wgrad, dp_axis_name)
        grad_gamma = jax.lax.psum(grad_gamma, dp_axis_name)
        if grad_beta is not None:
            grad_beta = jax.lax.psum(grad_beta, dp_axis_name)
        amax = jax.lax.pmax(amax, dp_axis_name)

442
443
444
445
446
447
448
    if len(fsdp_axis_name) > 0:
        wgrad = jax.lax.psum(wgrad, fsdp_axis_name)
        grad_gamma = jax.lax.psum(grad_gamma, fsdp_axis_name)
        if grad_beta is not None:
            grad_beta = jax.lax.psum(grad_beta, fsdp_axis_name)
        amax = jax.lax.pmax(amax, fsdp_axis_name)

449
450
451
452
453
454
455
456
457
458
    if is_tp_enabled(sharding_type.value[0]):
        amax = jax.lax.pmax(amax, tp_axis_name)

    wgrad = jnp.reshape(wgrad, kernel_shape)
    return grad_input, wgrad, \
           grad_gamma, grad_beta, \
           fp8_maxs, amax, scale, scale_inv


_layernorm_fp8_dot.defvjp(_layernorm_fp8_dot_fwd, _layernorm_fp8_dot_bwd)