test_custom_call_compute.py 30 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# See LICENSE for license information.

import functools
import operator

import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import lax
from jax import jit, value_and_grad
from flax import linen as nn

16
from utils import assert_allclose
17
18
from transformer_engine.jax.cpp_extensions import dgelu, dgelu_dbias_cast_transpose
from transformer_engine.jax.cpp_extensions import gelu, gelu_fp8
19
20
from transformer_engine.jax.cpp_extensions import dgated_gelu, gated_gelu
from transformer_engine.jax.cpp_extensions import dgated_gelu_cast_transpose, gated_gelu_fp8
21
22
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
23
from transformer_engine.jax.fp8 import is_fp8_available
24
from transformer_engine.jax.layernorm import layernorm
25
from transformer_engine.jax.mlp import layernorm_geglu_fp8_mlp
26
from transformer_engine.jax.mlp import layernorm_gelu_fp8_mlp
27

Tim Moon's avatar
Tim Moon committed
28
29
30
31
32
33
34
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (2048, 1024, 2048),
    (2048, 2048, 1024),
    (2048, 1024, 1024),
]
35
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
36
37
LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32]
38
is_fp8_supported, reason = is_fp8_available()
39
40


41
42
43
44
45
46
47
48
49
50
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
    """
    Clear all live arrays to keep the resource clean
    """
    yield
    for arr in jax.live_arrays():
        arr.delete()


51
52
class TestFP8Dot:

53
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
54
    def test_qdq(self):
55
        FP8_E4M3_MAX = (jnp.finfo(jnp.float8_e4m3fn).max).astype(jnp.float32)
56
57
58
59
60
        x = jnp.asarray([[-1, 0.1], [2, 3]], jnp.float32)
        amax = jnp.max(jnp.abs(x)).reshape(1)
        scale = jnp.asarray(FP8_E4M3_MAX / amax, jnp.float32).reshape(1)
        scale_inv = (1 / scale).reshape(1)

61
        y, _ = quantize(x, q_dtype=jnp.float8_e4m3fn, scale=scale)
62
        z = dequantize(y, dq_dtype=jnp.float32, scale_inv=scale_inv)
63

64
        assert_allclose(z, x, dtype=jnp.float8_e4m3fn)
65
66
67
68
69
70
71
72

    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    def test_forward_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

73
        primitive_out = type_safe_dot_general(a, b)
74
75
        ref_out = jnp.dot(a, b)

76
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
77

78
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
79
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
80
    def test_forward_fp8_randint(self, m, n, k):
81
82
83
84
85
86
87
88
89
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        # TODO(rewang): add float random test
        min_val, max_val = -8, 8
        a = jax.random.randint(subkeys[0], (m, k), min_val, max_val).astype(jnp.bfloat16)
        b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(jnp.bfloat16)

        fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
90
        fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
91
92
93
                                   jnp.float32)
        fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
        fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
94
95
96
97
        fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                      fp8_metas_scale_inv)

        primitive_out = type_safe_dot_general(a, b, fp8_meta_pkg)
98
99

        # calculate scale by amax
100
101
102
103
        fp8_metas_scale, fp8_metas_scale_inv = FP8Helper.update_fp8_scale(
            fp8_max, fp8_metas_amax, fp8_metas_scale)
        fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                      fp8_metas_scale_inv)
104

105
        primitive_out = type_safe_dot_general(a, b, fp8_meta_pkg)
106
107
108
109
110
        ref_out = jnp.dot(a, b)

        ref_out = ref_out.astype(jnp.float32)
        primitive_out = primitive_out.astype(jnp.float32)

111
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
112
113
114
115
116
117
118
119
120

    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    def test_grad_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

        def primitive_func(x, y):
121
122
            primitive_out = type_safe_dot_general(x, y)
            return jnp.mean(primitive_out)
123
124
125
126
127
128
129
130
131
132
133

        def ref_func(x, y):
            return jnp.mean(jnp.dot(x, y))

        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))

        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b)
        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

134
135
136
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=jnp.bfloat16)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=jnp.bfloat16)
137

138
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
139
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
140
    def test_grad_fp8_dot(self, m, n, k):
141
142
143
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

144
145
        a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)
146
147

        fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
148
        fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
149
150
151
152
                                   jnp.float32)
        fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
        fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)

153
154
155
156
157
        def primitive_func(x, y, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv):
            fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                          fp8_metas_scale_inv)
            primitive_out = type_safe_dot_general(x, y, fp8_meta_pkg)
            return jnp.mean(primitive_out)
158
159

        def ref_func(x, y):
160
            return jnp.mean(jnp.dot(x, y))
161

162
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5))
163
164
165
166
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

167
168
169
170
        for _ in range(3):
            primitive_out, (primitive_a_grad, primitive_b_grad, fp8_max, fp8_metas_amax,
                            fp8_metas_scale, fp8_metas_scale_inv) = value_n_grad_primitive_func(
                                a, b, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv)
171

172
173
174
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
175

176
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
177
178
    @pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024),
                                       (16384, 1024, 1024)])
179
    def test_grad_ln_geglu_fp8_mlp(self, m, n, k):
180
181
182
183
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 4)
        activations = ('gelu', 'linear')

184
185
186
187
188
189
190
191
192
193
194
195
196
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        k1 = jax.random.normal(subkeys[1], (k, len(activations), n), jnp.bfloat16)
        k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16)
        s = jax.random.normal(subkeys[3], (k,), jnp.bfloat16)

        init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
        init_fp8_metas_amax = jnp.zeros(
            (FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN), jnp.float32)
        init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
        init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)

        def primitive_func(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale,
                           fp8_metas_scale_inv):
197
198
199
            # x is input tensor, matrix 2d
            # y, z are weights, matrix 2d
            # out = (x * y) * z
200
201
            fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                          fp8_metas_scale_inv)
202
            return jnp.mean(layernorm_geglu_fp8_mlp(x, ln_s, None, [y, z], fp8_meta_pkg, "rmsnorm"))
203
204
205
206
207
208
209
210
211
212
213

        def _convert_to_activation_function(fn_or_string):
            """Convert a string to an activation function."""
            if fn_or_string == 'linear':
                return lambda x: x
            if isinstance(fn_or_string, str):
                return getattr(nn, fn_or_string)
            if callable(fn_or_string):
                return fn_or_string
            raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")

214
215
216
217
218
        def ln_geglu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
                                 kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
                                 scale: jnp.ndarray, scale_inv: jnp.ndarray) -> jnp.ndarray:

            x = jnp.asarray(x, jnp.float32)
219
            mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
220
            y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
221
222
            ln_out = y * ln_scale
            ln_out = jnp.asarray(ln_out, jnp.bfloat16)
223
224

            fp8_gemm_1_pkg = FP8MetaPackage(1, fp8_maxs[:FP8Helper.NUM_META_PER_GEMM],
225
226
227
                                            amax[:FP8Helper.NUM_META_PER_GEMM],
                                            scale[:FP8Helper.NUM_META_PER_GEMM],
                                            scale_inv[:FP8Helper.NUM_META_PER_GEMM])
228
229
230
            linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,)))

            x = jnp.split(linear_1_out, len(activations), axis=-2)
231
232
233
234
235
            acts = []
            for idx, act_fn in enumerate(activations):
                x_i = _convert_to_activation_function(act_fn)(x[idx])
                acts.append(x_i)
            x = functools.reduce(operator.mul, acts)
236
237
238
            x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)

            fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:],
239
240
241
                                            amax[FP8Helper.NUM_META_PER_GEMM:],
                                            scale[FP8Helper.NUM_META_PER_GEMM:],
                                            scale_inv[FP8Helper.NUM_META_PER_GEMM:])
242
            output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,)))
243
244
            return output

245
        def ref_func(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv):
246
            return jnp.mean(
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
                ln_geglu_fp8_mlp_ref(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                     fp8_metas_scale_inv))

        value_n_grad_primitive_func = jit(value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7)))
        value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7)))

        ref_fp8_max = init_fp8_max
        ref_fp8_metas_amax = init_fp8_metas_amax
        ref_fp8_metas_scale = init_fp8_metas_scale
        ref_fp8_metas_scale_inv = init_fp8_metas_scale_inv

        pri_fp8_max = init_fp8_max
        pri_fp8_metas_amax = init_fp8_metas_amax
        pri_fp8_metas_scale = init_fp8_metas_scale
        pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv

        for _ in range(3):
            ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_fp8_max,
                      ref_fp8_metas_amax, ref_fp8_metas_scale,
                      ref_fp8_metas_scale_inv) = value_n_grad_ref_func(
                          a, s, k1, k2, ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale,
                          ref_fp8_metas_scale_inv)

        for _ in range(3):
            primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
                            primitive_k2_grad, pri_fp8_max, pri_fp8_metas_amax, pri_fp8_metas_scale,
                            pri_fp8_metas_scale_inv) = value_n_grad_primitive_func(
                                a, s, k1, k2, pri_fp8_max, pri_fp8_metas_amax, pri_fp8_metas_scale,
                                pri_fp8_metas_scale_inv)

        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
278
279
        assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
                        jnp.asarray(ref_a_grad, np.float32),
280
                        dtype=FP8Helper.BWD_DTYPE)
281
282
        assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
                        jnp.asarray(ref_k1_grad, np.float32),
283
                        dtype=FP8Helper.BWD_DTYPE)
284
285
        assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
                        jnp.asarray(ref_k2_grad, np.float32),
286
                        dtype=FP8Helper.BWD_DTYPE)
287
288
        assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
                        jnp.asarray(ref_s_grad, np.float32),
289
                        dtype=FP8Helper.BWD_DTYPE)
290

291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024),
                                       (16384, 1024, 1024)])
    def test_grad_ln_gelu_fp8_mlp(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 6)
        activations = ('gelu',)

        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        k1 = jax.random.normal(subkeys[1], (k, len(activations), n), jnp.bfloat16)
        k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16)
        b1 = jax.random.normal(subkeys[3], (len(activations), n), jnp.bfloat16)
        b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
        s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)

        init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
        init_fp8_metas_amax = jnp.zeros(
            (FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN), jnp.float32)
        init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
        init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)

        def primitive_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
                           fp8_metas_scale_inv):
            # x is input tensor, matrix 2d
            # y, z are weights, matrix 2d
            # out = ((x * y) + w) * z + v
            fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                          fp8_metas_scale_inv)
            return jnp.mean(
                layernorm_gelu_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm"))

        def ln_gelu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
                                kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray,
                                fp8_maxs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
                                scale_inv: jnp.ndarray) -> jnp.ndarray:

            x = jnp.asarray(x, jnp.float32)
            mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
            y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
            ln_out = y * ln_scale
            ln_out = jnp.asarray(ln_out, jnp.bfloat16)

            fp8_gemm_1_pkg = FP8MetaPackage(1, fp8_maxs[:FP8Helper.NUM_META_PER_GEMM],
                                            amax[:FP8Helper.NUM_META_PER_GEMM],
                                            scale[:FP8Helper.NUM_META_PER_GEMM],
                                            scale_inv[:FP8Helper.NUM_META_PER_GEMM])
            linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,)))

            bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
            linear_1_out += jnp.reshape(bias_1, bias_1_shape)

            x = jax.nn.gelu(linear_1_out)
            x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)

            fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:],
                                            amax[FP8Helper.NUM_META_PER_GEMM:],
                                            scale[FP8Helper.NUM_META_PER_GEMM:],
                                            scale_inv[FP8Helper.NUM_META_PER_GEMM:])
            output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,)))

            bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
            output += jnp.reshape(bias_2, bias_2_shape)

            return output

        def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
                     fp8_metas_scale_inv):
            return jnp.mean(
                ln_gelu_fp8_mlp_ref(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
                                    fp8_metas_scale_inv))

        value_n_grad_primitive_func = jit(
            value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
        value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))

        ref_fp8_max = init_fp8_max
        ref_fp8_metas_amax = init_fp8_metas_amax
        ref_fp8_metas_scale = init_fp8_metas_scale
        ref_fp8_metas_scale_inv = init_fp8_metas_scale_inv

        pri_fp8_max = init_fp8_max
        pri_fp8_metas_amax = init_fp8_metas_amax
        pri_fp8_metas_scale = init_fp8_metas_scale
        pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv

        for _ in range(3):
            ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad,
                      ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale,
                      ref_fp8_metas_scale_inv) = value_n_grad_ref_func(
                          a, s, k1, k2, b1, b2, ref_fp8_max, ref_fp8_metas_amax,
                          ref_fp8_metas_scale, ref_fp8_metas_scale_inv)

        for _ in range(3):
            primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
                            primitive_k2_grad, primitive_b1_grad, primitive_b2_grad, pri_fp8_max,
                            pri_fp8_metas_amax, pri_fp8_metas_scale,
                            pri_fp8_metas_scale_inv) = value_n_grad_primitive_func(
                                a, s, k1, k2, b1, b2, pri_fp8_max, pri_fp8_metas_amax,
                                pri_fp8_metas_scale, pri_fp8_metas_scale_inv)

        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
                        jnp.asarray(ref_a_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
                        jnp.asarray(ref_k1_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
                        jnp.asarray(ref_k2_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
                        jnp.asarray(ref_s_grad, np.float32),
                        dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
                        jnp.asarray(ref_b1_grad, np.float32),
                        dtype=jnp.bfloat16)
        assert_allclose(jnp.asarray(primitive_b2_grad, np.float32),
                        jnp.asarray(ref_b2_grad, np.float32),
                        dtype=jnp.bfloat16)

411
412
413
414
415
416
417
418
419

@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
    key = jax.random.PRNGKey(0)
    subkeys = jax.random.split(key, 4)
    out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
    return out


420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
class TestGeLu:

    def ref_func(self, inputs):

        func = jit(value_and_grad(lambda x: jnp.mean(jax.nn.gelu(x))))
        return func(inputs)

    def prim_func(self, inputs):

        @jax.custom_vjp
        def primitive(x):
            out, _ = primitive_fwd(x)
            return out

        def primitive_fwd(x):
            out = gelu(x)
            ctx = x
            return out, ctx

        def primitive_bwd(ctx, g):
            x = ctx
            out = dgelu(g, x)
            return (out,)

        primitive.defvjp(primitive_fwd, primitive_bwd)
        func = value_and_grad(lambda x: jnp.mean(primitive(x)))
        return func(inputs)

    @pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
    def test_gelu(self, random_inputs):
        x = random_inputs
        prim_out, prim_grad = self.prim_func(x)
        ref_out, ref_grad = self.ref_func(x)

        assert_allclose(prim_out, ref_out, dtype=x.dtype)
        assert_allclose(prim_grad, ref_grad, dtype=x.dtype)


class TestGeLuFP8(TestGeLu):

    def prim_func(self, inputs):
        amax = self.amax
        scale = self.scale
        scale_inv = self.scale_inv
        no_use = jnp.zeros(1, jnp.float32)

        @jax.custom_vjp
        def primitive(x, y, z, w):
            out = primitive_fwd(x)
            return out

        def primitive_fwd(x, y, z, w):
            out, _ = gelu_fp8(x, amax, scale, scale_inv, jnp.float8_e4m3fn)
            out = dequantize(out, x.dtype, scale_inv)
            ctx = x
            return out, ctx

        def primitive_bwd(ctx, g):
            x = ctx
            dgelu, dgelu_trans, dbias, amax_out = dgelu_dbias_cast_transpose(
                g, x, amax, scale, scale_inv, jnp.float8_e5m2, -1)
            dgelu = dequantize(dgelu, x.dtype, scale_inv)
            dgelu_trans = dequantize(dgelu_trans, x.dtype, scale_inv)
            return dgelu, dgelu_trans, dbias, amax_out

        primitive.defvjp(primitive_fwd, primitive_bwd)
        func = value_and_grad(lambda x, y, z, w: jnp.mean(primitive(x, y, z, w)), (0, 1, 2, 3))

        return func(inputs, no_use, no_use, no_use)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
    def test_gelu(self, random_inputs):
        self.amax = jnp.zeros(1, jnp.float32)
        self.scale = jnp.ones(1, jnp.float32)
        self.scale_inv = jnp.ones(1, jnp.float32)

        x = random_inputs
        prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x)
        ref_out, ref_grad = self.ref_func(x)

        assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
        assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
        assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
        assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(prim_grad_trans,
                        jnp.transpose(ref_grad, (2, 0, 1)),
                        dtype=FP8Helper.BWD_DTYPE)


510
511
512
513
514
class TestGatedGeLu:

    def ref_func(self, inputs):

        def jax_gated_gelu(x):
515
            x = jnp.split(x, 2, axis=-2)
516
517
            acts = [jax.nn.gelu(x[0]), x[1]]
            x = functools.reduce(operator.mul, acts)
518
            x = jnp.asarray(jnp.squeeze(x, -2), jnp.bfloat16)
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
            return x

        func = jit(value_and_grad(lambda x: jnp.mean(jax_gated_gelu(x))))
        return func(inputs)

    def prim_func(self, inputs):

        @jax.custom_vjp
        def primitive(x):
            out, _ = primitive_fwd(x)
            return out

        def primitive_fwd(x):
            out = gated_gelu(x)
            ctx = x
            return out, ctx

        def primitive_bwd(ctx, g):
            x = ctx
            out = dgated_gelu(g, x)
            return (out,)

        primitive.defvjp(primitive_fwd, primitive_bwd)
542
        func = value_and_grad(lambda x: jnp.mean(primitive(x)))
543
544
        return func(inputs)

545
    @pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
546
547
548
549
550
    def test_gated_gelu(self, random_inputs):
        x = random_inputs
        prim_out, prim_grad = self.prim_func(x)
        ref_out, ref_grad = self.ref_func(x)

551
552
        assert_allclose(prim_out, ref_out, dtype=x.dtype)
        assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
553
554
555
556
557
558
559
560
561
562
563
564


class TestGatedGeLuFP8(TestGatedGeLu):

    def prim_func(self, inputs):
        amax = self.amax
        scale = self.scale
        scale_inv = self.scale_inv
        no_use = jnp.zeros(1, jnp.float32)

        @jax.custom_vjp
        def primitive(x, y, z):
565
            out = primitive_fwd(x)
566
567
            return out

568
569
570
        def primitive_fwd(x, y, z):
            out, _ = gated_gelu_fp8(x, amax, scale, scale_inv, jnp.float8_e4m3fn)
            out = dequantize(out, x.dtype, scale_inv)
571
572
573
574
575
576
            ctx = x
            return out, ctx

        def primitive_bwd(ctx, g):
            x = ctx
            dgelu, dgelu_trans, amax_out = dgated_gelu_cast_transpose(g, x, amax, scale, scale_inv,
577
578
579
                                                                      jnp.float8_e5m2, -1)
            dgelu = dequantize(dgelu, x.dtype, scale_inv)
            dgelu_trans = dequantize(dgelu_trans, x.dtype, scale_inv)
580
581
582
            return dgelu, dgelu_trans, amax_out

        primitive.defvjp(primitive_fwd, primitive_bwd)
583
        func = value_and_grad(lambda x, y, z: jnp.mean(primitive(x, y, z)), (0, 1, 2))
584
585
586

        return func(inputs, no_use, no_use)

587
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
588
    @pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
589
590
591
592
593
594
595
596
597
    def test_gated_gelu(self, random_inputs):
        self.amax = jnp.zeros(1, jnp.float32)
        self.scale = jnp.ones(1, jnp.float32)
        self.scale_inv = jnp.ones(1, jnp.float32)

        x = random_inputs
        prim_out, (prim_grad, prim_grad_trans, amax) = self.prim_func(x)
        ref_out, ref_grad = self.ref_func(x)

598
        assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
599
        assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
600
601
602
603
        assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(prim_grad_trans,
                        jnp.transpose(ref_grad, (1, 2, 0)),
                        dtype=FP8Helper.BWD_DTYPE)
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633


class TestRMSNorm:

    @pytest.mark.parametrize('n, hidden', LN_CASES)
    @pytest.mark.parametrize('dtype', DTYPES)
    def test_forward_backward(self, n, hidden, dtype):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -2, 1)
        scale = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, -2, 1)
        scale = jnp.asarray(scale, dtype)
        epsilon = 1e-6

        def reference_rmsnorm(x, scale):
            x = jnp.asarray(x, jnp.float32)
            mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
            y = jnp.asarray(x * lax.rsqrt(mean2 + epsilon), dtype)
            return y * scale

        jitted_primitive = jit(
            value_and_grad(lambda x, scale: jnp.mean(layernorm(x, scale, None, "rmsnorm")), (0, 1)))

        jitted_reference = jit(
            value_and_grad(lambda x, scale: jnp.mean(reference_rmsnorm(x, scale)), (0, 1)))

        primitive_out, (primitive_dx, primitive_dgamma) = jitted_primitive(x, scale)
        reference_out, (reference_dx, reference_dgamma) = jitted_reference(x, scale)

634
635
636
        assert_allclose(primitive_out, reference_out, dtype=dtype)
        assert_allclose(primitive_dx, reference_dx, dtype=dtype)
        assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
637
638
639
640
641
642


class TestLayerNorm:

    @pytest.mark.parametrize('n, hidden', LN_CASES)
    @pytest.mark.parametrize('dtype', DTYPES)
643
644
    @pytest.mark.parametrize('zero_centered_gamma', [False, True])
    def test_forward_backward(self, n, hidden, zero_centered_gamma, dtype):
645
646
647
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 3)

648
649
650
        x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1)
        scale_range = (-1, 1) if zero_centered_gamma else (0, 2)
        scale = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *scale_range)
651
        scale = jnp.asarray(scale, dtype)
652
        bias = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
653
654
655
        bias = jnp.asarray(bias, dtype)
        epsilon = 1e-6

656
657
658
659
660
        def reference_layernorm(x, scale, bias, zero_centered_gamma, eps):
            x_ = jnp.asarray(x, jnp.float32)
            mean = jnp.mean(x_, axis=-1, keepdims=True)
            var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
            normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
661
            # Align TE implementation
662
663
664
665
666
667
668
669
            if zero_centered_gamma:
                return jnp.asarray(normed_input * (scale + 1) + bias).astype(x.dtype)
            return jnp.asarray(normed_input * scale + bias).astype(x.dtype)

        def compute_loss(x):
            # Higher precision to compute the loss
            x_ = x.astype(jnp.float32)
            return jnp.mean(jnp.square(x_)).astype(x.dtype)
670
671

        jitted_primitive = jit(
672
673
674
675
            value_and_grad(
                lambda x, scale, bias: compute_loss(
                    layernorm(x, scale, bias, "layernorm", zero_centered_gamma, epsilon)),
                (0, 1, 2)))
676
677

        jitted_reference = jit(
678
679
680
            value_and_grad(
                lambda x, scale, bias: compute_loss(
                    reference_layernorm(x, scale, bias, zero_centered_gamma, epsilon)), (0, 1, 2)))
681
682
683
684
685
686

        primitive_out, (primitive_dx, primitive_dgamma,
                        primitive_dbeta) = jitted_primitive(x, scale, bias)
        reference_out, (reference_dx, reference_dgamma,
                        reference_dbeta) = jitted_reference(x, scale, bias)

687
688
689
690
        assert_allclose(primitive_out, reference_out, dtype=dtype)
        assert_allclose(primitive_dx, reference_dx, dtype=dtype)
        assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
        assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype)