test_custom_call_compute.py 31.2 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from contextlib import nullcontext
6
from typing import Callable, List, Sequence, Union
7
import os
8
9
10
11
12
13
14
15

import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import jit, value_and_grad
from flax import linen as nn

16
from utils import assert_allclose, assert_tree_like_allclose
17
18
19
20
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available
from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp
21
from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu
22
23
24
25
26
from transformer_engine.jax.cpp_extensions.transpose import (
    _jax_transpose,
    _jax_cast_transpose,
)
from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8
27
from transformer_engine.jax import cpp_extensions as tex
28

29

Tim Moon's avatar
Tim Moon committed
30
31
32
33
34
35
36
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (2048, 1024, 2048),
    (2048, 2048, 1024),
    (2048, 1024, 1024),
]
37
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
38
39
LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32]
40
is_fp8_supported, reason = is_fp8_available()
41

42

43
44
class TestFP8Dot:

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    @staticmethod
    def _generate_fp8_meta():
        fp8_dtype_list = [FP8Helper.FWD_DTYPE, FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE]
        amax_list = [
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
        ]
        scale_list = [
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
        ]
        return fp8_dtype_list, amax_list, scale_list

60
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
61
    def test_qdq(self):
62
        FP8_E4M3_MAX = (jnp.finfo(jnp.float8_e4m3fn).max).astype(jnp.float32)
63
64
65
66
67
        x = jnp.asarray([[-1, 0.1], [2, 3]], jnp.float32)
        amax = jnp.max(jnp.abs(x)).reshape(1)
        scale = jnp.asarray(FP8_E4M3_MAX / amax, jnp.float32).reshape(1)
        scale_inv = (1 / scale).reshape(1)

68
        y, _ = quantize(x, q_dtype=jnp.float8_e4m3fn, scale=scale)
69
        z = dequantize(y, dq_dtype=jnp.float32, scale_inv=scale_inv)
70

71
        assert_allclose(z, x, dtype=jnp.float8_e4m3fn)
72

73
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
74
75
76
77
78
79
    def test_forward_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

80
        primitive_out = type_safe_dot_general(a, b)
81
82
        ref_out = jnp.dot(a, b)

83
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
84

85
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
86
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
87
    def test_forward_fp8_randint(self, m, n, k):
88
89
90
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

91
92
        dtype = jnp.bfloat16

93
94
        # TODO(rewang): add float random test
        min_val, max_val = -8, 8
95
96
97
98
99
100
101
102
103
104
105
106
        a = jax.random.randint(subkeys[0], (m, k), min_val, max_val).astype(dtype)
        b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(dtype)

        _, amax_list, scale_list = TestFP8Dot._generate_fp8_meta()
        fp8_meta_pkg = FP8MetaPackage(
            amax_list[0],
            scale_list[0],
            amax_list[1],
            scale_list[1],
            amax_list[2],
            scale_list[2],
        )
107
        primitive_out = type_safe_dot_general(a, b, fp8_meta_pkg)
108
109
110
111
112
        ref_out = jnp.dot(a, b)

        ref_out = ref_out.astype(jnp.float32)
        primitive_out = primitive_out.astype(jnp.float32)

113
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
114

115
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
116
117
118
119
120
121
122
    def test_grad_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

        def primitive_func(x, y):
123
124
            primitive_out = type_safe_dot_general(x, y)
            return jnp.mean(primitive_out)
125
126
127
128
129
130
131
132
133
134
135

        def ref_func(x, y):
            return jnp.mean(jnp.dot(x, y))

        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))

        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b)
        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

136
137
138
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=jnp.bfloat16)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=jnp.bfloat16)
139

140
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
141
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
142
    def test_grad_fp8_dot(self, m, n, k):
143
144
145
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

146
147
        a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)
148

149
150
151
152
153
154
155
156
157
158
159
        _, amax_list, scale_list = TestFP8Dot._generate_fp8_meta()

        def primitive_func(x, y, amax_list, scale_list):
            fp8_meta_pkg = FP8MetaPackage(
                amax_list[0],
                scale_list[0],
                amax_list[1],
                scale_list[1],
                amax_list[2],
                scale_list[2],
            )
160
161
            primitive_out = type_safe_dot_general(x, y, fp8_meta_pkg)
            return jnp.mean(primitive_out)
162
163

        def ref_func(x, y):
164
            return jnp.mean(jnp.dot(x, y))
165

166
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2, 3))
167
168
169
170
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

171
        for _ in range(3):
172
173
174
            primitive_out, (primitive_a_grad, primitive_b_grad, amax_list, scale_list) = (
                value_n_grad_primitive_func(a, b, amax_list, scale_list)
            )
175

176
177
178
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
179

180
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    @pytest.mark.parametrize(
        "m,n,k", [(256, 128, 512), (16384, 1024, 2816), (16384, 2816, 1024), (16384, 1024, 1024)]
    )
    @pytest.mark.parametrize(
        "activation_type",
        [
            ("gelu",),
            ("gelu", "linear"),
            ("silu",),
            ("silu", "linear"),
            ("relu",),
            ("relu", "linear"),
            ("quick_gelu",),
            ("quick_gelu", "linear"),
            ("squared_relu",),
            ("squared_relu", "linear"),
        ],
    )
    @pytest.mark.parametrize("use_bias", [True, False])
    def test_grad_fused_layernorm_fp8_mlp(
        self, m, n, k, activation_type: Sequence[Union[str, Callable]], use_bias: bool
    ):
        """N/a"""
204
        key = jax.random.PRNGKey(0)
205
206
        subkeys = jax.random.split(key, 6)

207
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
208
209
        k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16) / jnp.sqrt(k)
        k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
210
211
212
213
214
        s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
        if use_bias:
            b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
            b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
        else:
215
216
            b1 = None
            b2 = None
217

218
219
220
        def primitive_func(
            x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2
        ):
221
222
            # x is input tensor, matrix 2d
            # y, z are weights, matrix 2d
223
            # out = ((x * y) + w) * z + v
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
            fp8_meta_pkg_1 = FP8MetaPackage(
                amax_list_1[0],
                scale_list_1[0],
                amax_list_1[1],
                scale_list_1[1],
                amax_list_1[2],
                scale_list_1[2],
            )
            fp8_meta_pkg_2 = FP8MetaPackage(
                amax_list_2[0],
                scale_list_2[0],
                amax_list_2[1],
                scale_list_2[1],
                amax_list_2[2],
                scale_list_2[2],
            )
240
            return jnp.mean(
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
                fused_layernorm_fp8_mlp(
                    x,
                    ln_s,
                    None,
                    [y, z],
                    [w, v],
                    [fp8_meta_pkg_1, fp8_meta_pkg_2],
                    "rmsnorm",
                    activation_type=activation_type,
                    use_bias=use_bias,
                )
            )

        def layernorm_fp8_mlp_ref(
            x: jnp.ndarray,
            ln_scale: jnp.ndarray,
            kernel_1: jnp.ndarray,
            kernel_2: jnp.ndarray,
            bias_1: jnp.ndarray,
            bias_2: jnp.ndarray,
            amax_list_1: List[jnp.ndarray],
            amax_list_2: List[jnp.ndarray],
            scale_list_1: List[jnp.ndarray],
            scale_list_2: List[jnp.ndarray],
        ) -> jnp.ndarray:
266
267
268
269
270
271
272

            x = jnp.asarray(x, jnp.float32)
            mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
            y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
            ln_out = y * ln_scale
            ln_out = jnp.asarray(ln_out, jnp.bfloat16)

273
274
275
276
277
278
279
280
281
            fp8_meta_pkg_1 = FP8MetaPackage(
                amax_list_1[0],
                scale_list_1[0],
                amax_list_1[1],
                scale_list_1[1],
                amax_list_1[2],
                scale_list_1[2],
            )
            linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_meta_pkg_1, ((1,), (0,)))
282

283
284
285
286
            if use_bias:
                bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
                linear_1_out += jnp.reshape(bias_1, bias_1_shape)

287
            x = _jax_act_lu(linear_1_out, activation_type)
288

289
290
291
292
293
294
295
296
297
            fp8_meta_pkg_2 = FP8MetaPackage(
                amax_list_2[0],
                scale_list_2[0],
                amax_list_2[1],
                scale_list_2[1],
                amax_list_2[2],
                scale_list_2[2],
            )
            output = type_safe_dot_general(x, kernel_2, fp8_meta_pkg_2, ((1,), (0,)))
298

299
300
301
            if use_bias:
                bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
                output += jnp.reshape(bias_2, bias_2_shape)
302
303
304

            return output

305
        def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2):
306
            return jnp.mean(
307
308
309
310
                layernorm_fp8_mlp_ref(
                    x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2
                )
            )
311
312

        value_n_grad_primitive_func = jit(
313
314
            value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
        )
315
316
        value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))

317
318
319
320
321
322
323
        _, amax_list_1, scale_list_1 = TestFP8Dot._generate_fp8_meta()
        _, amax_list_2, scale_list_2 = TestFP8Dot._generate_fp8_meta()

        ref_amax_list_1 = amax_list_1
        ref_scale_list_1 = scale_list_1
        ref_amax_list_2 = amax_list_2
        ref_scale_list_2 = scale_list_2
324

325
326
327
328
329
330
        primitive_amax_list_1 = amax_list_1
        primitive_scale_list_1 = scale_list_1
        primitive_amax_list_2 = amax_list_2
        primitive_scale_list_2 = scale_list_2

        primitive_amax_list_1, primitive_scale_list_1, primitive_amax_list_2, primitive_scale_list_2
331

332
        # Convert str to index as str is not a valid type for JAX JIT
333
        for _ in range(3):
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
            ref_out, (
                ref_a_grad,
                ref_s_grad,
                ref_k1_grad,
                ref_k2_grad,
                ref_b1_grad,
                ref_b2_grad,
                ref_amax_list_1,
                ref_amax_list_2,
                ref_scale_list_1,
                ref_scale_list_2,
            ) = value_n_grad_ref_func(
                a,
                s,
                k1,
                k2,
                b1,
                b2,
                ref_amax_list_1,
                ref_amax_list_2,
                ref_scale_list_1,
                ref_scale_list_2,
            )
357
358

        for _ in range(3):
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
            primitive_out, (
                primitive_a_grad,
                primitive_s_grad,
                primitive_k1_grad,
                primitive_k2_grad,
                primitive_b1_grad,
                primitive_b2_grad,
                primitive_amax_list_1,
                primitive_amax_list_2,
                primitive_scale_list_1,
                primitive_scale_list_2,
            ) = value_n_grad_primitive_func(
                a,
                s,
                k1,
                k2,
                b1,
                b2,
                primitive_amax_list_1,
                primitive_amax_list_2,
                primitive_scale_list_1,
                primitive_scale_list_2,
            )
382
383

        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
        assert_allclose(
            jnp.asarray(primitive_a_grad, np.float32),
            jnp.asarray(ref_a_grad, np.float32),
            dtype=FP8Helper.BWD_DTYPE,
        )
        assert_allclose(
            jnp.asarray(primitive_k1_grad, np.float32),
            jnp.asarray(ref_k1_grad, np.float32),
            dtype=FP8Helper.BWD_DTYPE,
        )
        assert_allclose(
            jnp.asarray(primitive_s_grad, np.float32),
            jnp.asarray(ref_s_grad, np.float32),
            dtype=FP8Helper.BWD_DTYPE,
        )
        assert_allclose(
            jnp.asarray(primitive_k2_grad, np.float32),
            jnp.asarray(ref_k2_grad, np.float32),
            dtype=FP8Helper.BWD_DTYPE,
        )
404
        if use_bias:
405
406
407
408
409
410
411
412
413
414
            assert_allclose(
                jnp.asarray(primitive_b2_grad, np.float32),
                jnp.asarray(ref_b2_grad, np.float32),
                dtype=FP8Helper.BWD_DTYPE,
            )
            assert_allclose(
                jnp.asarray(primitive_b1_grad, np.float32),
                jnp.asarray(ref_b1_grad, np.float32),
                dtype=FP8Helper.BWD_DTYPE,
            )
415

416

417
418
419
420
421
422
423
424
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
    key = jax.random.PRNGKey(0)
    subkeys = jax.random.split(key, 4)
    out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
    return out


425
class TestActivationLu:
426

427
    def ref_func(self, x, activation_type):
428

429
        def ref_act_lu(inputs):
430
            x = _jax_act_lu(inputs, activation_type)
431
            return jnp.mean(x)
432

433
434
        ref_act_func = jit(value_and_grad(ref_act_lu, (0,)))
        return ref_act_func(x)
435

436
    def primitive_func(self, inputs):
437
        return jnp.mean(activation_lu(inputs, activation_type=self.activation_type))
438

439
    @pytest.mark.parametrize("shape", [(32, 1, 64), (16, 64, 1, 256)])
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
    @pytest.mark.parametrize(
        "activation_type",
        [
            ("gelu",),
            ("gelu", "linear"),
            ("silu",),
            ("silu", "linear"),
            ("relu",),
            ("relu", "linear"),
            ("quick_gelu",),
            ("quick_gelu", "linear"),
            ("squared_relu",),
            ("squared_relu", "linear"),
        ],
    )
455
    def test_activation_lu(self, random_inputs, activation_type):
456
        x = random_inputs
457
        x = jnp.repeat(x, len(activation_type), axis=-2)
458
        self.activation_type = activation_type
459

460
        value_n_grad_primitive_func = jit(value_and_grad(self.primitive_func, (0,)))
461

462
463
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x)
        ref_out, (ref_grad,) = self.ref_func(x, activation_type)
464

465
466
        assert_allclose(prim_out, ref_out, dtype=x.dtype)
        assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
467
468


469
class TestActivationLuFP8(TestActivationLu):
470

471
472
473
474
475
476
477
478
479
480
481
482
    def prim_func(self, x):
        amax = self.amax
        scale = self.scale
        scale_inv = self.scale_inv
        activation_type = self.activation_type

        @jax.custom_vjp
        def _prim_func(x, _x_t, _dbias, _amax):
            output = _prim_func_fwd(x, _x_t, _dbias, _amax)
            return output

        def _prim_func_fwd(x, _x_t, _dbias, _amax):
483
484
485
            activation_lu_out, _ = tex.act_lu_fp8(
                x, amax, scale, scale_inv, FP8Helper.FWD_DTYPE, activation_type
            )
486
            activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
487
            ctx = x
488
489
490
491
            return activation_lu_out, ctx

        def _prim_func_bwd(ctx, g):
            x = ctx
492
493
494
495
            if len(self.activation_type) > 1:  # gated, no bias
                dactivation_lu, dactivation_lu_trans, amax_out = tex.dgated_act_lu_cast_transpose(
                    g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE, -1, activation_type
                )
496
                dbias = jnp.empty(x.shape[-1], x.dtype)
497
498
499
500
501
502
503
504
505
506
507
508
509
510
            else:  # not gated, with bias
                dactivation_lu, dactivation_lu_trans, dbias, amax_out = (
                    tex.dact_lu_dbias_cast_transpose(
                        g,
                        x,
                        amax,
                        scale,
                        scale_inv,
                        FP8Helper.BWD_DTYPE,
                        -1,
                        -2,
                        self.activation_type,
                    )
                )
511
512
513
514
515
516
517
            dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
            dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
            ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out)
            return ctx

        _prim_func.defvjp(_prim_func_fwd, _prim_func_bwd)

518
        dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_axes], dtype=x.dtype)
519
520
        dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
        amax_no_use = jnp.zeros(1, jnp.float32)
521
522
523
        value_n_grad_primitive_func = value_and_grad(
            lambda a, b, c, d: jnp.mean(_prim_func(a, b, c, d)), (0, 1, 2, 3)
        )
524
525
        return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use)

526
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
527
    @pytest.mark.parametrize("shape", [(32, 1, 64), (16, 64, 1, 256)])
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
    @pytest.mark.parametrize(
        "activation_type",
        [
            ("gelu",),
            ("gelu", "linear"),
            ("silu",),
            ("silu", "linear"),
            ("relu",),
            ("relu", "linear"),
            ("quick_gelu",),
            ("quick_gelu", "linear"),
            ("squared_relu",),
            ("squared_relu", "linear"),
        ],
    )
543
    def test_activation_lu(self, random_inputs, activation_type):
544
545
546
        self.amax = jnp.zeros(1, jnp.float32)
        self.scale = jnp.ones(1, jnp.float32)
        self.scale_inv = jnp.ones(1, jnp.float32)
547
        self.activation_type = activation_type
548
549

        x = random_inputs
550
551
552
553
        x = jnp.repeat(x, len(activation_type), axis=-2)
        axes = jnp.arange(x.ndim)
        self.transpose_axes = tuple([*axes[-2:]] + [*axes[:-2]])
        print(self.transpose_axes)
554

555
        prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x)
556
        ref_out, (ref_grad,) = self.ref_func(x, activation_type)
557

558
        assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
559
        assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
560
        if "linear" not in activation_type:
561
            assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
562
        assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
563
564
        assert_allclose(
            prim_grad_trans,
565
            jnp.transpose(ref_grad, self.transpose_axes),
566
567
            dtype=FP8Helper.BWD_DTYPE,
        )
568
569


570
571
572
573
class TestNorm:
    """
    Test transformer_engine.jax.layernorm APIs
    """
574

575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
    @staticmethod
    def _generate_fp8_meta():
        fp8_dtype_list = [FP8Helper.FWD_DTYPE, FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE]
        amax_list = [
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
        ]
        scale_list = [
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
        ]
        return fp8_dtype_list, amax_list, scale_list

590
591
592
593
594
595
596
597
    def reference_layernorm(self, x, scale, bias, zero_centered_gamma, eps):
        """
        JAX native layernorm implementations
        - bias is not None: layernorm
        - bias is None: rmsnorm
        """
        x_ = jnp.asarray(x, jnp.float32)
        if bias is None:
598
            mean = 0.0
599
600
601
602
603
        else:
            mean = jnp.mean(x_, axis=-1, keepdims=True)
        var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
        normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
        if zero_centered_gamma:
604
            scale += 1.0
605
        if bias is None:
606
            bias = 0.0
607
        return jnp.asarray(normed_input * scale + bias).astype(x.dtype)
608

609
610
611
612
613
614
615
616
    @pytest.mark.parametrize("n, hidden", LN_CASES)
    @pytest.mark.parametrize("dtype", DTYPES)
    @pytest.mark.parametrize("ln_type", ["layernorm", "rmsnorm"])
    @pytest.mark.parametrize("zero_centered_gamma", [False, True])
    @pytest.mark.parametrize("epsilon", [1e-2, 1e-6])
    def test_layernorm_forward_backward(
        self, n, hidden, ln_type, zero_centered_gamma, epsilon, dtype
    ):
617
618
619
620
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        expect_assert = False
621
        if ln_type == "rmsnorm" and zero_centered_gamma:
622
623
624
            # zero_centered_gamma is not supported for rmsnorm, expect an assertion.
            expect_assert = True

625
626
627
628
629
        with (
            pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*")
            if expect_assert
            else nullcontext()
        ):
630
631
632
633
634
635
636
            key = jax.random.PRNGKey(0)
            subkeys = jax.random.split(key, 3)

            x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1)
            gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
            gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
            gamma = jnp.asarray(gamma, dtype)
637
            if ln_type == "layernorm":
638
639
640
641
642
643
644
645
646
647
648
649
650
                beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
                beta = jnp.asarray(beta, dtype)
            else:
                beta = None

            def compute_loss(x):
                # Higher precision to compute the loss
                x_ = x.astype(jnp.float32)
                return jnp.mean(jnp.square(x_)).astype(x.dtype)

            jitted_primitive = jit(
                value_and_grad(
                    lambda x, gamma, beta: compute_loss(
651
652
653
654
655
                        layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)
                    ),
                    (0, 1, 2),
                )
            )
656
657
658
659

            jitted_reference = jit(
                value_and_grad(
                    lambda x, gamma, beta: compute_loss(
660
661
662
663
664
                        self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
                    ),
                    (0, 1, 2),
                )
            )
665

666
667
668
669
670
671
            primitive_out, (primitive_dx, primitive_dgamma, primitive_dbeta) = jitted_primitive(
                x, gamma, beta
            )
            reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference(
                x, gamma, beta
            )
672
673
674
675
676
677
678
679

            assert_allclose(primitive_out, reference_out, dtype=dtype)
            assert_allclose(primitive_dx, reference_dx, dtype=dtype)
            assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
            if beta is not None:
                assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
680
681
682
683
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
    @pytest.mark.parametrize("ln_type", ["layernorm", "rmsnorm"])
    @pytest.mark.parametrize("zero_centered_gamma", [True, False])
    @pytest.mark.parametrize("epsilon", [1e-2, 1e-6])
684
685
686
687
688
    def test_ln_fp8_dot_forward_backward(self, m, n, k, ln_type, zero_centered_gamma, epsilon):
        """
        Test transformer_engine.jax.layernorm.layernorm_fp8_dot
        """
        expect_assert = False
689
        if ln_type == "rmsnorm" and zero_centered_gamma:
690
691
692
            # zero_centered_gamma is not supported for rmsnorm, expect an assertion.
            expect_assert = True

693
694
695
696
697
        with (
            pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*")
            if expect_assert
            else nullcontext()
        ):
698
699
700
701
702
703
704
            key = jax.random.PRNGKey(0)
            subkeys = jax.random.split(key, 4)

            a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
            b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)

            gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
705
            if ln_type == "layernorm":
706
707
708
709
                beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
            else:
                beta = None

710
711
712
713
714
715
716
717
718
719
720
            _, amax_list_1, scale_list_1 = TestNorm._generate_fp8_meta()

            def primitive_func(x, y, gamma, beta, amax_list_1, scale_list_1):
                fp8_meta_pkg = FP8MetaPackage(
                    amax_list_1[0],
                    scale_list_1[0],
                    amax_list_1[1],
                    scale_list_1[1],
                    amax_list_1[2],
                    scale_list_1[2],
                )
721
722
723
                primitive_out = layernorm_fp8_dot(
                    x, y, gamma, beta, fp8_meta_pkg, ln_type, zero_centered_gamma
                )
724
725
726
727
728
729
                return jnp.mean(primitive_out)

            def ref_func(x, y, gamma, beta, zero_centered_gamma):
                x = self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
                return jnp.mean(jnp.dot(x, y))

730
            value_n_grad_primitive_func = value_and_grad(primitive_func, range(6))
731
732
            value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))

733
734
735
            ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad, ref_beta_grad) = (
                value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma)
            )
736
737

            for _ in range(3):
738
739
740
741
742
743
744
745
                primitive_out, (
                    primitive_a_grad,
                    primitive_b_grad,
                    primitive_gamma_grad,
                    primitive_beta_grad,
                    amax_list_1,
                    scale_list_1,
                ) = value_n_grad_primitive_func(a, b, gamma, beta, amax_list_1, scale_list_1)
746
747
748
749
750
751
752

            assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
            assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
            assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
            assert_allclose(primitive_gamma_grad, ref_gamma_grad, dtype=FP8Helper.BWD_DTYPE)
            if beta is not None:
                assert_allclose(primitive_beta_grad, ref_beta_grad, dtype=FP8Helper.BWD_DTYPE)
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851


@pytest.mark.parametrize(
    "in_dtype",
    [
        pytest.param(jnp.float32, id="input_float32"),
        pytest.param(jnp.float16, id="input_float16"),
        pytest.param(jnp.bfloat16, id="input_bfloat16"),
    ],
)
@pytest.mark.parametrize(
    "input_shape, transpose_axis",
    [
        pytest.param((16, 16), 1, id="(16, 16)-1"),
        pytest.param((256, 128), 1, id="(256, 128)-1"),
        pytest.param((128, 512), 1, id="(128, 512)-1"),
        pytest.param((64, 16, 4, 256), 1, id="(64, 16, 4, 256)-1"),
        pytest.param((64, 16, 4, 256), 2, id="(64, 16, 4, 256)-2"),
        pytest.param((64, 16, 4, 256), 3, id="(64, 16, 4, 256)-3"),
    ],
)
class TestTranspose:
    def test_transpose(self, in_dtype, input_shape, transpose_axis):
        key = jax.random.PRNGKey(0)
        input_tensor = jax.random.uniform(key, input_shape, in_dtype)
        static_axis_boundary = -1
        jax_output = _jax_transpose(input_tensor, static_axis_boundary, transpose_axis)
        os.environ["NVTE_JAX_WITH_FFI"] = "0"
        noffi_output = tex.transpose(input_tensor, static_axis_boundary, transpose_axis)
        os.environ["NVTE_JAX_WITH_FFI"] = "1"
        ffi_output = tex.transpose(input_tensor, static_axis_boundary, transpose_axis)
        assert_allclose(jax_output, noffi_output)
        assert_allclose(noffi_output, ffi_output)

    @pytest.mark.parametrize(
        "out_dtype",
        [
            pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"),
            pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"),
        ],
    )
    def test_cast_transpose(self, in_dtype, input_shape, transpose_axis, out_dtype):
        amax = jnp.zeros(1, jnp.float32)
        scale = jnp.ones(1, jnp.float32)
        scale_inv = jnp.ones(1, jnp.float32)
        key = jax.random.PRNGKey(0)
        input = jax.random.uniform(key, input_shape, in_dtype)
        static_axis_boundary = -1
        jax_output = _jax_cast_transpose(
            input, scale, amax, out_dtype, static_axis_boundary, transpose_axis
        )
        os.environ["NVTE_JAX_WITH_FFI"] = "0"
        noffi_output = tex.cast_transpose(
            input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
        )
        os.environ["NVTE_JAX_WITH_FFI"] = "1"
        ffi_output = tex.cast_transpose(
            input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
        )
        assert_tree_like_allclose(jax_output, ffi_output)
        assert_tree_like_allclose(noffi_output, ffi_output)


@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize(
    "input_shape",
    [
        pytest.param((256, 128), id="(256, 128)"),
        pytest.param((128, 512, 8), id="(128, 512, 8)"),
    ],
)
@pytest.mark.parametrize(
    "in_dtype",
    [
        pytest.param(jnp.float32, id="input_float32"),
        pytest.param(jnp.float16, id="input_float16"),
        pytest.param(jnp.bfloat16, id="input_bfloat16"),
    ],
)
@pytest.mark.parametrize(
    "out_dtype",
    [
        pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"),
        pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"),
    ],
)
def test_quantize(input_shape, in_dtype, out_dtype):
    amax = jnp.zeros(1, jnp.float32)
    scale = jnp.ones(1, jnp.float32)
    scale_inv = jnp.ones(1, jnp.float32)
    key = jax.random.PRNGKey(0)
    input = jax.random.uniform(key, input_shape, in_dtype)
    jax_output = _jax_cast_fp8(input, scale, amax, out_dtype)
    os.environ["NVTE_JAX_WITH_FFI"] = "0"
    noffi_output = tex.cast_fp8(input, amax, scale, scale_inv, out_dtype)
    os.environ["NVTE_JAX_WITH_FFI"] = "1"
    ffi_output = tex.cast_fp8(input, amax, scale, scale_inv, out_dtype)
    assert_tree_like_allclose(jax_output, ffi_output)
    assert_tree_like_allclose(noffi_output, ffi_output)