mlp.py 12.1 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX MLP modules"""

6
7
from typing import List
from functools import partial
8
9
10
11

import jax
import jax.numpy as jnp

12
from .cpp_extensions import cast_fp8, transpose, cast_transpose
13
14
15
16
from .cpp_extensions import gated_gelu, gated_gelu_fp8
from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd
17
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize
18
from .layernorm import canonicalize_layernorm_type
19
from .fp8 import FP8Helper, FP8MetaPackage
20
21


22
def geglu(x: jnp.ndarray):
23
24
25
    """
    Gated gelu
    """
26
    assert x.shape[-2] == 2    # Linear + GeLU
27

28
    output = _geglu(x)
29
30
31
32

    return output


33
34
@partial(jax.custom_vjp)
def _geglu(x: jnp.ndarray):
35

36
    geglu_output, _ = _geglu_fwd_rule(x)
37
38
39
40

    return geglu_output


41
42
43
def _geglu_fwd_rule(x):
    geglu_output = gated_gelu(x)
    return geglu_output, (x,)
44
45


46
47
48
def _geglu_bwd_rule(ctx, g):
    x, = ctx
    assert x.dtype == g.dtype
49

50
51
    dgelu = dgated_gelu(g, x)
    dgelu = jnp.reshape(dgelu, x.shape)
52
53
54
    return (dgelu,)


55
_geglu.defvjp(_geglu_fwd_rule, _geglu_bwd_rule)
56
57


58
59
60
61
62
63
64
65
def layernrom_geglu_fp8_mlp(x: jnp.ndarray,
                            gamma: jnp.ndarray,
                            beta: jnp.ndarray,
                            kernels: List[jnp.ndarray],
                            fp8_gemm_pkg: FP8MetaPackage,
                            layernorm_type: str,
                            zero_centered_gamma: bool = False,
                            epsilon: float = 1e-6) -> jnp.ndarray:
66
    """
67
    Layernorm + GEMM1 + GeGLU + GEMM2
68
    """
69
70
71
72
73
74

    assert len(kernels) == 2
    assert fp8_gemm_pkg.num_of_gemm == len(kernels)

    kernel_1 = kernels[0]
    kernel_2 = kernels[1]
75
76
77
78
79
    fp8_max = fp8_gemm_pkg.fp8_max
    amax = fp8_gemm_pkg.amax
    scale = fp8_gemm_pkg.scale
    scale_inv = fp8_gemm_pkg.scale_inv

80
81
82
    fwd_dtype = FP8Helper.FWD_DTYPE
    bwd_dtype = FP8Helper.BWD_DTYPE

83
84
    layernorm_type = canonicalize_layernorm_type(layernorm_type)
    if layernorm_type == 'rmsnorm':
85
        assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
86
87
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
88

89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    output = _layernrom_geglu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale,
                                      scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
                                      zero_centered_gamma, epsilon)
    return output


@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13))
def _layernrom_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
                             kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_max: jnp.ndarray,
                             amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
                             fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, layernorm_type: str,
                             zero_centered_gamma: bool, epsilon: float):
    output, _ = _layernrom_geglu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax,
                                                  scale, scale_inv, fwd_dtype, bwd_dtype,
                                                  layernorm_type, zero_centered_gamma, epsilon)
    return output


def _layernrom_geglu_fp8_mlp_fwd_rule(
        x,
109
110
111
112
        gamma,
        beta,
        kernel_1,
        kernel_2,
113
        fp8_max,
114
115
116
117
118
        amax,
        scale,
        scale_inv,
        fwd_dtype,
        bwd_dtype,    # pylint: disable=unused-argument
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        layernorm_type,
        zero_centered_gamma,
        epsilon):

    # x should be in shape of (batch..., hidden)
    # Kernel_1 should be in shape of (Hidden_in, 2, Hidden_out)
    # Kernel_2 should be in shape of (Hidden_in, Hidden_out)
    assert len(kernel_1.shape) == 3
    assert kernel_1.shape[-2] == 2
    assert len(kernel_2.shape) == 2

    x_contracting_dims = (len(x.shape) - 1,)
    xt_batch_dims = tuple(range(1, x.ndim))

    assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
    assert kernel_1.shape[-1] == kernel_2.shape[0]
135

136
137
    amax = FP8Helper.update_amax_history(amax)

138
139
140
141
142
    gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)

    x_amax = amax[gemm1_x_idx, 0:1]
    x_scale = scale[gemm1_x_idx]
    x_scale_inv = scale_inv[gemm1_x_idx]
143
144

    if layernorm_type == 'layernorm':
145
146
147
148
149
150
151
152
153
154
        ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
            x,
            gamma,
            beta,
            x_amax,
            x_scale,
            x_scale_inv,
            out_dtype=fwd_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon)
155
    else:
156
157
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
158
159
160
161
162
163
164
        ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
                                                         gamma,
                                                         x_amax,
                                                         x_scale,
                                                         x_scale_inv,
                                                         out_dtype=fwd_dtype,
                                                         epsilon=epsilon)
165
166
        mu = None

167
168
    assert x.shape == ln_out.shape

169
    kernel_1_amax = amax[gemm1_kernel_idx, 0:1]
170
171
172
    kernel_1_scale = scale[gemm1_kernel_idx]
    kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]

173
174
175
176
    # Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel_1, updated_kernel_1_amax = \
        cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)
177

178
179
180
181
    # (batch..., hidden_in) x (hidden_in, 2, hidden_out)
    dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
                                (x_contracting_dims, (0,)),
                                get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
182
183
184
185
186
187
188
189
190
191
192

    gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)

    geglu_out_amax = amax[gemm2_x_idx, 0:1]
    geglu_out_scale = scale[gemm2_x_idx]
    geglu_out_scale_inv = scale_inv[gemm2_x_idx]

    # (batch..., hidden_in) -> (batch..., hidden)
    casted_geglu_out, updated_geglu_amax = gated_gelu_fp8(dot_1_output, geglu_out_amax,
                                                          geglu_out_scale, geglu_out_scale_inv,
                                                          fwd_dtype)
193
194
195

    kernel_2_scale = scale[gemm2_kernel_idx]
    kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
196
197
198
    # Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)
199
200

    # (batch..., hidden_in) x (hidden_out, hidden_in)
201
202
203
    dot_2_output = fp8_dot_impl(casted_geglu_out, casted_kernel_2, geglu_out_scale_inv,
                                kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)),
                                get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
204

205
206
    ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, casted_kernel_1,
           casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_geglu_amax,
207
208
209
210
211
212
213
214
           updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims)

    return dot_2_output, ctx


def _layernrom_geglu_fp8_mlp_bwd_rule(
        fwd_dtype,    # pylint: disable=unused-argument
        bwd_dtype,
215
        layernorm_type,
216
        zero_centered_gamma,
217
218
        epsilon,
        ctx,
219
220
        grad):
    x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, \
221
    casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \
222
223
    updated_geglu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
    x_contracting_dims, xt_batch_dims = ctx
224

225
    gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)
226

227
    grad_amax = amax[gemm2_grad_idx, 0:1]
228
229
230
    grad_scale = scale[gemm2_grad_idx]
    grad_scale_inv = scale_inv[gemm2_grad_idx]

231
232
233
234
235
236
237
    casted_grad, casted_grad_t, updated_grad_amax = \
        cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
                       static_axis_boundary=-1, transpose_axis_boundary=-1)

    casted_geglu_out_t = transpose(casted_geglu_out,
                                   static_axis_boundary=-1,
                                   transpose_axis_boundary=-1)
238

239
240
241
    # (hidden, batch...,) x (hidden, batch...)
    gemm2_x_scale_inv = scale_inv[gemm2_x_idx]
    wgrad_2 = fp8_dot_impl(casted_geglu_out_t, casted_grad_t, gemm2_x_scale_inv, grad_scale_inv,
242
243
                           grad.dtype, (xt_batch_dims, xt_batch_dims),
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
244
245

    # (batch..., hidden_out) x (hidden_in, hidden_out)
246
    kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
247
248
249
    dgrad_2 = fp8_dot_impl(casted_grad, casted_kernel_2, grad_scale_inv, kernel_2_scale_inv,
                           grad.dtype, (x_contracting_dims, (1,)),
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
250
251
252
253
254
255

    gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)

    dgeglu_amax = amax[gemm1_grad_idx, 0:1]
    dgeglu_scale = scale[gemm1_grad_idx]
    dgeglu_scale_inv = scale_inv[gemm1_grad_idx]
256

257
258
259
260
261
262
263
264
    casted_dgeglu, casted_dgeglu_t, updated_dgeglu_amax = dgated_gelu_cast_transpose(
        dgrad_2,
        dot_1_output,
        dgeglu_amax,
        dgeglu_scale,
        dgeglu_scale_inv,
        bwd_dtype,
        static_axis_boundary=-1)
265

266
    ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
267

268
269
270
271
    # (hidden, batch...) x (2, hidden, batch...)
    xt_batch_dims_plus_act_dim = tuple(i + 1 for i in xt_batch_dims)
    gemm1_x_scale_inv = scale_inv[gemm1_x_idx]
    wgrad_1 = fp8_dot_impl(ln_out_t, casted_dgeglu_t, gemm1_x_scale_inv, dgeglu_scale_inv,
272
273
                           grad.dtype, (xt_batch_dims, xt_batch_dims_plus_act_dim),
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
274

275
276
277
    # (batch..., 2, hidden_out) x (hidden_in, 2, hidden_out)
    x_contracting_dims_plus_act_dim = (min(x_contracting_dims),) + tuple(
        i + 1 for i in x_contracting_dims)
278
    kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
279
280
281
    dgrad_1 = fp8_dot_impl(casted_dgeglu, casted_kernel_1, dgeglu_scale_inv, kernel_1_scale_inv,
                           grad.dtype, (x_contracting_dims_plus_act_dim, (1, 2)),
                           get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
282
283

    if layernorm_type == 'layernorm':
284
285
286
287
288
289
290
        dx, dgamma, dbeta = layernorm_bwd(dgrad_1,
                                          x,
                                          mu,
                                          rsigma,
                                          gamma,
                                          zero_centered_gamma=zero_centered_gamma,
                                          epsilon=epsilon)
291
    else:
292
293
        assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
            "if layernorm_type is 'rmsnorm'"
294
295
296
297
298
299
300
        dx, dgamma = rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
        dbeta = None

    amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0])
    amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0])
    amax = amax.at[gemm1_grad_idx, 0].set(updated_dgeglu_amax[0])
    amax = amax.at[gemm2_x_idx, 0].set(updated_geglu_amax[0])
301
    amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax)
302
303
304
305
306
307
308
309
310
311
    amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0])

    scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)

    return dx, dgamma, dbeta, wgrad_1, wgrad_2, \
           fp8_max, amax, scale, scale_inv


_layernrom_geglu_fp8_mlp.defvjp(_layernrom_geglu_fp8_mlp_fwd_rule,
                                _layernrom_geglu_fp8_mlp_bwd_rule)