layernorm.py 18.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX layernorm modules"""

from typing import Tuple, Sequence
from functools import partial, reduce
import operator
import jax
import jax.numpy as jnp

from transformer_engine_jax import DType as TEDType
from .cpp_extensions import cast_transpose, gemm, jax_dtype_to_te_dtype
from .cpp_extensions import transpose
from .cpp_extensions import rmsnorm_fwd, rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd, layernorm_fwd_fp8, layernorm_bwd
from .fp8 import FP8Helper, FP8GemmPackage
from .sharding import ShardingType, get_elementwise_sharding_meta
from .sharding import get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import is_dp_enabled, is_tp_enabled, merge_axis_resources
from .sharding import xmap_runner

jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)


def canonicalize_layernorm_type(x):
    '''
    Canonicalize the layernorm type
    '''
    canonicalized = x.lower().strip().replace('-', '').replace('_', '')
    assert canonicalized in ['layernorm', 'rmsnorm']
    return canonicalized


def layernorm(inputs: jnp.ndarray,
              gamma: jnp.ndarray,
              beta: jnp.ndarray,
              layernorm_type: str,
40
              zero_centered_gamma: bool = False,
41
42
43
44
45
46
47
48
49
50
51
52
              epsilon: float = 1e-6,
              sharding_type: ShardingType = ShardingType.SINGLE,
              dp_dim_index: int = 0):
    """
    Layernorm wrapper
    """
    assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
        "layernorm does not support row-split tensor parallelism currently."

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
    if layernorm_type == 'rmsnorm':
        assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
53
54
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
55
56
57
58
59
60

    if sharding_type is ShardingType.SINGLE:
        output = _layernorm(inputs,
                            gamma,
                            beta,
                            layernorm_type=layernorm_type,
61
62
                            zero_centered_gamma=zero_centered_gamma,
                            epsilon=epsilon,
63
                            sharding_type=sharding_type,
64
                            dp_axis_name="")
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    else:
        dp_axis_name = "batch"
        tp_axis_name = "model"
        sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, gamma.shape,
                                                      dp_dim_index, dp_axis_name, tp_axis_name)
        inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0])    # 0 for input
        gamma_ = jnp.reshape(gamma, sharding_meta.input_shapes[1])    # 1 for gamma
        beta_ = beta
        beta_in_axis = {}
        if beta_ is not None:
            beta_ = jnp.reshape(beta_, sharding_meta.input_shapes[1])    # 1 for beta
            beta_in_axis = sharding_meta.in_axes[1]

        in_axes = (*sharding_meta.in_axes, beta_in_axis)

        partial_ln = partial(_layernorm,
                             layernorm_type=layernorm_type,
82
83
                             zero_centered_gamma=zero_centered_gamma,
                             epsilon=epsilon,
84
                             sharding_type=sharding_type,
85
                             dp_axis_name=dp_axis_name)
86
87
88
89
90
91
92
93
94

        output = xmap_runner(partial_ln, in_axes, sharding_meta.out_axes,
                             sharding_meta.axis_resources, (inputs_, gamma_, beta_))

        output = jnp.reshape(output, sharding_meta.output_shapes[0])

    return output


95
96
97
98
99
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7))
def _layernorm(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon, sharding_type,
               dp_axis_name):
    output, _ = _layernorm_fwd(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon,
                               sharding_type, dp_axis_name)
100
101
102
103
104
105
106
107
    return output


def _layernorm_fwd(
        x,
        gamma,
        beta,
        layernorm_type,
108
109
        zero_centered_gamma,
        epsilon,
110
        sharding_type,    # pylint: disable=unused-argument
111
112
        dp_axis_name    # pylint: disable=unused-argument
):
113
    if layernorm_type == 'layernorm':
114
        output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon)
115
    else:
116
117
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
118
119
120
121
122
        output, rsigma = rmsnorm_fwd(x, gamma, epsilon)
        mu = None
    return output, (mu, rsigma, x, gamma)


123
124
def _layernorm_bwd(layernorm_type, zero_centered_gamma, epsilon, sharding_type, dp_axis_name, ctx,
                   g):
125
126
127
    mu, rsigma, x, gamma = ctx

    if layernorm_type == 'layernorm':
128
129
130
131
132
133
134
        grad_input, grad_gamma, grad_beta = layernorm_bwd(g,
                                                          mu,
                                                          rsigma,
                                                          x,
                                                          gamma,
                                                          zero_centered_gamma=zero_centered_gamma,
                                                          epsilon=epsilon)
135
    else:
136
137
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        grad_input, grad_gamma = rmsnorm_bwd(g, rsigma, x, gamma, epsilon=epsilon)
        grad_beta = None

    if is_dp_enabled(sharding_type.value[0]):
        grad_gamma = jax.lax.psum(grad_gamma, dp_axis_name)
        if grad_beta is not None:
            grad_beta = jax.lax.psum(grad_beta, dp_axis_name)
    return grad_input, grad_gamma, grad_beta


_layernorm.defvjp(_layernorm_fwd, _layernorm_bwd)


def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
                      gamma: jnp.ndarray,
                      beta: jnp.ndarray,
                      layernorm_type: str,
                      fwd_dtype: TEDType,
                      bwd_dtype: TEDType,
                      contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
158
159
                      zero_centered_gamma: bool = False,
                      epsilon: float = 1e-6,
160
                      sharding_type: ShardingType = ShardingType.SINGLE,
161
                      dp_dim_index: int = 0) -> jnp.ndarray:
162
163
164
165
166
167
168
169
170
    """
    LN + fp8 dot fusion wrapper
    """
    assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
        "layernorm_fp8_dot does not support row-split tensor parallelism currently."

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
    if layernorm_type == 'rmsnorm':
        assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
171
172
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

    assert fp8_gemm_pkg.num_of_gemm == 1
    inputs = fp8_gemm_pkg.inputs
    kernel = fp8_gemm_pkg.kernels[0]
    fp8_max = fp8_gemm_pkg.fp8_max
    amax = fp8_gemm_pkg.amax
    scale = fp8_gemm_pkg.scale
    scale_inv = fp8_gemm_pkg.scale_inv

    if sharding_type is ShardingType.SINGLE:
        output = _layernorm_fp8_dot(inputs,
                                    kernel,
                                    gamma,
                                    beta,
                                    fp8_max,
                                    amax,
                                    scale,
                                    scale_inv,
                                    layernorm_type,
                                    fwd_dtype,
                                    bwd_dtype,
                                    contracting_dims,
195
196
                                    zero_centered_gamma=zero_centered_gamma,
                                    epsilon=epsilon,
197
198
                                    sharding_type=sharding_type,
                                    dp_axis_name="",
199
                                    tp_axis_name="")
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    else:
        dp_axis_name = "batch"
        tp_axis_name = "model"

        ln_sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, gamma.shape,
                                                         dp_dim_index, dp_axis_name, tp_axis_name)
        inputs_ = jnp.reshape(inputs, ln_sharding_meta.input_shapes[0])    # 0 for input
        gamma_ = jnp.reshape(gamma, ln_sharding_meta.input_shapes[1])    # 1 for gamma
        beta_ = beta
        beta_in_axis = {}
        if beta_ is not None:
            beta_ = jnp.reshape(beta_, ln_sharding_meta.input_shapes[1])    # 1 for beta
            beta_in_axis = ln_sharding_meta.in_axes[1]

        kernel_tp_index = None
        # TODO (Ming Huang): Should we add a new argument to support general sharding to kernel? # pylint: disable=fixme
        if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
            kernel_tp_index = len(kernel.shape) - 1
        elif sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
            kernel_tp_index = 0

        input_tp_index = len(inputs.shape) - 1
        dot_sharding_meta = get_dot_sharding_meta(sharding_type, inputs.shape, kernel.shape,
                                                  dp_dim_index, input_tp_index, kernel_tp_index,
                                                  contracting_dims, dp_axis_name, tp_axis_name)
        kernel_ = jnp.reshape(kernel, dot_sharding_meta.input_shapes[1])    # 1 for kernel

        num_of_fp8_meta_kind = 4    # fp8_max, amax, scale, scale_inv
        fp8_sharding_meta = get_fp8_meta_sharding_meta(sharding_type, num_of_fp8_meta_kind,
                                                       dp_axis_name, tp_axis_name)

        axis_resource = merge_axis_resources([
            ln_sharding_meta.axis_resources, dot_sharding_meta.axis_resources,
            fp8_sharding_meta.axis_resources
        ])

        partial_ln_fp8_dot = partial(_layernorm_fp8_dot,
                                     layernorm_type=layernorm_type,
                                     fwd_dtype=fwd_dtype,
                                     bwd_dtype=bwd_dtype,
                                     contracting_dims=contracting_dims,
241
242
                                     zero_centered_gamma=zero_centered_gamma,
                                     epsilon=epsilon,
243
244
                                     sharding_type=sharding_type,
                                     dp_axis_name=dp_axis_name,
245
                                     tp_axis_name=tp_axis_name)
246
247
248
249
250
251
252
253
254
255
256
257

        # input, kernel, gamma, beta, fp8_metas
        in_axes = (ln_sharding_meta.in_axes[0], dot_sharding_meta.in_axes[1],
                   ln_sharding_meta.in_axes[1], beta_in_axis, *fp8_sharding_meta.in_axes)

        output = xmap_runner(partial_ln_fp8_dot, in_axes, dot_sharding_meta.out_axes, axis_resource,
                             (inputs_, kernel_, gamma_, beta_, fp8_max, amax, scale, scale_inv))

        output = jnp.reshape(output, dot_sharding_meta.output_shapes[0])
    return output


258
259
260
261
262
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15, 16))
def _layernorm_fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray,
                       beta: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
                       scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str,
                       fwd_dtype: TEDType, bwd_dtype: TEDType,
263
                       contracting_dims: Tuple[Sequence[int], Sequence[int]],
264
265
                       zero_centered_gamma: bool, epsilon: float, sharding_type: ShardingType,
                       dp_axis_name: str, tp_axis_name: str) -> jnp.ndarray:
266
    output, _ = _layernorm_fp8_dot_fwd(inputs, kernel, gamma, beta, fp8_maxs, amax, scale,
267
                                       scale_inv, layernorm_type, fwd_dtype, bwd_dtype,
268
269
                                       contracting_dims, zero_centered_gamma, epsilon,
                                       sharding_type, dp_axis_name, tp_axis_name)
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
    return output


def _layernorm_fp8_dot_fwd(
        inputs,
        kernel,
        gamma,
        beta,
        fp8_maxs,
        amax,
        scale,
        scale_inv,
        layernorm_type,
        fwd_dtype,
        bwd_dtype,    # pylint: disable=unused-argument
        contracting_dims,
286
287
        zero_centered_gamma,
        epsilon,
288
289
        sharding_type,
        dp_axis_name,    # pylint: disable=unused-argument
290
        tp_axis_name):
291
292
293
294
295
296
297
298
299
300

    lhs_contracting_dims, rhs_contracting_dims = contracting_dims
    input_shape_pre = inputs.shape[:min(lhs_contracting_dims)]
    input_shape_suf = inputs.shape[min(lhs_contracting_dims):]
    kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1]
    kernel_shape_suf = kernel.shape[max(rhs_contracting_dims) + 1:]
    input_contracting_size = reduce(operator.mul, input_shape_suf)
    kernel_contracting_size = reduce(operator.mul, kernel_shape_pre)
    assert input_contracting_size == kernel_contracting_size

301
302
    amax = FP8Helper.update_amax_history(amax)

303
304
    gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)

305
    input_amax = amax[gemm_input_idx, 0:1]
306
307
308
309
310
311
312
313
314
    input_scale = scale[gemm_input_idx]
    input_scale_inv = scale_inv[gemm_input_idx]
    if layernorm_type == 'layernorm':
        ln_out, mu, rsigma, input_amax = layernorm_fwd_fp8(inputs,
                                                           gamma,
                                                           beta,
                                                           input_amax,
                                                           input_scale,
                                                           input_scale_inv,
315
                                                           zero_centered_gamma=zero_centered_gamma,
316
317
                                                           epsilon=epsilon)
    else:
318
319
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
320
321
322
323
324
325
326
327
328
329
330
331
        ln_out, rsigma, input_amax = rmsnorm_fwd_fp8(inputs,
                                                     gamma,
                                                     input_amax,
                                                     input_scale,
                                                     input_scale_inv,
                                                     epsilon=epsilon)
        mu = None

    assert inputs.shape == ln_out.shape
    ln_out_ = jnp.reshape(ln_out, (-1, input_contracting_size))
    kernel_ = jnp.reshape(kernel, (kernel_contracting_size, -1))

332
    kernel_amax = amax[gemm_kernel_idx, 0:1]
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
    kernel_scale = scale[gemm_kernel_idx]
    kernel_scale_inv = scale_inv[gemm_kernel_idx]
    kernel_cast, kernel_cast_trans, kernel_amax = cast_transpose(kernel_, kernel_amax, kernel_scale,
                                                                 kernel_scale_inv, fwd_dtype)

    output = gemm(kernel_cast_trans, kernel_scale_inv, fwd_dtype, True, ln_out_, input_scale_inv,
                  fwd_dtype, False, jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP)

    if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
        output = jax.lax.psum(output, tp_axis_name)

    # (input_shape_pre, input_shape_suf)
    # x (kernel_shape_pre, kernel_shape_suf)
    # = (input_shape_pre, kernel_shape_suf)
    output_shape = input_shape_pre + kernel_shape_suf
    output = jnp.reshape(output, output_shape)

    ctx = (ln_out_, kernel_cast, fp8_maxs, amax, scale, scale_inv, input_amax, kernel_amax,
           inputs.shape, kernel.shape, mu, rsigma, inputs, gamma)
    return output, ctx


def _layernorm_fp8_dot_bwd(
        layernorm_type,
        fwd_dtype,
        bwd_dtype,
        contracting_dims,    # pylint: disable=unused-argument
360
361
        zero_centered_gamma,
        epsilon,
362
363
364
365
366
367
368
369
370
371
372
373
374
375
        sharding_type,
        dp_axis_name,
        tp_axis_name,
        ctx,
        g):
    ln_out_, kernel_cast, \
    fp8_maxs, amax, scale, scale_inv, \
    input_amax, kernel_amax, \
    inputs_shape, kernel_shape, \
    mu, rsigma, inputs, gamma = ctx

    gemm_input_idx, gemm_kernel_idx, gemm_grad_idx = \
        FP8Helper.get_fp8_meta_indices(0)

376
    grad_amax = amax[gemm_grad_idx, 0:1]
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
    grad_scale = scale[gemm_grad_idx]
    grad_scale_inv = scale_inv[gemm_grad_idx]

    ln_out_trans = transpose(ln_out_, fwd_dtype)
    g = jnp.reshape(g, (ln_out_trans.shape[1], -1))

    # cast and transpose the grad_output
    grad_cast, grad_cast_trans, grad_amax = cast_transpose(g, grad_amax, grad_scale, grad_scale_inv,
                                                           bwd_dtype)

    input_scale_inv = scale_inv[gemm_input_idx]
    wgrad = gemm(grad_cast_trans, grad_scale_inv, bwd_dtype, True, ln_out_trans, input_scale_inv,
                 fwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD)

    kernel_scale_inv = scale_inv[gemm_kernel_idx]
    dgrad = gemm(kernel_cast, kernel_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv,
                 bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD)

    dgrad = jnp.reshape(dgrad, inputs_shape)

    if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
        dgrad = jax.lax.psum(dgrad, tp_axis_name)

    if layernorm_type == 'layernorm':
401
402
403
404
405
406
407
        grad_input, grad_gamma, grad_beta = layernorm_bwd(dgrad,
                                                          mu,
                                                          rsigma,
                                                          inputs,
                                                          gamma,
                                                          zero_centered_gamma=zero_centered_gamma,
                                                          epsilon=epsilon)
408
    else:
409
410
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
411
412
413
        grad_input, grad_gamma = rmsnorm_bwd(dgrad, rsigma, inputs, gamma, epsilon=epsilon)
        grad_beta = None

414
415
416
    amax = amax.at[gemm_input_idx, 0].set(input_amax[0])
    amax = amax.at[gemm_kernel_idx, 0].set(kernel_amax[0])
    amax = amax.at[gemm_grad_idx, 0].set(grad_amax[0])
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434

    if is_dp_enabled(sharding_type.value[0]):
        wgrad = jax.lax.psum(wgrad, dp_axis_name)
        grad_gamma = jax.lax.psum(grad_gamma, dp_axis_name)
        if grad_beta is not None:
            grad_beta = jax.lax.psum(grad_beta, dp_axis_name)
        amax = jax.lax.pmax(amax, dp_axis_name)

    if is_tp_enabled(sharding_type.value[0]):
        amax = jax.lax.pmax(amax, tp_axis_name)

    wgrad = jnp.reshape(wgrad, kernel_shape)
    return grad_input, wgrad, \
           grad_gamma, grad_beta, \
           fp8_maxs, amax, scale, scale_inv


_layernorm_fp8_dot.defvjp(_layernorm_fp8_dot_fwd, _layernorm_fp8_dot_bwd)