test_custom_call_compute.py 26.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import functools
import operator

import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import lax
from jax import jit, value_and_grad
from flax import linen as nn

16
from utils import assert_allclose
17
18
19
20
21
22
from transformer_engine.common.recipe import Format
from transformer_engine.jax.cpp_extensions import dgated_gelu, gated_gelu
from transformer_engine.jax.cpp_extensions import dgated_gelu_cast_transpose, gated_gelu_fp8
from transformer_engine.jax.cpp_extensions import dequantize, quantize
from transformer_engine.jax.dot import fp8_dot
from transformer_engine.jax.fp8 import DType, FP8GemmPackage, FP8Helper, _format2dtypes
23
from transformer_engine.jax.fp8 import is_fp8_available
24
25
26
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.mlp import fp8_ln_mlp

Tim Moon's avatar
Tim Moon committed
27
28
29
30
31
32
33
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (2048, 1024, 2048),
    (2048, 2048, 1024),
    (2048, 1024, 1024),
]
34
35
36
FP8_COMPUTE_TYPE = [_format2dtypes(Format.E4M3), _format2dtypes(Format.HYBRID)]
LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32]
37
is_fp8_supported, reason = is_fp8_available()
38
39
40
41


class TestFP8Dot:

42
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
43
44
45
46
47
48
49
50
    def test_qdq(self):
        FP8_E4M3_MAX = 448
        x = jnp.asarray([[-1, 0.1], [2, 3]], jnp.float32)
        amax = jnp.max(jnp.abs(x)).reshape(1)
        scale = jnp.asarray(FP8_E4M3_MAX / amax, jnp.float32).reshape(1)
        scale_inv = (1 / scale).reshape(1)

        y, new_amax = quantize(x, amax, scale, scale_inv, out_dtype=DType.kFloat8E4M3)
51
        assert_allclose(new_amax, 3.0, rtol=0, atol=0)
52
53
54
55
56
57
58
59

        no_use = jnp.zeros(1, jnp.float32)
        z = dequantize(y,
                       no_use,
                       no_use,
                       scale_inv,
                       fp8_dtype=DType.kFloat8E4M3,
                       out_dtype=DType.kFloat32)
60
        assert_allclose(z, x, dtype=DType.kFloat8E4M3)
61
62
63
64
65
66
67
68
69

    def test_compile_bf16(self):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (256, 512), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (512, 256), jnp.bfloat16)

        def func(x, y):
            fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
70
            fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
71
72
73
74
75
76
77
                                       jnp.float32)
            fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
            fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
            # x = input, matrix 2d
            # y = input, matrix 2d (weight)
            fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
                                          fp8_metas_scale_inv)
78
            return jnp.sum(fp8_dot(fp8_gemm_pkg, *_format2dtypes(None)))
79
80
81
82
83

        value_n_grad_func = value_and_grad(func, (0, 1))
        value_n_grad_func_compiled = jit(value_n_grad_func).lower(a, b).compile()
        value_n_grad_func_compiled(a, b)

84
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
85
86
87
88
89
90
91
92
93
    @pytest.mark.parametrize('compute_type', FP8_COMPUTE_TYPE)
    def test_compile_fp8(self, compute_type):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (256, 512), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (512, 256), jnp.bfloat16)

        def func(x, y):
            fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
94
            fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
95
96
97
98
99
                                       jnp.float32)
            fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
            fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
            fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
                                          fp8_metas_scale_inv)
100
            return jnp.sum(fp8_dot(fp8_gemm_pkg, *compute_type))
101
102
103
104
105
106
107
108
109
110
111
112
113

        value_n_grad_func = value_and_grad(func, (0, 1))
        value_n_grad_func_compiled = jit(value_n_grad_func).lower(a, b).compile()
        value_n_grad_func_compiled(a, b)

    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    def test_forward_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

        fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
114
        fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
115
116
117
118
119
                                   jnp.float32)
        fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
        fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
        fp8_gemm_pkg = FP8GemmPackage(1, a, [b], fp8_max, fp8_metas_amax, fp8_metas_scale,
                                      fp8_metas_scale_inv)
120
121
        fwd_dtype, bwd_dtype = _format2dtypes(None)
        primitive_out = fp8_dot(fp8_gemm_pkg, fwd_dtype, bwd_dtype)
122
123
        ref_out = jnp.dot(a, b)

124
        assert_allclose(primitive_out, ref_out, dtype=fwd_dtype)
125

126
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
127
128
129
130
131
132
133
134
135
136
137
138
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    @pytest.mark.parametrize('compute_type', FP8_COMPUTE_TYPE)
    def test_forward_fp8_randint(self, m, n, k, compute_type):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        # TODO(rewang): add float random test
        min_val, max_val = -8, 8
        a = jax.random.randint(subkeys[0], (m, k), min_val, max_val).astype(jnp.bfloat16)
        b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(jnp.bfloat16)

        fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
139
        fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
140
141
142
143
144
145
146
                                   jnp.float32)
        fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
        fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
        fp8_meta = [fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv]

        # calculate amax
        fp8_gemm_pkg = FP8GemmPackage(1, a, [b], *fp8_meta)
147
        primitive_out = fp8_dot(fp8_gemm_pkg, *compute_type)
148
149
150
151
        # calculate scale by amax
        fp8_meta = FP8Helper._update_fp8_metas_impl(fp8_meta)

        fp8_gemm_pkg = FP8GemmPackage(1, a, [b], *fp8_meta)
152
        primitive_out = fp8_dot(fp8_gemm_pkg, *compute_type)
153
154
155
156
157
        ref_out = jnp.dot(a, b)

        ref_out = ref_out.astype(jnp.float32)
        primitive_out = primitive_out.astype(jnp.float32)

158
        assert_allclose(primitive_out, ref_out, dtype=compute_type[0])
159
160
161
162
163
164
165

    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    def test_grad_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)
166
        fwd_dtype, bwd_dtype = _format2dtypes(None)
167
168
169

        def primitive_func(x, y):
            fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
170
            fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
171
172
173
174
175
                                       jnp.float32)
            fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
            fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
            fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
                                          fp8_metas_scale_inv)
176
            return jnp.mean(fp8_dot(fp8_gemm_pkg, fwd_dtype, bwd_dtype))
177
178
179
180
181
182
183
184
185
186
187

        def ref_func(x, y):
            return jnp.mean(jnp.dot(x, y))

        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))

        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b)
        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

188
189
190
        assert_allclose(primitive_out, ref_out, dtype=fwd_dtype)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=bwd_dtype)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=bwd_dtype)
191

192
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
193
194
195
196
197
198
199
200
201
202
203
204
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    @pytest.mark.parametrize('compute_type', FP8_COMPUTE_TYPE)
    def test_grad_fp8_randint(self, m, n, k, compute_type):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        # TODO(rewang): add float random test
        min_val, max_val = -8, 8
        a = jax.random.randint(subkeys[0], (m, k), min_val, max_val).astype(jnp.bfloat16)
        b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(jnp.bfloat16)

        fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
205
        fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
206
207
208
209
210
211
212
                                   jnp.float32)
        fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
        fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
        fp8_meta = [fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv]

        def primitive_func(x, y, metas):
            fp8_gemm_pkg = FP8GemmPackage(1, x, [y], *metas)
213
            return jnp.sum(fp8_dot(fp8_gemm_pkg, *compute_type))
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231

        def ref_func(x, y):
            return jnp.sum(jnp.dot(x, y))

        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

        # calculate amax
        primitive_out, (primitive_a_grad,
                        primitive_b_grad) = value_n_grad_primitive_func(a, b, fp8_meta)

        # calculate scale by amax
        fp8_meta = FP8Helper._update_fp8_metas_impl(fp8_meta)
        primitive_out, (primitive_a_grad,
                        primitive_b_grad) = value_n_grad_primitive_func(a, b, fp8_meta)

232
233
234
        assert_allclose(primitive_out, ref_out, dtype=compute_type[0])
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=compute_type[1])
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=compute_type[1])
235
236
237
238
239
240

    def test_contracting_dims_bf16(self):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (32, 8, 16, 64), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (16, 64, 128), jnp.bfloat16)
241
        fwd_dtype, bwd_dtype = _format2dtypes(None)
242
243
244

        def primitive_func(x, y):
            fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
245
            fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
246
247
248
249
250
                                       jnp.float32)
            fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
            fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
            fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
                                          fp8_metas_scale_inv)
251
            return jnp.sum(fp8_dot(fp8_gemm_pkg, fwd_dtype, bwd_dtype, ((2, 3), (0, 1))))
252
253
254
255
256
257
258
259
260

        def ref_func(x, y):
            return jnp.sum(lax.dot_general(x, y, dimension_numbers=(((2, 3), (0, 1)), ((), ()))))

        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))
        primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b)
        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

261
262
263
        assert_allclose(primitive_out, ref_out, dtype=fwd_dtype)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=bwd_dtype)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=bwd_dtype)
264

265
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
266
267
268
269
270
271
272
273
274
275
276
277
278
    @pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024),
                                       (16384, 1024, 1024)])
    def test_grad_fp8_mlp_randint(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 4)
        activations = ('gelu', 'linear')

        a = jax.random.uniform(subkeys[0], (m, k), jnp.bfloat16, 5, 8)
        k1 = jax.random.uniform(subkeys[1], (k, n * len(activations)), jnp.bfloat16, 5, 8)
        k2 = jax.random.uniform(subkeys[2], (n, k), jnp.bfloat16, 5, 8)
        s = jax.random.uniform(subkeys[3], (k,), jnp.bfloat16, 5, 8)

        fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
279
        fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN),
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
                                   jnp.float32)
        fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
        fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
        fp8_meta = [fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv]
        compute_type = _format2dtypes(Format.HYBRID)

        def primitive_func(x, ln_s, y, z, metas):
            # x is input tensor, matrix 2d
            # y, z are weights, matrix 2d
            # out = (x * y) * z
            fp8_gemm_pkg = FP8GemmPackage(2, x, [y, z], *metas)
            return jnp.mean(
                fp8_ln_mlp(fp8_gemm_pkg,
                           ln_s,
                           None,
                           "rmsnorm",
                           *compute_type,
                           activations=activations))

        def _convert_to_activation_function(fn_or_string):
            """Convert a string to an activation function."""
            if fn_or_string == 'linear':
                return lambda x: x
            if isinstance(fn_or_string, str):
                return getattr(nn, fn_or_string)
            if callable(fn_or_string):
                return fn_or_string
            raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")

        def fp8_ln_mlp_py(inputs: jnp.ndarray,
                          ln_scale: jnp.ndarray,
                          kernel_1: jnp.ndarray,
                          kernel_2: jnp.ndarray,
                          fp8_maxs: jnp.ndarray,
                          amax: jnp.ndarray,
                          scale: jnp.ndarray,
                          scale_inv: jnp.ndarray,
                          fwd_dtype,
                          bwd_dtype,
                          epsilon=1e-6,
                          contracting_dims=((-1,), (0,)),
                          dp_dim_index=0,
                          activations=('gelu', 'linear')) -> jnp.ndarray:
            x = jnp.asarray(inputs, jnp.float32)
            mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
            y = jnp.asarray(x * jax.lax.rsqrt(mean2 + epsilon), jnp.bfloat16)
            ln_out = y * ln_scale
            ln_out = jnp.asarray(ln_out, jnp.bfloat16)
            fp8_gemm_1_pkg = FP8GemmPackage(1, ln_out, [kernel_1],
                                            fp8_maxs[:FP8Helper.NUM_META_PER_GEMM],
                                            amax[:FP8Helper.NUM_META_PER_GEMM],
                                            scale[:FP8Helper.NUM_META_PER_GEMM],
                                            scale_inv[:FP8Helper.NUM_META_PER_GEMM])
            linear_1_out = fp8_dot(fp8_gemm_1_pkg,
                                   fwd_dtype,
                                   bwd_dtype,
                                   contracting_dims,
                                   dp_dim_index=dp_dim_index)
            x = jnp.split(linear_1_out, len(activations), axis=-1)
            acts = []
            for idx, act_fn in enumerate(activations):
                x_i = _convert_to_activation_function(act_fn)(x[idx])
                acts.append(x_i)
            x = functools.reduce(operator.mul, acts)
            x = jnp.asarray(x, jnp.bfloat16)
            fp8_gemm_2_pkg = FP8GemmPackage(1, x, [kernel_2],
                                            fp8_maxs[FP8Helper.NUM_META_PER_GEMM:],
                                            amax[FP8Helper.NUM_META_PER_GEMM:],
                                            scale[FP8Helper.NUM_META_PER_GEMM:],
                                            scale_inv[FP8Helper.NUM_META_PER_GEMM:])
            output = fp8_dot(fp8_gemm_2_pkg,
                             fwd_dtype,
                             bwd_dtype,
                             contracting_dims,
                             dp_dim_index=dp_dim_index)
            return output

        def ref_func(x, ln_s, y, z, metas):
            return jnp.mean(
359
                fp8_ln_mlp_py(x, ln_s, y, z, *metas, *compute_type, activations=activations))
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375

        value_n_grad_primitive_func = jit(value_and_grad(primitive_func, (0, 1, 2, 3)))
        value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3)))

        ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad,
                  ref_k2_grad) = value_n_grad_ref_func(a, s, k1, k2, fp8_meta)

        # calculate amax
        primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
                        primitive_k2_grad) = value_n_grad_primitive_func(a, s, k1, k2, fp8_meta)

        # calculate scale by amax
        fp8_meta = FP8Helper._update_fp8_metas_impl(fp8_meta)
        primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
                        primitive_k2_grad) = value_n_grad_primitive_func(a, s, k1, k2, fp8_meta)

376
        assert_allclose(primitive_out, ref_out, dtype=compute_type[0])
377
378
        assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
                        jnp.asarray(ref_a_grad, np.float32),
379
                        dtype=compute_type[1])
380
381
        assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
                        jnp.asarray(ref_k1_grad, np.float32),
382
                        dtype=compute_type[1])
383
384
        assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
                        jnp.asarray(ref_k2_grad, np.float32),
385
                        dtype=compute_type[1])
386
387
        assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
                        jnp.asarray(ref_s_grad, np.float32),
388
                        dtype=compute_type[1])
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476


@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
    key = jax.random.PRNGKey(0)
    subkeys = jax.random.split(key, 4)
    out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
    return out


class TestGatedGeLu:

    def ref_func(self, inputs):

        def jax_gated_gelu(x):
            x = jnp.split(x, 2, axis=-1)
            acts = [jax.nn.gelu(x[0]), x[1]]
            x = functools.reduce(operator.mul, acts)
            x = jnp.asarray(x, jnp.bfloat16)
            return x

        func = jit(value_and_grad(lambda x: jnp.mean(jax_gated_gelu(x))))
        return func(inputs)

    def prim_func(self, inputs):

        @jax.custom_vjp
        def primitive(x):
            out, _ = primitive_fwd(x)
            return out

        def primitive_fwd(x):
            out = gated_gelu(x)
            ctx = x
            return out, ctx

        def primitive_bwd(ctx, g):
            x = ctx
            out = dgated_gelu(g, x)
            return (out,)

        primitive.defvjp(primitive_fwd, primitive_bwd)
        func = jit(value_and_grad(lambda x: jnp.mean(primitive(x))))
        return func(inputs)

    @pytest.mark.parametrize('shape', [(32, 64), (64, 256)])
    def test_gated_gelu(self, random_inputs):
        x = random_inputs
        prim_out, prim_grad = self.prim_func(x)
        ref_out, ref_grad = self.ref_func(x)

        assert_allclose(prim_out, ref_out, rtol=1e-2)
        assert_allclose(prim_grad, ref_grad, rtol=1e-1, atol=1e-3)


class TestGatedGeLuFP8(TestGatedGeLu):

    def prim_func(self, inputs):
        amax = self.amax
        scale = self.scale
        scale_inv = self.scale_inv
        no_use = jnp.zeros(1, jnp.float32)

        @jax.custom_vjp
        def primitive(x, y, z):
            out = primitive_fwd(x, y, z)
            return out

        def primitive_fwd(x, y, z):    # pylint: disable=unused-argument
            out, _ = gated_gelu_fp8(x, amax, scale, scale_inv, DType.kFloat8E5M2)
            out = dequantize(out, no_use, no_use, scale_inv, DType.kFloat8E5M2, DType.kBFloat16)
            ctx = x
            return out, ctx

        def primitive_bwd(ctx, g):
            x = ctx
            dgelu, dgelu_trans, amax_out = dgated_gelu_cast_transpose(g, x, amax, scale, scale_inv,
                                                                      DType.kFloat8E5M2)
            dgelu = dequantize(dgelu, no_use, no_use, scale_inv, DType.kFloat8E5M2, DType.kFloat32)
            dgelu_trans = dequantize(dgelu_trans, no_use, no_use, scale_inv, DType.kFloat8E5M2,
                                     DType.kFloat32)
            return dgelu, dgelu_trans, amax_out

        primitive.defvjp(primitive_fwd, primitive_bwd)
        func = jit(value_and_grad(lambda x, y, z: jnp.mean(primitive(x, y, z)), (0, 1, 2)))

        return func(inputs, no_use, no_use)

477
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
    @pytest.mark.parametrize('shape', [(32, 64), (64, 256)])
    def test_gated_gelu(self, random_inputs):
        self.amax = jnp.zeros(1, jnp.float32)
        self.scale = jnp.ones(1, jnp.float32)
        self.scale_inv = jnp.ones(1, jnp.float32)

        x = random_inputs
        prim_out, (prim_grad, prim_grad_trans, amax) = self.prim_func(x)
        ref_out, ref_grad = self.ref_func(x)

        assert_allclose(prim_out, ref_out, rtol=1e-2)
        assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
        assert_allclose(prim_grad, ref_grad, rtol=1e-1, atol=1e-3)
        assert_allclose(prim_grad_trans, jnp.transpose(ref_grad), rtol=1e-1, atol=1e-3)


class TestRMSNorm:

    @pytest.mark.parametrize('n, hidden', LN_CASES)
    @pytest.mark.parametrize('dtype', DTYPES)
    def test_forward_backward(self, n, hidden, dtype):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -2, 1)
        scale = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, -2, 1)
        scale = jnp.asarray(scale, dtype)
        epsilon = 1e-6

        def reference_rmsnorm(x, scale):
            x = jnp.asarray(x, jnp.float32)
            mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
            y = jnp.asarray(x * lax.rsqrt(mean2 + epsilon), dtype)
            return y * scale

        jitted_primitive = jit(
            value_and_grad(lambda x, scale: jnp.mean(layernorm(x, scale, None, "rmsnorm")), (0, 1)))

        jitted_reference = jit(
            value_and_grad(lambda x, scale: jnp.mean(reference_rmsnorm(x, scale)), (0, 1)))

        primitive_out, (primitive_dx, primitive_dgamma) = jitted_primitive(x, scale)
        reference_out, (reference_dx, reference_dgamma) = jitted_reference(x, scale)

        if dtype == jnp.float32:
            assert_allclose(primitive_out, reference_out, rtol=1e-7)
            assert_allclose(primitive_dx, reference_dx, rtol=1e-7)
            assert_allclose(primitive_dgamma, reference_dgamma, rtol=1e-7)
        else:
            assert_allclose(primitive_out, reference_out, rtol=1e-3)
            assert_allclose(primitive_dx, reference_dx, rtol=1e-4, atol=5e-8)
            assert_allclose(primitive_dgamma, reference_dgamma, rtol=1e-4, atol=5e-8)


class TestLayerNorm:

    @pytest.mark.parametrize('n, hidden', LN_CASES)
    @pytest.mark.parametrize('dtype', DTYPES)
536
537
    @pytest.mark.parametrize('zero_centered_gamma', [False, True])
    def test_forward_backward(self, n, hidden, zero_centered_gamma, dtype):
538
539
540
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 3)

541
542
543
        x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1)
        scale_range = (-1, 1) if zero_centered_gamma else (0, 2)
        scale = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *scale_range)
544
        scale = jnp.asarray(scale, dtype)
545
        bias = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
546
547
548
        bias = jnp.asarray(bias, dtype)
        epsilon = 1e-6

549
550
551
552
553
        def reference_layernorm(x, scale, bias, zero_centered_gamma, eps):
            x_ = jnp.asarray(x, jnp.float32)
            mean = jnp.mean(x_, axis=-1, keepdims=True)
            var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
            normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
554
            # Align TE implementation
555
556
557
558
559
560
561
562
            if zero_centered_gamma:
                return jnp.asarray(normed_input * (scale + 1) + bias).astype(x.dtype)
            return jnp.asarray(normed_input * scale + bias).astype(x.dtype)

        def compute_loss(x):
            # Higher precision to compute the loss
            x_ = x.astype(jnp.float32)
            return jnp.mean(jnp.square(x_)).astype(x.dtype)
563
564

        jitted_primitive = jit(
565
566
567
568
            value_and_grad(
                lambda x, scale, bias: compute_loss(
                    layernorm(x, scale, bias, "layernorm", zero_centered_gamma, epsilon)),
                (0, 1, 2)))
569
570

        jitted_reference = jit(
571
572
573
            value_and_grad(
                lambda x, scale, bias: compute_loss(
                    reference_layernorm(x, scale, bias, zero_centered_gamma, epsilon)), (0, 1, 2)))
574
575
576
577
578
579
580
581
582
583
584
585

        primitive_out, (primitive_dx, primitive_dgamma,
                        primitive_dbeta) = jitted_primitive(x, scale, bias)
        reference_out, (reference_dx, reference_dgamma,
                        reference_dbeta) = jitted_reference(x, scale, bias)

        if dtype == jnp.float32:
            assert_allclose(primitive_out, reference_out, rtol=1e-7)
            assert_allclose(primitive_dx, reference_dx, rtol=1e-7)
            assert_allclose(primitive_dgamma, reference_dgamma, rtol=1e-7)
            assert_allclose(primitive_dbeta, reference_dbeta, rtol=1e-7)
        else:
586
587
588
589
            assert_allclose(primitive_out, reference_out, rtol=1e-7)
            assert_allclose(primitive_dx, reference_dx, rtol=1e-5, atol=1e-6)
            assert_allclose(primitive_dgamma, reference_dgamma, rtol=1e-5, atol=3e-5)
            assert_allclose(primitive_dbeta, reference_dbeta, rtol=1e-5, atol=3e-5)