mlp.py 18.2 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX MLP modules"""

6
from typing import List, Tuple, Sequence, Union, Callable
7
from functools import partial
8
9
10

import jax
import jax.numpy as jnp
11
from jax.ad_checkpoint import checkpoint_name
12

13
from .cpp_extensions import cast_fp8, transpose, cast_transpose, dbias_cast_transpose
14
15
from .cpp_extensions import act_lu, act_lu_fp8, dact_lu
from .cpp_extensions import dact_lu_dbias_cast_transpose, dgated_act_lu_cast_transpose
16
17
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd
18
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize
19
from .layernorm import canonicalize_layernorm_type
20
from .fp8 import FP8Helper, FP8MetaPackage
21
22
from .sharding import with_sharding_constraint_by_logical_axes

23

24
def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
25
    """
26
    Activation Unit
27
    """
28
    if len(activation_type) > 1:
29
        assert x.shape[-2] == 2    # Linear + GeLU
30
    output = _activation_lu(x, activation_type)
31
32
    return output

33

34
35
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
36

37
    _output, _ = _activation_lu_fwd_rule(x, activation_type)
38

39
    return _output
40

41

42
def _activation_lu_fwd_rule(x, activation_type):
43
    fwd_output = act_lu(x, activation_type)
44
    return fwd_output, (x,)
45

46

47
def _activation_lu_bwd_rule(activation_type, ctx, g):
48
49
    x, = ctx
    assert x.dtype == g.dtype
50

51
    dx = dact_lu(g, x, activation_type)
52
53
    dx = jnp.reshape(dx, x.shape)
    return (dx,)
54

55

56
_activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule)
57

58

59
def fused_layernorm_fp8_mlp(x: jnp.ndarray,
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
                            gamma: jnp.ndarray,
                            beta: jnp.ndarray,
                            kernels: List[jnp.ndarray],
                            biases: List[jnp.ndarray],
                            fp8_gemm_pkg: FP8MetaPackage,
                            layernorm_type: str,
                            zero_centered_gamma: bool = False,
                            epsilon: float = 1e-6,
                            layernorm_input_axes: Tuple[str, ...] = None,
                            dot_1_input_axes: Tuple[str, ...] = None,
                            dot_2_input_axes: Tuple[str, ...] = None,
                            ffn1_ckpt_name: str = 'ffn1',
                            ffn2_ckpt_name: str = 'ffn2',
                            activation_type: Sequence[Union[str, Callable]] = ('gelu',),
                            use_bias: bool = True) -> jnp.ndarray:
75
    """
76
    Layernorm + GEMM1 + bias + activation + GEMM2 + bias
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    """

    assert len(kernels) == 2
    assert fp8_gemm_pkg.num_of_gemm == len(kernels)

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
    bias_1 = biases[0]
    bias_2 = biases[1]
    fp8_max = fp8_gemm_pkg.fp8_max
    amax = fp8_gemm_pkg.amax
    scale = fp8_gemm_pkg.scale
    scale_inv = fp8_gemm_pkg.scale_inv

    fwd_dtype = FP8Helper.FWD_DTYPE
    bwd_dtype = FP8Helper.BWD_DTYPE

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
    if layernorm_type == 'rmsnorm':
        assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"

100
    output = _fused_layernorm_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max,
101
102
103
104
                                      amax, scale, scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
                                      zero_centered_gamma, epsilon, layernorm_input_axes,
                                      dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name,
                                      ffn2_ckpt_name, activation_type, use_bias)
105
106
107
    return output


108
109
@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22))
def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
110
111
112
113
114
115
116
117
118
119
120
121
122
                             kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray,
                             bias_2: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray,
                             scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype,
                             bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool,
                             epsilon: float, layernorm_input_axes: Tuple[str, ...],
                             dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
                             ffn1_ckpt_name: str, ffn2_ckpt_name: str,
                             activation_type: Sequence[Union[str, Callable]], use_bias: bool):
    output, _ = _fused_layernorm_fp8_mlp_fwd_rule(
        x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max, amax, scale, scale_inv,
        fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma, epsilon, layernorm_input_axes,
        dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, activation_type,
        use_bias)
123
124
125
    return output


126
def _fused_layernorm_fp8_mlp_fwd_rule(
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        x,
        gamma,
        beta,
        kernel_1,
        kernel_2,
        bias_1,
        bias_2,
        fp8_max,
        amax,
        scale,
        scale_inv,
        fwd_dtype,
        bwd_dtype,    # pylint: disable=unused-argument
        layernorm_type,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,
147
148
149
        ffn2_ckpt_name,
        activation_type,
        use_bias):
150
151
152
153
154

    # x should be in shape of (batch..., hidden)
    # Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out)
    # Kernel_2 should be in shape of (Hidden_in, Hidden_out)
    assert len(kernel_1.shape) == 3
155
    assert kernel_1.shape[-2] == len(activation_type)
156
157
158
159
160
161
162
163
    assert len(kernel_2.shape) == 2

    x_contracting_dims = (len(x.shape) - 1,)
    xt_batch_dims = tuple(range(1, x.ndim))

    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
    assert kernel_1.shape[-1] == kernel_2.shape[0]

164
165
166
167
    maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
        FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv)
    fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv)
    scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
168

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    amax = FP8Helper.update_amax_history(amax)

    gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)

    x_amax = amax[gemm1_x_idx, 0:1]
    x_scale = scale[gemm1_x_idx]
    x_scale_inv = scale_inv[gemm1_x_idx]

    x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)

    if layernorm_type == 'layernorm':
        ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
            x,
            gamma,
            beta,
            x_amax,
            x_scale,
            x_scale_inv,
            out_dtype=fwd_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon)
    else:
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
        ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
                                                         gamma,
                                                         x_amax,
                                                         x_scale,
                                                         x_scale_inv,
                                                         out_dtype=fwd_dtype,
                                                         epsilon=epsilon)
        mu = None

    assert x.shape == ln_out.shape

    kernel_1_amax = amax[gemm1_kernel_idx, 0:1]
    kernel_1_scale = scale[gemm1_kernel_idx]
    kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]

    # Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel_1, updated_kernel_1_amax = \
        cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)

    ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes)

    # (batch..., hidden_in) x (hidden_in, hidden_out)
    dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
                                (x_contracting_dims, (0,)),
                                get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
219
    if use_bias:
220
221
222
223
224
        bias_1_shape = bias_1.shape
        bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
        dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
    else:
        bias_1_shape = None
225
226
227
228
    dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)

    gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)

229
230
231
232
    activation_lu_out_amax = amax[gemm2_x_idx, 0:1]
    activation_lu_out_scale = scale[gemm2_x_idx]
    activation_lu_out_scale_inv = scale_inv[gemm2_x_idx]

233
234

    # (batch..., hidden_in) -> (batch..., hidden)
235
    casted_activation_lu_out, updated_activation_lu_amax = \
236
237
    act_lu_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale,
               activation_lu_out_scale_inv, fwd_dtype, activation_type)
238

239
240
    casted_activation_lu_out = with_sharding_constraint_by_logical_axes(
        casted_activation_lu_out, dot_2_input_axes)
241
242
243
244
245
246
247
248

    kernel_2_scale = scale[gemm2_kernel_idx]
    kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
    # Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)

    # (batch..., hidden_in) x (hidden_out, hidden_in)
249
    dot_2_output = fp8_dot_impl(casted_activation_lu_out, casted_kernel_2,
250
251
                                activation_lu_out_scale_inv, kernel_2_scale_inv, x.dtype,
                                (x_contracting_dims, (0,)),
252
253
                                get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))

254
    if use_bias:
255
256
257
258
259
        bias_2_shape = bias_2.shape
        bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
        dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
    else:
        bias_2_shape = None
260

261
262
    dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

263
264
265
    ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1,
           casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax,
           updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax,
266
           x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32)
267
268
269
270

    return dot_2_output, ctx


271
def _fused_layernorm_fp8_mlp_bwd_rule(
272
273
274
275
276
277
278
279
280
281
        fwd_dtype,    # pylint: disable=unused-argument
        bwd_dtype,
        layernorm_type,
        zero_centered_gamma,
        epsilon,
        layernorm_input_axes,
        dot_1_input_axes,
        dot_2_input_axes,
        ffn1_ckpt_name,    # pylint: disable=unused-argument
        ffn2_ckpt_name,    # pylint: disable=unused-argument
282
283
        activation_type,
        use_bias,
284
285
        ctx,
        grad):
286
    x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, \
287
    casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \
288
    updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
289
    x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32 = ctx
290
291
292
293
294
295
296
297
298

    gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)

    grad_amax = amax[gemm2_grad_idx, 0:1]
    grad_scale = scale[gemm2_grad_idx]
    grad_scale_inv = scale_inv[gemm2_grad_idx]

    # Since the sharding of outputs should be the same as dot_1's input
    grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
299
300
301
302
303
304
305
306
307
308
309
310
311
    if use_bias:
        casted_grad, casted_grad_t, dbias_2, updated_grad_amax = \
        dbias_cast_transpose(grad, grad_amax, grad_scale,
                             grad_scale_inv, bwd_dtype,
                             static_axis_boundary=-1,
                             transpose_axis_boundary=-1)
        dbias_2 = jnp.reshape(dbias_2, bias_2_shape)
    else:
        casted_grad, casted_grad_t, updated_grad_amax = \
        cast_transpose(grad, grad_amax, grad_scale,
                       grad_scale_inv, bwd_dtype,
                       static_axis_boundary=-1,
                       transpose_axis_boundary=-1)
312
        dbias_2 = None
313

314
315
316
    casted_activation_lu_out_t = transpose(casted_activation_lu_out,
                                           static_axis_boundary=-1,
                                           transpose_axis_boundary=-1)
317
318
319

    # (hidden, batch...,) x (hidden, batch...)
    gemm2_x_scale_inv = scale_inv[gemm2_x_idx]
320
321
    wgrad_2 = fp8_dot_impl(casted_activation_lu_out_t, casted_grad_t, gemm2_x_scale_inv,
                           grad_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims),
322
323
324
325
326
327
328
329
330
331
332
333
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))

    # (batch..., hidden_out) x (hidden_in, hidden_out)
    kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
    dgrad_2 = fp8_dot_impl(casted_grad, casted_kernel_2, grad_scale_inv, kernel_2_scale_inv,
                           grad.dtype, (x_contracting_dims, (1,)),
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))

    dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)

    gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)

334
335
336
337
    dactivation_lu_amax = amax[gemm1_grad_idx, 0:1]
    dactivation_lu_scale = scale[gemm1_grad_idx]
    dactivation_lu_scale_inv = scale_inv[gemm1_grad_idx]

338
    if len(activation_type) > 1:    # if gated
339
        if use_bias:
340
            dactivation_lu = dact_lu(dgrad_2, dot_1_output, activation_type)
341
342
343
344
345
346
347
348
349
350
351
352
            casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \
            dbias_cast_transpose(
                dactivation_lu,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
                static_axis_boundary=-1,
                transpose_axis_boundary=-2)
            dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
        else:
            casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
353
            dgated_act_lu_cast_transpose(
354
355
356
357
358
359
                dgrad_2,
                dot_1_output,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
360
361
362
                static_axis_boundary=-1,
                activation_type=activation_type)
            dbias_1 = None
363
364
    else:
        if use_bias:
365
366
            casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax=\
            dact_lu_dbias_cast_transpose(
367
368
369
370
371
372
373
                dgrad_2,
                dot_1_output,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
                static_axis_boundary=-1,
374
375
                transpose_axis_boundary=-2,
                activation_type=activation_type)
376
377
            dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
        else:
378
            dactivation_lu = dact_lu(dgrad_2, dot_1_output, activation_type)
379
380
381
382
383
384
385
386
            casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
            cast_transpose(
                dactivation_lu,
                dactivation_lu_amax,
                dactivation_lu_scale,
                dactivation_lu_scale_inv,
                bwd_dtype,
                static_axis_boundary=-1,
387
388
                transpose_axis_boundary=-2)
            dbias_1 = None
389
390
391
392
393

    ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)

    # (hidden, batch...) x (hidden, batch...)
    gemm1_x_scale_inv = scale_inv[gemm1_x_idx]
394
    xt_batch_dims_2 = tuple(i + 1 for i in xt_batch_dims)
395
    wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv,
396
                           dactivation_lu_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims_2),
397
398
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))

399
400
    x_contracting_dims = ((min(x_contracting_dims),) + tuple(
            i + 1 for i in x_contracting_dims), (1,2))
401
    kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
402
403
    dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, dactivation_lu_scale_inv,
                           kernel_1_scale_inv, grad.dtype, x_contracting_dims,
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))

    dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)

    if layernorm_type == 'layernorm':
        dx, dgamma, dbeta = layernorm_bwd(dgrad_1,
                                          x,
                                          mu,
                                          rsigma,
                                          gamma,
                                          zero_centered_gamma=zero_centered_gamma,
                                          epsilon=epsilon)
    else:
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
        dx, dgamma = rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
        dbeta = None

    amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0])
    amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0])
424
425
    amax = amax.at[gemm1_grad_idx, 0].set(updated_dactivation_lu_amax[0])
    amax = amax.at[gemm2_x_idx, 0].set(updated_activation_lu_amax[0])
426
427
428
    amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax)
    amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0])

429
430
    fp8_max, amax, scale, scale_inv = maybe_fp32_to_fm32(fp8_max, amax, scale, scale_inv)

431
432
433
434
    return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \
           fp8_max, amax, scale, scale_inv


435
_fused_layernorm_fp8_mlp.defvjp(_fused_layernorm_fp8_mlp_fwd_rule,
436
                                _fused_layernorm_fp8_mlp_bwd_rule)