test_custom_call_compute.py 32.4 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from contextlib import nullcontext
6
from typing import Callable, List, Sequence, Union
7
import os
8
9
10
11
12
13
14
15

import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import jit, value_and_grad
from flax import linen as nn

16
from utils import assert_allclose, assert_tree_like_allclose
17
18
19
20
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available
from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp
21
from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu
22
23
24
from transformer_engine.jax.cpp_extensions.transpose import (
    _jax_transpose,
    _jax_cast_transpose,
25
    _jax_dbias_cast_transpose,
26
27
)
from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8
28
from transformer_engine.jax import cpp_extensions as tex
29

30

Tim Moon's avatar
Tim Moon committed
31
32
33
34
35
36
37
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (2048, 1024, 2048),
    (2048, 2048, 1024),
    (2048, 1024, 1024),
]
38
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
39
40
LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32]
41
is_fp8_supported, reason = is_fp8_available()
42

43

44
45
class TestFP8Dot:

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    @staticmethod
    def _generate_fp8_meta():
        fp8_dtype_list = [FP8Helper.FWD_DTYPE, FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE]
        amax_list = [
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
        ]
        scale_list = [
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
        ]
        return fp8_dtype_list, amax_list, scale_list

61
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
62
    def test_qdq(self):
63
        FP8_E4M3_MAX = (jnp.finfo(jnp.float8_e4m3fn).max).astype(jnp.float32)
64
65
66
67
68
        x = jnp.asarray([[-1, 0.1], [2, 3]], jnp.float32)
        amax = jnp.max(jnp.abs(x)).reshape(1)
        scale = jnp.asarray(FP8_E4M3_MAX / amax, jnp.float32).reshape(1)
        scale_inv = (1 / scale).reshape(1)

69
        y, _ = quantize(x, q_dtype=jnp.float8_e4m3fn, scale=scale)
70
        z = dequantize(y, dq_dtype=jnp.float32, scale_inv=scale_inv)
71

72
        assert_allclose(z, x, dtype=jnp.float8_e4m3fn)
73

74
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
75
76
77
78
79
80
    def test_forward_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

81
        primitive_out = type_safe_dot_general(a, b)
82
83
        ref_out = jnp.dot(a, b)

84
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
85

86
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
87
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
88
    def test_forward_fp8_randint(self, m, n, k):
89
90
91
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

92
93
        dtype = jnp.bfloat16

94
95
        # TODO(rewang): add float random test
        min_val, max_val = -8, 8
96
97
98
99
100
101
102
103
104
105
106
107
        a = jax.random.randint(subkeys[0], (m, k), min_val, max_val).astype(dtype)
        b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(dtype)

        _, amax_list, scale_list = TestFP8Dot._generate_fp8_meta()
        fp8_meta_pkg = FP8MetaPackage(
            amax_list[0],
            scale_list[0],
            amax_list[1],
            scale_list[1],
            amax_list[2],
            scale_list[2],
        )
108
        primitive_out = type_safe_dot_general(a, b, fp8_meta_pkg)
109
110
111
112
113
        ref_out = jnp.dot(a, b)

        ref_out = ref_out.astype(jnp.float32)
        primitive_out = primitive_out.astype(jnp.float32)

114
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
115

116
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
117
118
119
120
121
122
123
    def test_grad_bf16(self, m, n, k):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)

        def primitive_func(x, y):
124
125
            primitive_out = type_safe_dot_general(x, y)
            return jnp.mean(primitive_out)
126
127
128
129
130
131
132
133
134
135
136

        def ref_func(x, y):
            return jnp.mean(jnp.dot(x, y))

        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))

        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b)
        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

137
138
139
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=jnp.bfloat16)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=jnp.bfloat16)
140

141
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
142
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
143
    def test_grad_fp8_dot(self, m, n, k):
144
145
146
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

147
148
        a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
        b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)
149

150
151
152
153
154
155
156
157
158
159
160
        _, amax_list, scale_list = TestFP8Dot._generate_fp8_meta()

        def primitive_func(x, y, amax_list, scale_list):
            fp8_meta_pkg = FP8MetaPackage(
                amax_list[0],
                scale_list[0],
                amax_list[1],
                scale_list[1],
                amax_list[2],
                scale_list[2],
            )
161
162
            primitive_out = type_safe_dot_general(x, y, fp8_meta_pkg)
            return jnp.mean(primitive_out)
163
164

        def ref_func(x, y):
165
            return jnp.mean(jnp.dot(x, y))
166

167
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2, 3))
168
169
170
171
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))

        ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)

172
        for _ in range(3):
173
174
175
            primitive_out, (primitive_a_grad, primitive_b_grad, amax_list, scale_list) = (
                value_n_grad_primitive_func(a, b, amax_list, scale_list)
            )
176

177
178
179
        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
        assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
        assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
180

181
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    @pytest.mark.parametrize(
        "m,n,k", [(256, 128, 512), (16384, 1024, 2816), (16384, 2816, 1024), (16384, 1024, 1024)]
    )
    @pytest.mark.parametrize(
        "activation_type",
        [
            ("gelu",),
            ("gelu", "linear"),
            ("silu",),
            ("silu", "linear"),
            ("relu",),
            ("relu", "linear"),
            ("quick_gelu",),
            ("quick_gelu", "linear"),
            ("squared_relu",),
            ("squared_relu", "linear"),
        ],
    )
    @pytest.mark.parametrize("use_bias", [True, False])
    def test_grad_fused_layernorm_fp8_mlp(
        self, m, n, k, activation_type: Sequence[Union[str, Callable]], use_bias: bool
    ):
        """N/a"""
205
        key = jax.random.PRNGKey(0)
206
207
        subkeys = jax.random.split(key, 6)

208
        a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
209
210
        k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16) / jnp.sqrt(k)
        k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
211
212
213
214
215
        s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
        if use_bias:
            b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
            b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
        else:
216
217
            b1 = None
            b2 = None
218

219
220
221
        def primitive_func(
            x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2
        ):
222
223
            # x is input tensor, matrix 2d
            # y, z are weights, matrix 2d
224
            # out = ((x * y) + w) * z + v
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
            fp8_meta_pkg_1 = FP8MetaPackage(
                amax_list_1[0],
                scale_list_1[0],
                amax_list_1[1],
                scale_list_1[1],
                amax_list_1[2],
                scale_list_1[2],
            )
            fp8_meta_pkg_2 = FP8MetaPackage(
                amax_list_2[0],
                scale_list_2[0],
                amax_list_2[1],
                scale_list_2[1],
                amax_list_2[2],
                scale_list_2[2],
            )
241
            return jnp.mean(
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
                fused_layernorm_fp8_mlp(
                    x,
                    ln_s,
                    None,
                    [y, z],
                    [w, v],
                    [fp8_meta_pkg_1, fp8_meta_pkg_2],
                    "rmsnorm",
                    activation_type=activation_type,
                    use_bias=use_bias,
                )
            )

        def layernorm_fp8_mlp_ref(
            x: jnp.ndarray,
            ln_scale: jnp.ndarray,
            kernel_1: jnp.ndarray,
            kernel_2: jnp.ndarray,
            bias_1: jnp.ndarray,
            bias_2: jnp.ndarray,
            amax_list_1: List[jnp.ndarray],
            amax_list_2: List[jnp.ndarray],
            scale_list_1: List[jnp.ndarray],
            scale_list_2: List[jnp.ndarray],
        ) -> jnp.ndarray:
267
268
269
270
271
272
273

            x = jnp.asarray(x, jnp.float32)
            mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
            y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
            ln_out = y * ln_scale
            ln_out = jnp.asarray(ln_out, jnp.bfloat16)

274
275
276
277
278
279
280
281
282
            fp8_meta_pkg_1 = FP8MetaPackage(
                amax_list_1[0],
                scale_list_1[0],
                amax_list_1[1],
                scale_list_1[1],
                amax_list_1[2],
                scale_list_1[2],
            )
            linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_meta_pkg_1, ((1,), (0,)))
283

284
285
286
287
            if use_bias:
                bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
                linear_1_out += jnp.reshape(bias_1, bias_1_shape)

288
            x = _jax_act_lu(linear_1_out, activation_type)
289

290
291
292
293
294
295
296
297
298
            fp8_meta_pkg_2 = FP8MetaPackage(
                amax_list_2[0],
                scale_list_2[0],
                amax_list_2[1],
                scale_list_2[1],
                amax_list_2[2],
                scale_list_2[2],
            )
            output = type_safe_dot_general(x, kernel_2, fp8_meta_pkg_2, ((1,), (0,)))
299

300
301
302
            if use_bias:
                bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
                output += jnp.reshape(bias_2, bias_2_shape)
303
304
305

            return output

306
        def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2):
307
            return jnp.mean(
308
309
310
311
                layernorm_fp8_mlp_ref(
                    x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2
                )
            )
312
313

        value_n_grad_primitive_func = jit(
314
315
            value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
        )
316
317
        value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))

318
319
320
321
322
323
324
        _, amax_list_1, scale_list_1 = TestFP8Dot._generate_fp8_meta()
        _, amax_list_2, scale_list_2 = TestFP8Dot._generate_fp8_meta()

        ref_amax_list_1 = amax_list_1
        ref_scale_list_1 = scale_list_1
        ref_amax_list_2 = amax_list_2
        ref_scale_list_2 = scale_list_2
325

326
327
328
329
330
331
        primitive_amax_list_1 = amax_list_1
        primitive_scale_list_1 = scale_list_1
        primitive_amax_list_2 = amax_list_2
        primitive_scale_list_2 = scale_list_2

        primitive_amax_list_1, primitive_scale_list_1, primitive_amax_list_2, primitive_scale_list_2
332

333
        # Convert str to index as str is not a valid type for JAX JIT
334
        for _ in range(3):
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
            ref_out, (
                ref_a_grad,
                ref_s_grad,
                ref_k1_grad,
                ref_k2_grad,
                ref_b1_grad,
                ref_b2_grad,
                ref_amax_list_1,
                ref_amax_list_2,
                ref_scale_list_1,
                ref_scale_list_2,
            ) = value_n_grad_ref_func(
                a,
                s,
                k1,
                k2,
                b1,
                b2,
                ref_amax_list_1,
                ref_amax_list_2,
                ref_scale_list_1,
                ref_scale_list_2,
            )
358
359

        for _ in range(3):
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
            primitive_out, (
                primitive_a_grad,
                primitive_s_grad,
                primitive_k1_grad,
                primitive_k2_grad,
                primitive_b1_grad,
                primitive_b2_grad,
                primitive_amax_list_1,
                primitive_amax_list_2,
                primitive_scale_list_1,
                primitive_scale_list_2,
            ) = value_n_grad_primitive_func(
                a,
                s,
                k1,
                k2,
                b1,
                b2,
                primitive_amax_list_1,
                primitive_amax_list_2,
                primitive_scale_list_1,
                primitive_scale_list_2,
            )
383
384

        assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
        assert_allclose(
            jnp.asarray(primitive_a_grad, np.float32),
            jnp.asarray(ref_a_grad, np.float32),
            dtype=FP8Helper.BWD_DTYPE,
        )
        assert_allclose(
            jnp.asarray(primitive_k1_grad, np.float32),
            jnp.asarray(ref_k1_grad, np.float32),
            dtype=FP8Helper.BWD_DTYPE,
        )
        assert_allclose(
            jnp.asarray(primitive_s_grad, np.float32),
            jnp.asarray(ref_s_grad, np.float32),
            dtype=FP8Helper.BWD_DTYPE,
        )
        assert_allclose(
            jnp.asarray(primitive_k2_grad, np.float32),
            jnp.asarray(ref_k2_grad, np.float32),
            dtype=FP8Helper.BWD_DTYPE,
        )
405
        if use_bias:
406
407
408
409
410
411
412
413
414
415
            assert_allclose(
                jnp.asarray(primitive_b2_grad, np.float32),
                jnp.asarray(ref_b2_grad, np.float32),
                dtype=FP8Helper.BWD_DTYPE,
            )
            assert_allclose(
                jnp.asarray(primitive_b1_grad, np.float32),
                jnp.asarray(ref_b1_grad, np.float32),
                dtype=FP8Helper.BWD_DTYPE,
            )
416

417

418
419
420
421
422
423
424
425
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
    key = jax.random.PRNGKey(0)
    subkeys = jax.random.split(key, 4)
    out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
    return out


426
class TestActivationLu:
427

428
    def ref_func(self, x, activation_type):
429

430
        def ref_act_lu(inputs):
431
            x = _jax_act_lu(inputs, activation_type)
432
            return jnp.mean(x)
433

434
435
        ref_act_func = jit(value_and_grad(ref_act_lu, (0,)))
        return ref_act_func(x)
436

437
    def primitive_func(self, inputs):
438
        return jnp.mean(activation_lu(inputs, activation_type=self.activation_type))
439

440
    @pytest.mark.parametrize("shape", [(32, 1, 64), (16, 64, 1, 256)])
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
    @pytest.mark.parametrize(
        "activation_type",
        [
            ("gelu",),
            ("gelu", "linear"),
            ("silu",),
            ("silu", "linear"),
            ("relu",),
            ("relu", "linear"),
            ("quick_gelu",),
            ("quick_gelu", "linear"),
            ("squared_relu",),
            ("squared_relu", "linear"),
        ],
    )
456
    def test_activation_lu(self, random_inputs, activation_type):
457
        x = random_inputs
458
        x = jnp.repeat(x, len(activation_type), axis=-2)
459
        self.activation_type = activation_type
460

461
        value_n_grad_primitive_func = jit(value_and_grad(self.primitive_func, (0,)))
462

463
464
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x)
        ref_out, (ref_grad,) = self.ref_func(x, activation_type)
465

466
467
        assert_allclose(prim_out, ref_out, dtype=x.dtype)
        assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
468
469


470
class TestActivationLuFP8(TestActivationLu):
471

472
473
474
475
476
477
478
479
480
481
482
483
    def prim_func(self, x):
        amax = self.amax
        scale = self.scale
        scale_inv = self.scale_inv
        activation_type = self.activation_type

        @jax.custom_vjp
        def _prim_func(x, _x_t, _dbias, _amax):
            output = _prim_func_fwd(x, _x_t, _dbias, _amax)
            return output

        def _prim_func_fwd(x, _x_t, _dbias, _amax):
484
485
486
            activation_lu_out, _ = tex.act_lu_fp8(
                x, amax, scale, scale_inv, FP8Helper.FWD_DTYPE, activation_type
            )
487
            activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
488
            ctx = x
489
490
491
492
            return activation_lu_out, ctx

        def _prim_func_bwd(ctx, g):
            x = ctx
493
494
495
496
            if len(self.activation_type) > 1:  # gated, no bias
                dactivation_lu, dactivation_lu_trans, amax_out = tex.dgated_act_lu_cast_transpose(
                    g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE, -1, activation_type
                )
497
                dbias = jnp.empty(x.shape[-1], x.dtype)
498
499
500
501
502
503
504
505
506
507
508
509
510
            else:  # not gated, with bias
                dactivation_lu, dactivation_lu_trans, dbias, amax_out = (
                    tex.dact_lu_dbias_cast_transpose(
                        g,
                        x,
                        amax,
                        scale,
                        scale_inv,
                        FP8Helper.BWD_DTYPE,
                        -1,
                        self.activation_type,
                    )
                )
511
512
513
514
515
516
517
            dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
            dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
            ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out)
            return ctx

        _prim_func.defvjp(_prim_func_fwd, _prim_func_bwd)

518
        dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_axes], dtype=x.dtype)
519
520
        dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
        amax_no_use = jnp.zeros(1, jnp.float32)
521
522
523
        value_n_grad_primitive_func = value_and_grad(
            lambda a, b, c, d: jnp.mean(_prim_func(a, b, c, d)), (0, 1, 2, 3)
        )
524
525
        return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use)

526
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
527
    @pytest.mark.parametrize("shape", [(32, 1, 64), (16, 64, 1, 256)])
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
    @pytest.mark.parametrize(
        "activation_type",
        [
            ("gelu",),
            ("gelu", "linear"),
            ("silu",),
            ("silu", "linear"),
            ("relu",),
            ("relu", "linear"),
            ("quick_gelu",),
            ("quick_gelu", "linear"),
            ("squared_relu",),
            ("squared_relu", "linear"),
        ],
    )
543
    def test_activation_lu(self, random_inputs, activation_type):
544
545
546
        self.amax = jnp.zeros(1, jnp.float32)
        self.scale = jnp.ones(1, jnp.float32)
        self.scale_inv = jnp.ones(1, jnp.float32)
547
        self.activation_type = activation_type
548
549

        x = random_inputs
550
551
552
553
        x = jnp.repeat(x, len(activation_type), axis=-2)
        axes = jnp.arange(x.ndim)
        self.transpose_axes = tuple([*axes[-2:]] + [*axes[:-2]])
        print(self.transpose_axes)
554

555
        prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x)
556
        ref_out, (ref_grad,) = self.ref_func(x, activation_type)
557

558
        assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
559
        assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
560
        if "linear" not in activation_type:
561
            assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
562
        assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
563
564
        assert_allclose(
            prim_grad_trans,
565
            jnp.transpose(ref_grad, self.transpose_axes),
566
567
            dtype=FP8Helper.BWD_DTYPE,
        )
568
569


570
571
572
573
class TestNorm:
    """
    Test transformer_engine.jax.layernorm APIs
    """
574

575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
    @staticmethod
    def _generate_fp8_meta():
        fp8_dtype_list = [FP8Helper.FWD_DTYPE, FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE]
        amax_list = [
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
        ]
        scale_list = [
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
        ]
        return fp8_dtype_list, amax_list, scale_list

590
591
592
593
594
595
596
597
    def reference_layernorm(self, x, scale, bias, zero_centered_gamma, eps):
        """
        JAX native layernorm implementations
        - bias is not None: layernorm
        - bias is None: rmsnorm
        """
        x_ = jnp.asarray(x, jnp.float32)
        if bias is None:
598
            mean = 0.0
599
600
601
602
603
        else:
            mean = jnp.mean(x_, axis=-1, keepdims=True)
        var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
        normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
        if zero_centered_gamma:
604
            scale += 1.0
605
        if bias is None:
606
            bias = 0.0
607
        return jnp.asarray(normed_input * scale + bias).astype(x.dtype)
608

609
610
611
612
613
614
615
616
    @pytest.mark.parametrize("n, hidden", LN_CASES)
    @pytest.mark.parametrize("dtype", DTYPES)
    @pytest.mark.parametrize("ln_type", ["layernorm", "rmsnorm"])
    @pytest.mark.parametrize("zero_centered_gamma", [False, True])
    @pytest.mark.parametrize("epsilon", [1e-2, 1e-6])
    def test_layernorm_forward_backward(
        self, n, hidden, ln_type, zero_centered_gamma, epsilon, dtype
    ):
617
618
619
620
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        expect_assert = False
621
        if ln_type == "rmsnorm" and zero_centered_gamma:
622
623
624
            # zero_centered_gamma is not supported for rmsnorm, expect an assertion.
            expect_assert = True

625
626
627
628
629
        with (
            pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*")
            if expect_assert
            else nullcontext()
        ):
630
631
632
633
634
635
636
            key = jax.random.PRNGKey(0)
            subkeys = jax.random.split(key, 3)

            x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1)
            gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
            gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
            gamma = jnp.asarray(gamma, dtype)
637
            if ln_type == "layernorm":
638
639
640
641
642
643
644
645
646
647
648
649
650
                beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
                beta = jnp.asarray(beta, dtype)
            else:
                beta = None

            def compute_loss(x):
                # Higher precision to compute the loss
                x_ = x.astype(jnp.float32)
                return jnp.mean(jnp.square(x_)).astype(x.dtype)

            jitted_primitive = jit(
                value_and_grad(
                    lambda x, gamma, beta: compute_loss(
651
652
653
654
655
                        layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)
                    ),
                    (0, 1, 2),
                )
            )
656
657
658
659

            jitted_reference = jit(
                value_and_grad(
                    lambda x, gamma, beta: compute_loss(
660
661
662
663
664
                        self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
                    ),
                    (0, 1, 2),
                )
            )
665

666
667
668
669
670
671
            primitive_out, (primitive_dx, primitive_dgamma, primitive_dbeta) = jitted_primitive(
                x, gamma, beta
            )
            reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference(
                x, gamma, beta
            )
672
673
674
675
676
677
678
679

            assert_allclose(primitive_out, reference_out, dtype=dtype)
            assert_allclose(primitive_dx, reference_dx, dtype=dtype)
            assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
            if beta is not None:
                assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
680
681
682
683
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
    @pytest.mark.parametrize("ln_type", ["layernorm", "rmsnorm"])
    @pytest.mark.parametrize("zero_centered_gamma", [True, False])
    @pytest.mark.parametrize("epsilon", [1e-2, 1e-6])
684
685
686
687
688
    def test_ln_fp8_dot_forward_backward(self, m, n, k, ln_type, zero_centered_gamma, epsilon):
        """
        Test transformer_engine.jax.layernorm.layernorm_fp8_dot
        """
        expect_assert = False
689
        if ln_type == "rmsnorm" and zero_centered_gamma:
690
691
692
            # zero_centered_gamma is not supported for rmsnorm, expect an assertion.
            expect_assert = True

693
694
695
696
697
        with (
            pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*")
            if expect_assert
            else nullcontext()
        ):
698
699
700
701
702
703
704
            key = jax.random.PRNGKey(0)
            subkeys = jax.random.split(key, 4)

            a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
            b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)

            gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
705
            if ln_type == "layernorm":
706
707
708
709
                beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
            else:
                beta = None

710
711
712
713
714
715
716
717
718
719
720
            _, amax_list_1, scale_list_1 = TestNorm._generate_fp8_meta()

            def primitive_func(x, y, gamma, beta, amax_list_1, scale_list_1):
                fp8_meta_pkg = FP8MetaPackage(
                    amax_list_1[0],
                    scale_list_1[0],
                    amax_list_1[1],
                    scale_list_1[1],
                    amax_list_1[2],
                    scale_list_1[2],
                )
721
722
723
                primitive_out = layernorm_fp8_dot(
                    x, y, gamma, beta, fp8_meta_pkg, ln_type, zero_centered_gamma
                )
724
725
726
727
728
729
                return jnp.mean(primitive_out)

            def ref_func(x, y, gamma, beta, zero_centered_gamma):
                x = self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
                return jnp.mean(jnp.dot(x, y))

730
            value_n_grad_primitive_func = value_and_grad(primitive_func, range(6))
731
732
            value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))

733
734
735
            ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad, ref_beta_grad) = (
                value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma)
            )
736
737

            for _ in range(3):
738
739
740
741
742
743
744
745
                primitive_out, (
                    primitive_a_grad,
                    primitive_b_grad,
                    primitive_gamma_grad,
                    primitive_beta_grad,
                    amax_list_1,
                    scale_list_1,
                ) = value_n_grad_primitive_func(a, b, gamma, beta, amax_list_1, scale_list_1)
746
747
748
749
750
751
752

            assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
            assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
            assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
            assert_allclose(primitive_gamma_grad, ref_gamma_grad, dtype=FP8Helper.BWD_DTYPE)
            if beta is not None:
                assert_allclose(primitive_beta_grad, ref_beta_grad, dtype=FP8Helper.BWD_DTYPE)
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814


@pytest.mark.parametrize(
    "in_dtype",
    [
        pytest.param(jnp.float32, id="input_float32"),
        pytest.param(jnp.float16, id="input_float16"),
        pytest.param(jnp.bfloat16, id="input_bfloat16"),
    ],
)
@pytest.mark.parametrize(
    "input_shape, transpose_axis",
    [
        pytest.param((16, 16), 1, id="(16, 16)-1"),
        pytest.param((256, 128), 1, id="(256, 128)-1"),
        pytest.param((128, 512), 1, id="(128, 512)-1"),
        pytest.param((64, 16, 4, 256), 1, id="(64, 16, 4, 256)-1"),
        pytest.param((64, 16, 4, 256), 2, id="(64, 16, 4, 256)-2"),
        pytest.param((64, 16, 4, 256), 3, id="(64, 16, 4, 256)-3"),
    ],
)
class TestTranspose:
    def test_transpose(self, in_dtype, input_shape, transpose_axis):
        key = jax.random.PRNGKey(0)
        input_tensor = jax.random.uniform(key, input_shape, in_dtype)
        static_axis_boundary = -1
        jax_output = _jax_transpose(input_tensor, static_axis_boundary, transpose_axis)
        os.environ["NVTE_JAX_WITH_FFI"] = "0"
        noffi_output = tex.transpose(input_tensor, static_axis_boundary, transpose_axis)
        os.environ["NVTE_JAX_WITH_FFI"] = "1"
        ffi_output = tex.transpose(input_tensor, static_axis_boundary, transpose_axis)
        assert_allclose(jax_output, noffi_output)
        assert_allclose(noffi_output, ffi_output)

    @pytest.mark.parametrize(
        "out_dtype",
        [
            pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"),
            pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"),
        ],
    )
    def test_cast_transpose(self, in_dtype, input_shape, transpose_axis, out_dtype):
        amax = jnp.zeros(1, jnp.float32)
        scale = jnp.ones(1, jnp.float32)
        scale_inv = jnp.ones(1, jnp.float32)
        key = jax.random.PRNGKey(0)
        input = jax.random.uniform(key, input_shape, in_dtype)
        static_axis_boundary = -1
        jax_output = _jax_cast_transpose(
            input, scale, amax, out_dtype, static_axis_boundary, transpose_axis
        )
        os.environ["NVTE_JAX_WITH_FFI"] = "0"
        noffi_output = tex.cast_transpose(
            input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
        )
        os.environ["NVTE_JAX_WITH_FFI"] = "1"
        ffi_output = tex.cast_transpose(
            input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
        )
        assert_tree_like_allclose(jax_output, ffi_output)
        assert_tree_like_allclose(noffi_output, ffi_output)

815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
    @pytest.mark.parametrize(
        "out_dtype",
        [
            pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"),
            pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"),
        ],
    )
    def test_dbias_cast_transpose(self, in_dtype, input_shape, transpose_axis, out_dtype):
        amax = jnp.zeros(1, jnp.float32)
        scale = jnp.ones(1, jnp.float32)
        scale_inv = jnp.ones(1, jnp.float32)
        key = jax.random.PRNGKey(0)
        input = jax.random.uniform(key, input_shape, in_dtype)
        static_axis_boundary = -1
        jax_output = _jax_dbias_cast_transpose(
            input, amax, scale, out_dtype, static_axis_boundary, transpose_axis
        )
        os.environ["NVTE_JAX_WITH_FFI"] = "0"
        noffi_output = tex.dbias_cast_transpose(
            input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
        )
        os.environ["NVTE_JAX_WITH_FFI"] = "1"
        ffi_output = tex.dbias_cast_transpose(
            input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
        )
        assert_tree_like_allclose(jax_output, ffi_output)
        assert_tree_like_allclose(noffi_output, ffi_output)

843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize(
    "input_shape",
    [
        pytest.param((256, 128), id="(256, 128)"),
        pytest.param((128, 512, 8), id="(128, 512, 8)"),
    ],
)
@pytest.mark.parametrize(
    "in_dtype",
    [
        pytest.param(jnp.float32, id="input_float32"),
        pytest.param(jnp.float16, id="input_float16"),
        pytest.param(jnp.bfloat16, id="input_bfloat16"),
    ],
)
@pytest.mark.parametrize(
    "out_dtype",
    [
        pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"),
        pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"),
    ],
)
def test_quantize(input_shape, in_dtype, out_dtype):
    amax = jnp.zeros(1, jnp.float32)
    scale = jnp.ones(1, jnp.float32)
    scale_inv = jnp.ones(1, jnp.float32)
    key = jax.random.PRNGKey(0)
    input = jax.random.uniform(key, input_shape, in_dtype)
    jax_output = _jax_cast_fp8(input, scale, amax, out_dtype)
    os.environ["NVTE_JAX_WITH_FFI"] = "0"
    noffi_output = tex.cast_fp8(input, amax, scale, scale_inv, out_dtype)
    os.environ["NVTE_JAX_WITH_FFI"] = "1"
    ffi_output = tex.cast_fp8(input, amax, scale, scale_inv, out_dtype)
    assert_tree_like_allclose(jax_output, ffi_output)
    assert_tree_like_allclose(noffi_output, ffi_output)