test_custom_call_compute.py 78 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.

import jax
import jax.numpy as jnp
import pytest
from jax import jit, value_and_grad
9
from functools import reduce
10
from typing import Union
11
12
13
14
15
import operator

from utils import (
    assert_allclose,
    pytest_parametrize_wrapper,
Alp Dener's avatar
Alp Dener committed
16
    use_jax_gemm,
17
18
19
20
21
)
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.layernorm_mlp import layernorm_mlp

from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu, _jax_quantize_dact_dbias
22
23
24
25
26
from transformer_engine.jax.cpp_extensions.normalization import (
    _jax_layernorm,
    _jax_rmsnorm,
    is_norm_zero_centered_gamma_in_weight_dtype,
)
27
28
29
from transformer_engine.jax.cpp_extensions.quantization import (
    _jax_quantize,
    _jax_quantize_dbias,
30
)
31
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
32
from transformer_engine.jax import cpp_extensions as tex
33
from transformer_engine.jax.quantize import (
34
    NoScaleTensor,
35
    ScaledTensor,
36
37
38
    ScaledTensor1x,
    ScaledTensor2x,
    GroupedScaledTensor1x,
39
40
    ScalingMode,
    QuantizerFactory,
41
    QuantizeLayout,
42
    noop_quantizer_set,
43
44
    QuantizeMetaSet,
    QuantizeMeta,
45
46
47
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
48
from transformer_engine.jax.dense import dense, grouped_dense
49
from transformer_engine.jax.layernorm_dense import layernorm_dense
50

Tim Moon's avatar
Tim Moon committed
51
52
53
54
55
56
57
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (2048, 1024, 2048),
    (2048, 2048, 1024),
    (2048, 1024, 1024),
]
58
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
59
LN_CASES = [(256, 128), (128, 256)]
60
DTYPES = [jnp.bfloat16, jnp.float32]
61

62
63
64
65
66
67
68
69
70
71
72
# TODO(Phuong): remove unneccessary pytest skips
is_fp8_supported, fp8_unsupported_reason = helper.is_scaling_mode_supported(
    ScalingMode.DELAYED_TENSOR_SCALING
)
is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_scaling_mode_supported(
    ScalingMode.MXFP8_1D_SCALING
)
is_fp4_supported, fp4_unsupported_reason = helper.is_scaling_mode_supported(
    ScalingMode.NVFP4_1D_SCALING
)

73
""" Find supported scaling modes"""
74
75
76
77
supported_scaling_modes = helper.get_supported_scaling_modes()
non_fp4_supported_scaling_modes = [s for s in supported_scaling_modes if not s.is_nvfp4_scaling]
supported_recipes = helper.get_supported_quantization_recipes()
supported_recipes = [pytest.param(r, id=r.__class__.__name__) for r in supported_recipes]
78
79
80
81
82
83


def is_shape_supported_by_mxfp8(input_shape):
    try:
        if isinstance(input_shape, type(pytest.param(0))):
            input_shape = input_shape.values[0]
84
        ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape)
85
86
87
88
89
90
        return True
    except:
        # get_scale_shapes will raise an exception if the shape is not supported
        return False


91
92
93
def assert_bitwise_scaled_tensors(
    a: ScaledTensor, b: ScaledTensor, precise_comparison: bool = True
):
94
    if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
95
        if not precise_comparison and not a.scaling_mode.is_nvfp4_scaling:
96
97
98
            assert_allclose(a.dequantize(), b.dequantize(), dtype=a.data.dtype)
            return

99
        assert a.scaling_mode == b.scaling_mode
100
        assert a.scale_inv.dtype == b.scale_inv.dtype
101
        assert a.data_layout == b.data_layout
102
103
104
105
106
        if a.scaling_mode.is_tensor_scaling():
            # Assert in dq_dtype as some unfused codepaths have an intermediate cast
            # to an input dtype which reduces precision compared to everything in fp32
            assert_allclose(a.scale_inv, b.scale_inv, dtype=a.dq_dtype)
        elif a.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
107
108
            # Compare MXFP8 scales as uint8
            assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8))
109
110
111
112
113
114
115
116
117
118
        elif a.scaling_mode.is_nvfp4_scaling:
            assert_allclose(a.amax, b.amax)
            assert_allclose(a.scale_inv, b.scale_inv)
            if not precise_comparison:
                mismatch = a.data != b.data
                mismatch_fraction = jnp.mean(mismatch.astype(jnp.float32))
                assert (
                    mismatch_fraction < 0.05
                ), f"Mismatch fraction {mismatch_fraction} is too high"
                return
119
        else:
120
            raise ValueError(f"Unsupported scaling mode {a.scaling_mode}")
121
        assert_allclose(a.data, b.data)
122

123
    elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x):
124
125
126
127
128
129
        assert_bitwise_scaled_tensors(
            a.rowwise_tensor, b.rowwise_tensor, precise_comparison=precise_comparison
        )
        assert_bitwise_scaled_tensors(
            a.colwise_tensor, b.colwise_tensor, precise_comparison=precise_comparison
        )
130
131
132
133
134
135
    else:
        pytest.fail("Unsupported input types")


def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
    if isinstance(a, ScaledTensor1x):
136
137
138
        if a.data_layout == "T":
            flatten_axis = a.data.ndim - a.flatten_axis
            b_transpose = jnp.transpose(b, (*range(flatten_axis, b.ndim), *range(flatten_axis)))
139
140
141
142
            assert_allclose(a.dequantize(), b_transpose, dtype=a.data.dtype)
        else:
            assert_allclose(a.dequantize(), b, dtype=a.data.dtype)
    elif isinstance(a, ScaledTensor2x):
143
144
        assert_dequantized_scaled_tensor(a.rowwise_tensor, b)
        assert_dequantized_scaled_tensor(a.colwise_tensor, b)
145
146
147
148
    else:
        pytest.fail("a must be a ScaledTensor object")


149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def assert_dequantized_grouped_scaled_tensor(
    a: Union[GroupedScaledTensor1x, ScaledTensor2x], b: jnp.ndarray
):
    if isinstance(a, GroupedScaledTensor1x):
        assert a.group_sizes.sum() == b.shape[0]
        b = jnp.split(b, jnp.cumulative_sum(a.group_sizes)[:-1], axis=0)
        dq_a = a.dequantize()
        for dq_a_i, b_i in zip(dq_a, b):
            if len(dq_a_i) == 0:
                continue
            if a.data_layout == "T":
                data_ndim = len(a.original_shape)
                flatten_axis = a.flatten_axis
                if b_i.shape[0] == 1:
                    b_i = jnp.transpose(
                        b_i, (0, *range(flatten_axis, data_ndim), *range(1, flatten_axis))
                    )
                else:
                    b_i = jnp.transpose(
                        b_i, (*range(flatten_axis, data_ndim), *range(flatten_axis))
                    )
            dq_a_i = dq_a_i.reshape(b_i.shape)
            assert_allclose(dq_a_i, b_i, dtype=a.data.dtype)
    elif isinstance(a, ScaledTensor2x):
173
174
175
176
        assert isinstance(a.rowwise_tensor, GroupedScaledTensor1x)
        assert isinstance(a.colwise_tensor, GroupedScaledTensor1x)
        assert_dequantized_grouped_scaled_tensor(a.rowwise_tensor, b)
        assert_dequantized_grouped_scaled_tensor(a.colwise_tensor, b)
177
178
179
180
    else:
        pytest.fail("a must be a GroupedScaledTensor object")


181
182
183
184
185
186
187
188
189
190
191
192
ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)]
ALL_ACTIVATION_TYPES = [
    ("gelu",),
    ("gelu", "linear"),
    ("silu",),
    ("silu", "linear"),
    ("relu",),
    ("relu", "linear"),
    ("quick_gelu",),
    ("quick_gelu", "linear"),
    ("squared_relu",),
    ("squared_relu", "linear"),
193
    ("clamped_silu", "clamped_linear"),
194
]
195

196
197
198
199
200
201
202
ACTIVATION_TYPES = {
    "L0": [
        ("gelu",),
        ("gelu", "linear"),
    ],
    "L2": ALL_ACTIVATION_TYPES,
}
203
204


205
class TestActivation:
206
207
    def ref_act(self, x, activation_type, act_params):
        return _jax_act_lu(x, activation_type, act_params=act_params).data
208

209
    def value_n_grad_ref_func(self, x, activation_type, act_params):
210
        jitted_reference = jit(
211
212
213
            value_and_grad(
                lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,)
            )
214
215
        )
        return jitted_reference(x)
216

217
218
219
220
    def primitive_func(self, inputs, activation_type, quantizer, act_params):
        out = activation(
            inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params
        )
221
222
223
224
225
226
227
228
229
230
        return jnp.mean(out)

    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper(
        "activation_type",
        (
            ALL_ACTIVATION_TYPES  # Test all activation types for this test to ensure all are functional, then just test a subset for the other tests to verify other functionality
        ),
    )
    def test_act_grad(self, shape, activation_type):
231
        key = jax.random.PRNGKey(0)
232
        x = jax.random.uniform(key, shape, jnp.float32)
233
234
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
235

236
        value_n_grad_primitive_func = jit(
237
            value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3)
238
        )
239
240
241
242
243
244
245
246
247
248
249
250
        act_args = (
            {"limit": 0.75, "alpha": 1.702}
            if activation_type == ("clamped_silu", "clamped_linear")
            else {}
        )
        act_params = (
            tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
            if activation_type == ("clamped_silu", "clamped_linear")
            else None
        )
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params)
        ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)
251
252
        assert_allclose(prim_out, ref_out, dtype=x.dtype)
        assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
253

254
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
255
256
257
    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
258
259
260
261
262
263
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_act_grad_with_tensor_scaling_fp8(
        self, random_inputs, activation_type, output_type, scaling_mode
    ):
264
        x = random_inputs
265
266
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
267
        self.activation_type = activation_type
268

269
        value_n_grad_primitive_func = jit(
270
271
            value_and_grad(self.primitive_func, (0,)),
            static_argnums=(1, 3),
272
        )
273

274
        quantizer = QuantizerFactory.create(
275
            scaling_mode=scaling_mode,
276
            q_dtype=output_type,
277
            q_layout=QuantizeLayout.ROWWISE,
278
        )
279
280
281
282
283
        act_args = (
            {"limit": 0.75, "alpha": 1.702}
            if activation_type == ("clamped_silu", "clamped_linear")
            else {}
        )
284

285
286
287
288
289
290
291
292
293
        act_params = (
            tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
            if activation_type == ("clamped_silu", "clamped_linear")
            else None
        )
        prim_out, (prim_grad,) = value_n_grad_primitive_func(
            x, activation_type, quantizer, act_params
        )
        ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)
294

295
296
        assert_allclose(prim_out, ref_out, dtype=output_type)
        assert_allclose(prim_grad, ref_grad, dtype=output_type)
297

298
    @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
299
300
301
    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
302
303
304
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
305
306
307
308
309
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_act_forward_with_tensor_scaling_fp8(
        self, random_inputs, activation_type, output_type, q_layout, scaling_mode
310
311
    ):
        x = random_inputs
312
313
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
314
        self.activation_type = activation_type
315

316
317
        te_quantizer, jax_quantizer = QuantizerFactory.create(
            n_quantizers=2,
318
            scaling_mode=scaling_mode,
319
            q_dtype=output_type,
320
            q_layout=q_layout,
321
        )
322
323
324
325
326
327
328
329
330
331
332
333
        act_args = (
            {"limit": 0.75, "alpha": 1.702}
            if activation_type == ("clamped_silu", "clamped_linear")
            else {}
        )
        act_params = (
            tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
            if activation_type == ("clamped_silu", "clamped_linear")
            else None
        )
        te_output = tex.act_lu(x, activation_type, te_quantizer, act_params)
        jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params)
334
        assert_bitwise_scaled_tensors(te_output, jax_output)
335

336
    @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
337
    @pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)])
338
339
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
340
341
342
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
343
    def test_act_forward_with_block_scaling_fp8(
344
        self, random_inputs, activation_type, output_type, q_layout
345
346
    ):
        x = random_inputs
347
        x = jnp.repeat(x, len(activation_type), axis=-2)
348
        self.activation_type = activation_type
349

350
        quantizer = QuantizerFactory.create(
351
            scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
352
        )
353
354
355
356
357
358
359
360
361
362
363
364
        act_args = (
            {"limit": 0.75, "alpha": 1.702}
            if activation_type == ("clamped_silu", "clamped_linear")
            else {}
        )
        act_params = (
            tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
            if activation_type == ("clamped_silu", "clamped_linear")
            else None
        )
        output = tex.act_lu(x, activation_type, quantizer, act_params)
        ref_out = self.ref_act(x, activation_type, act_params)
365
        assert_dequantized_scaled_tensor(output, ref_out)
366
367


368
369
370
371
NORM_OUTPUT_DTYPES = {
    "L0": [jnp.float8_e4m3fn],
    "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
372

373

374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
@pytest_parametrize_wrapper("n, hidden", LN_CASES)
@pytest_parametrize_wrapper("inp_dtype", DTYPES)
@pytest_parametrize_wrapper("norm_type", ["layernorm", "rmsnorm"])
@pytest_parametrize_wrapper(
    "zero_centered_gamma",
    [
        pytest.param(True, id="zero_centered"),
        pytest.param(False, id="no_zero_centered"),
    ],
)
@pytest_parametrize_wrapper("epsilon", [1e-2, 1e-6])
class TestNorm:
    """
    Test transformer_engine.jax.layernorm APIs
    """
389

390
391
392
393
394
395
396
397
398
399
400
401
402
    def _test_norm_grad(
        self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
    ):
        def compute_loss(x):
            # Higher precision to compute the loss
            x_ = x.astype(jnp.float32)
            return jnp.mean(jnp.square(x_)).astype(x.dtype)

        def reference_func(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer):
            if norm_type == "rmsnorm":
                ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
            else:
                ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
403
404
            # This is a no-op for non-quantized data
            ln_out = ln_out.dequantize()
405
            return ln_out
406

407
408
409
410
411
412
413
414
415
416
417
418
419
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 3)

        x = jax.random.uniform(subkeys[0], (n, hidden), jnp.float32, -1, 1)
        x = x.astype(inp_dtype)
        gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
        gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
        gamma = jnp.asarray(gamma, inp_dtype)
        if norm_type == "layernorm":
            beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
            beta = jnp.asarray(beta, inp_dtype)
        else:
            beta = None
420

421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
        jitted_reference = jit(
            value_and_grad(
                lambda x, gamma, beta: compute_loss(
                    reference_func(
                        x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer=None
                    )
                ),
                (0, 1, 2),
            )
        )
        jitted_primitive = jit(
            value_and_grad(
                lambda x, gamma, beta: compute_loss(
                    layernorm(x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer)
                ),
                (0, 1, 2),
437
            )
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
        )

        reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference(
            x, gamma, beta
        )
        primitive_out, (primitive_dx, primitive_dgamma, primitive_dbeta) = jitted_primitive(
            x, gamma, beta
        )

        out_dtype = inp_dtype if quantizer is None else quantizer.q_dtype
        assert_allclose(primitive_out, reference_out, dtype=out_dtype)
        assert_allclose(primitive_dx, reference_dx, dtype=out_dtype)
        assert_allclose(primitive_dgamma, reference_dgamma, dtype=out_dtype)
        if beta is not None:
            assert_allclose(primitive_dbeta, reference_dbeta, dtype=out_dtype)
453

454
455
456
457
458
459
460
461
462
463
    def test_norm_grad(self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype):
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        self._test_norm_grad(
            n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer=None
        )
464

465
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
466
467
    # No Norm FWD E5M2 in TE backend
    @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
468
469
470
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
471
472
473
474
475
476
477
478
479
480
481
482
483
484
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_norm_grad_with_tensor_scaling_fp8(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        q_layout,
        scaling_mode,
485
486
487
488
489
490
491
492
    ):
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        quantizer = QuantizerFactory.create(
493
            scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
        )
        self._test_norm_grad(
            n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
        )

    def _test_norm_forward(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        scaling_mode,
509
        q_layout,
510
    ):
511
        key = jax.random.PRNGKey(0)
512
        subkeys = jax.random.split(key, 3)
513

514
515
516
517
518
        x = jax.random.uniform(subkeys[0], (n, hidden), inp_dtype, -1, 1)
        x = jnp.asarray(x, inp_dtype)
        gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
        gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
        gamma = jnp.asarray(gamma, inp_dtype)
519

520
        quantizer, ref_quantizer = QuantizerFactory.create(
521
            n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
522
523
524
525
526
527
        )
        if norm_type == "layernorm":
            beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
            beta = jnp.asarray(beta, inp_dtype)
            output, mu, rsigma = tex.layernorm_fwd(
                x, gamma, beta, zero_centered_gamma, epsilon, quantizer=quantizer
528
            )
529
            ref_out, ref_mu, ref_rsigma = _jax_layernorm(
530
531
532
533
534
535
                x,
                gamma,
                beta,
                zero_centered_gamma,
                epsilon,
                quantizer=ref_quantizer,
536
            )
537
538
539
        else:
            output, rsigma = tex.rmsnorm_fwd(
                x, gamma, zero_centered_gamma, epsilon, quantizer=quantizer
540
            )
541
            ref_out, ref_rsigma = _jax_rmsnorm(
542
543
544
545
546
                x,
                gamma,
                zero_centered_gamma,
                epsilon,
                quantizer=ref_quantizer,
547
            )
548
            ref_mu = None
549

550
551
552
        precise_comparison = True

        if get_cudnn_version() < (9, 10, 0) and scaling_mode == ScalingMode.MXFP8_1D_SCALING:
553
554
            # Reduce precision of test as we don't use fused norm below this version CuDNN for MXFP8 and instead
            # do an unfused norm and quantize with an intermediate cast into in_dtype which can reduce precision
555
556
557
558
559
560
561
562
563
564
565
            precise_comparison = False
        elif is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode):
            # Larger tolerances as our JAX implementation _jax_*norm uses the compute dtype float32
            # for zero-centered gamma always
            precise_comparison = False
        elif scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING and inp_dtype != jnp.float32:
            # Current implementation of Current Tensor Scaling performs unfused layernorm and quantization
            # and writes intermediate results into the input dtype, which will slightly reduce precision
            # if the input dtype is not float32
            precise_comparison = False

566
        assert_bitwise_scaled_tensors(output, ref_out, precise_comparison=precise_comparison)
567

568
569
570
        assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype)
        if norm_type == "layernorm":
            assert_allclose(mu, ref_mu, dtype=inp_dtype)
571

572
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
573
574
    # No Norm FWD E5M2 in TE backend
    @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
575
576
577
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
578
579
580
581
582
583
584
585
586
587
588
589
590
591
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_norm_forward_with_tensor_scaling_fp8(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        q_layout,
        scaling_mode,
592
593
594
595
596
597
598
599
600
601
602
603
    ):
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        self._test_norm_forward(
            n=n,
            hidden=hidden,
            norm_type=norm_type,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon,
            inp_dtype=inp_dtype,
            out_dtype=out_dtype,
604
            scaling_mode=scaling_mode,
605
            q_layout=q_layout,
606
        )
607

608
    @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
609
610
611
612
613
614
    @pytest.mark.parametrize(
        "out_dtype",
        [
            jnp.float8_e4m3fn,
        ],
    )
615
616
617
618
619
620
621
622
623
624
625
    def test_norm_forward_with_block_scaling_fp8(
        self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype
    ):
        self._test_norm_forward(
            n=n,
            hidden=hidden,
            norm_type=norm_type,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon,
            inp_dtype=inp_dtype,
            out_dtype=out_dtype,
626
            scaling_mode=ScalingMode.MXFP8_1D_SCALING,
627
            q_layout=QuantizeLayout.ROWWISE_COLWISE,
628
        )
629
630


631
QUANTIZE_OUTPUT_FP8_DTYPES = {
632
633
634
    "L0": [jnp.float8_e4m3fn],
    "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
635
636
637
638
639
640
641
642
643
644
645
646
647
648
QUANTIZE_OUTPUT_DTYPES = {
    test_level: QUANTIZE_OUTPUT_FP8_DTYPES[test_level] + [jnp.float4_e2m1fn]
    for test_level in QUANTIZE_OUTPUT_FP8_DTYPES
}
QUANTIZE_QDTYPE_AND_SCALING_MODES = {
    test_level: [
        (q_dtype, scaling_mode)
        for q_dtype, scaling_mode in zip(
            QUANTIZE_OUTPUT_FP8_DTYPES[test_level], supported_scaling_modes
        )
        if q_dtype in scaling_mode.get_compatible_q_dtypes()
    ]
    for test_level in QUANTIZE_OUTPUT_FP8_DTYPES
}
649

650
651
652
ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [
    ((32, 64), -1),
    ((2, 64, 32), -1),
653
    ((64, 2, 32), -2),
654
655
656
    ((32, 256, 128), -1),
    ((32, 256, 128), -2),
    ((64, 32, 32, 256), -1),
657
    ((8192, 2, 4096), -2),
658
]
659

660
QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = {
661
    "L0": [
662
663
        ((32, 64), -1),
        ((2, 64, 32), -1),
664
        ((64, 2, 32), -2),
665
    ],
666
    "L2": ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES,
667
}
668

669
670
671
672
673
674
QUANTIZATION_INPUT_DTYPE = {
    "L0": [jnp.bfloat16],
    "L2": [jnp.float32, jnp.float16, jnp.bfloat16],
}


675
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
676
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
677
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2, jnp.float4_e2m1fn])
678
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
679
680
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper(
681
682
683
684
685
686
    "q_layout",
    [
        QuantizeLayout.ROWWISE,
        QuantizeLayout.COLWISE,
        QuantizeLayout.ROWWISE_COLWISE,
    ],
687
688
689
690
691
692
)
class TestQuantize:
    """
    Purely quantization related tests that will always test on a wider set of types and shapes
    """

693
694
    def _skip_unsupported_dtypes(self, q_dtype, scaling_mode):
        """Skip unsupported dtypes for given scaling mode. For example, NVFP4 only supports the float4_e2m1 dtype not float8 dtypes."""
695
696
697
698
        if q_dtype not in scaling_mode.get_compatible_q_dtypes():
            pytest.skip(f"Quantize dtype {q_dtype} is not supported by {scaling_mode}")
            return

699
    def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
700
        self._skip_unsupported_dtypes(q_dtype, scaling_mode)
701

702
        key = jax.random.PRNGKey(0)
703

704
705
706
707
        # Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling)
        quantizer = QuantizerFactory.create(
            scaling_mode=scaling_mode,
            q_dtype=q_dtype,
708
            q_layout=q_layout,
709
        )
710

711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
        if scaling_mode.is_nvfp4_scaling:
            if in_dtype != jnp.bfloat16:
                pytest.skip("NVFP4 scaling only supported with bfloat16 input dtype currently")
                return
            q_func = _jax_quantize
            # For NVFP4 scaling, the maximum possible error for a single value can be high between the dequantized and original tensors. To ensure quantization and dequantization is operating correctly without requiring a very high tolerance for all values, we instead test that quantizing the dequantized tensor is bitwise identical to the original quantized tensor.
            x = jax.random.uniform(key, input_shape, in_dtype) * 10
            q1 = q_func(x, quantizer=quantizer, flatten_axis=flatten_axis)

            dq_rowwise = None
            dq_colwise = None
            if isinstance(q1, ScaledTensor1x):
                dq = q1.dequantize()
                if q1.is_colwise:
                    dq_colwise = dq
                else:
                    dq_rowwise = dq
            elif isinstance(q1, ScaledTensor2x):
                dq_rowwise = q1.rowwise_tensor.dequantize()
                dq_colwise = q1.colwise_tensor.dequantize()
            else:
                raise ValueError(f"Unsupported output type {type(q1)}")

            # We only compare Q-DQ for the same quantization layout. If we for example QDQ rowwise, then re-quantize colwise, the error will be larger and may not be bitwise identical to the original colwise quantization.
            if dq_rowwise is not None:
                assert (
                    dq_rowwise.shape == x.shape
                ), f"dq_rowwise shape {dq_rowwise.shape} != x shape {x.shape}"
                q2_rowwise = q_func(dq_rowwise, quantizer=quantizer, flatten_axis=flatten_axis)
                q2_rowwise = (
                    q2_rowwise
                    if isinstance(q2_rowwise, ScaledTensor1x)
                    else q2_rowwise.rowwise_tensor
                )
                q1_rowwise = q1 if isinstance(q1, ScaledTensor1x) else q1.rowwise_tensor
                assert_bitwise_scaled_tensors(q1_rowwise, q2_rowwise)

            if dq_colwise is not None:
                # Since this is for NVFP4, we are assuming colwise has T layout and we do a transpose here to get back to original shape
                flatten_axis = flatten_axis + len(input_shape) if flatten_axis < 0 else flatten_axis
                colwise_flatten_axis = len(input_shape) - flatten_axis
                dq_colwise = jnp.transpose(
                    dq_colwise,
                    (*range(colwise_flatten_axis, dq_colwise.ndim), *range(colwise_flatten_axis)),
                )
                assert (
                    dq_colwise.shape == x.shape
                ), f"dq_colwise shape {dq_colwise.shape} != x shape {x.shape}"
                q2_colwise = q_func(dq_colwise, quantizer=quantizer, flatten_axis=flatten_axis)
                q2_colwise = (
                    q2_colwise
                    if isinstance(q2_colwise, ScaledTensor1x)
                    else q2_colwise.colwise_tensor
                )
                q1_colwise = q1 if isinstance(q1, ScaledTensor1x) else q1.colwise_tensor
                assert_bitwise_scaled_tensors(q1_colwise, q2_colwise)

            assert (
                dq_rowwise is not None or dq_colwise is not None
            ), "At least one of rowwise or colwise dq must be not None"
            return

773
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
774
775
776
        for _ in range(n_iterations):
            x = jax.random.uniform(key, input_shape, in_dtype)

777
            scaled_tensor = quantizer.quantize(x, flatten_axis=flatten_axis)
778
779
            assert_dequantized_scaled_tensor(scaled_tensor, x)

780
    def _should_use_precise_comparison(
781
        self, in_dtype, scaling_mode, quantizer, input_shape, flatten_axis
782
783
784
785
786
787
788
    ):
        if scaling_mode.is_nvfp4_scaling and in_dtype != jnp.bfloat16:
            # With NVFP4 scaling, TE kernels internally use bfloat16 so using a different input dtype can lead to small numerical differences compared to the JAX implementation
            return False

        return True

789
790
791
    def test_quantize_bitwise(
        self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
    ):
792
        self._skip_unsupported_dtypes(q_dtype, scaling_mode)
793
794
795
796
797

        key = jax.random.PRNGKey(0)
        input = jax.random.uniform(key, input_shape, in_dtype)

        te_quantizer, jax_quantizer = QuantizerFactory.create(
798
            n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout
799
        )
800

801
        jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
802

803
        te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
804
805
806
807
808

        assert_bitwise_scaled_tensors(
            te_output,
            jax_output,
            precise_comparison=self._should_use_precise_comparison(
809
                in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis
810
811
812
813
814
815
            ),
        )

    def test_quantize_bitwise_jitted(
        self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
    ):
816
        self._skip_unsupported_dtypes(q_dtype, scaling_mode)
817
818
819
820
821
822
823
824
825
826
827
828
829

        key = jax.random.PRNGKey(0)
        input = jax.random.uniform(key, input_shape, in_dtype)

        te_quantizer, jax_quantizer = QuantizerFactory.create(
            n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout
        )

        jax_impl_func_jit = jax.jit(_jax_quantize, static_argnums=(2, 3))
        te_impl_func_jit = jax.jit(tex.quantize, static_argnums=(2,))

        jax_output = jax_impl_func_jit(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)

830
        te_output = te_impl_func_jit(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
831
832
833
834
835

        assert_bitwise_scaled_tensors(
            te_output,
            jax_output,
            precise_comparison=self._should_use_precise_comparison(
836
                in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
            ),
        )


@pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16])
@pytest_parametrize_wrapper("q_dtype", [jnp.float4_e2m1fn])
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
@pytest_parametrize_wrapper(
    "scaling_mode", [s for s in supported_scaling_modes if s.is_nvfp4_scaling]
)
@pytest_parametrize_wrapper(
    "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
)
class TestStochasticRounding:

    def _dequantize(self, scaled_tensor) -> list[jnp.ndarray]:
        """Dequantizes a ScaledTensor back to it's original jnp.ndarray form. This always returns an array of jnp.ndarrays, for ScaledTensor2x there will be two tensors, for ScaledTensor1x there will be one tensor."""
        if isinstance(scaled_tensor, ScaledTensor1x):
            dq = scaled_tensor.dequantize()
            if scaled_tensor.data_layout == "T":
                dq = jnp.transpose(
                    dq,
                    (
                        *range(scaled_tensor.flatten_axis, dq.ndim),
                        *range(scaled_tensor.flatten_axis),
                    ),
                )
            return [dq]
        elif isinstance(scaled_tensor, ScaledTensor2x):
            [rowwise_dq] = self._dequantize(scaled_tensor.rowwise_tensor)
            [colwise_dq] = self._dequantize(scaled_tensor.colwise_tensor)
            return [rowwise_dq, colwise_dq]
        raise ValueError(
            "Unsupported ScaledTensor type, expected ScaledTensor but received"
            f" {type(scaled_tensor)}"
        )

    def _sample_sr_qdq(
        self, num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
    ) -> list[jnp.ndarray]:
        """Samples num_samples quantize-dequantize operations with stochastic rounding enabled and returns the dequantized tensors."""
        dq_tensors = []

        key = jax.random.PRNGKey(0)

        for i in range(num_samples):
            iter_key = jax.random.fold_in(key, i)
            sr_rng_state = jax.random.randint(
885
                iter_key, (1, 4), minval=0, maxval=2**30 - 1, dtype=jnp.uint32
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
            )
            quantizer = QuantizerFactory.create(
                q_dtype=q_dtype,
                scaling_mode=scaling_mode,
                q_layout=q_layout,
                stochastic_rounding_rng_state=sr_rng_state,
            )

            q_output = q_func(inputs, quantizer=quantizer, flatten_axis=flatten_axis)
            iter_dq = self._dequantize(q_output)
            dq_tensors.extend(iter_dq)

            avg_sr_tensor = jnp.mean(jnp.stack(dq_tensors), axis=0)
            assert avg_sr_tensor.shape == inputs.shape, (
                f"Dequantized tensor shape {avg_sr_tensor.shape} does not match input shape"
                f" {inputs.shape}"
            )

            sr_mae = jnp.mean(jnp.abs(avg_sr_tensor - inputs))

        dq_var = jnp.var(jnp.stack(dq_tensors))
        assert (
            dq_var > 0
        ), "Variance of dequantized tensors is zero, stochastic rounding may not be working"

        return dq_tensors

    def _round_nearest(
        self, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
    ) -> jnp.ndarray:
        """Quantizes and dequantizes the input tensor with round nearest quantization."""
        quantizer = QuantizerFactory.create(
            q_dtype=q_dtype,
            scaling_mode=scaling_mode,
            q_layout=q_layout,
            stochastic_rounding_rng_state=None,
        )
        q_output = q_func(inputs, quantizer=quantizer, flatten_axis=flatten_axis)
        rn_dq = self._dequantize(q_output)[0]
        return rn_dq

    def _test_sr(
        self, num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
    ) -> float:
        """Tests that the mean absolute error (MAE) of stochastic rounding is smaller than round nearest quantization over multiple samples."""
        dq_tensors = self._sample_sr_qdq(
            num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
        )
        avg_sr_tensor = jnp.mean(jnp.stack(dq_tensors).astype(jnp.float32), axis=0)
        assert avg_sr_tensor.shape == inputs.shape, (
            f"Dequantized tensor shape {avg_sr_tensor.shape} does not match input shape"
            f" {inputs.shape}"
        )

        round_nearest_tensor = self._round_nearest(
            q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
        )

        sr_mae = jnp.mean(jnp.abs(avg_sr_tensor - inputs))
        rn_mae = jnp.mean(jnp.abs(round_nearest_tensor - inputs))

        assert sr_mae < rn_mae, (
            f"Mean absolute error of stochastic rounding ({sr_mae}) is not smaller than"
            f" round nearest ({rn_mae})"
        )

        return sr_mae

    def test_sr_nvfp4(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
        """Tests that the mean absolute error of stochastic rounding is smaller than round nearest quantization over multiple samples for both TE and JAX implementations. Asserts that the MAE of both implementations is close to each other."""

        key = jax.random.PRNGKey(0)
        inputs = jax.random.uniform(key, input_shape, in_dtype)

        NUM_SAMPLES = 10

        te_mean_error = self._test_sr(
            NUM_SAMPLES, tex.quantize, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
        )
        jax_mean_error = self._test_sr(
            NUM_SAMPLES, _jax_quantize, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
        )

        assert_allclose(te_mean_error, jax_mean_error, rtol=0.2, atol=1e-4)
970
971


972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
@pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16])
@pytest_parametrize_wrapper("q_dtype", [jnp.float4_e2m1fn])
@pytest_parametrize_wrapper(
    "scaling_mode", [s for s in supported_scaling_modes if s == ScalingMode.NVFP4_1D_SCALING]
)
class TestRandomizedHadamardTransform:

    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE]
    )
    @pytest_parametrize_wrapper("input_shape,flatten_axis", [((64, 128), -1)])
    def test_rht_quantize_bitwise_jitted(
        self, in_dtype, q_dtype, scaling_mode, q_layout, input_shape, flatten_axis
    ):
        key = jax.random.PRNGKey(0)
        inputs = jax.random.uniform(key, input_shape, in_dtype)

        te_quantizer, jax_quantizer = QuantizerFactory.create(
            n_quantizers=2,
            q_dtype=q_dtype,
            scaling_mode=scaling_mode,
            q_layout=q_layout,
            use_rht=True,
        )

        jax_impl_func_jit = jax.jit(_jax_quantize, static_argnums=(2, 3))
        te_impl_func_jit = jax.jit(tex.quantize, static_argnums=(2,))

        jax_output = jax_impl_func_jit(inputs, quantizer=jax_quantizer, flatten_axis=flatten_axis)

        te_output = te_impl_func_jit(inputs, quantizer=te_quantizer, flatten_axis=flatten_axis)

        assert_bitwise_scaled_tensors(te_output, jax_output)

    def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
        if data_layout[0] == "T":
            a = jnp.swapaxes(a, -1, -2)
        if data_layout[1] == "T":
            b = jnp.swapaxes(b, -1, -2)
        return jnp.dot(a, b)

    def _generate_gemm_input(self, m, n, k, data_layout):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        x = jax.random.uniform(
            subkeys[0],
            (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
            dtype=jnp.bfloat16,
        ) / jnp.sqrt(k)
        w = jax.random.uniform(
            subkeys[1],
            (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
            dtype=jnp.bfloat16,
        ) / jnp.sqrt(n)
        lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
        rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
        contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)

        return (x, w, contracting_dims)

    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
    # We do not test NN and TT layouts here as they do not have both inputs using RHT due to RHT only supporting the colwise layout currently
    @pytest_parametrize_wrapper("data_layout", ["TN", "NT"])
    @pytest_parametrize_wrapper("with_jax_gemm", [True, False])
    def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, with_jax_gemm):
        key = jax.random.PRNGKey(0)

        lhs_scaling_mode, rhs_scaling_mode = scaling_mode, scaling_mode
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
        lhs_quantizer = QuantizerFactory.create(
            scaling_mode=lhs_scaling_mode,
            q_dtype=jnp.float4_e2m1fn,
            use_rht=True,
        )
        rhs_quantizer = QuantizerFactory.create(
            scaling_mode=rhs_scaling_mode,
            q_dtype=jnp.float4_e2m1fn,
            use_rht=True,
        )
        with use_jax_gemm(enabled=with_jax_gemm):
            primitive_out = tex.gemm(
                x,
                w,
                contracting_dims=contracting_dims,
                lhs_quantizer=lhs_quantizer,
                rhs_quantizer=rhs_quantizer,
            )
        ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
        assert_allclose(primitive_out, ref_out, dtype=jnp.float4_e2m1fn)


1063
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1064
1065
1066
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn])
1067
@pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes)
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
@pytest_parametrize_wrapper("flatten_axis", [-1])
@pytest_parametrize_wrapper("with_group_sizes", [True, False])
@pytest_parametrize_wrapper(
    "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE]
)
class TestGroupedQuantize:
    def test_grouped_qdq(
        self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis, with_group_sizes
    ):
        n_groups, m, n = input_shape
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        # *32 so that the input shapes works for MXFP8
        input_shape = (m * 32, n)

        if with_group_sizes:
            group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
            group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
            group_sizes = jnp.diff(group_sizes)
            assert group_sizes.sum() == m
            assert jnp.any(group_sizes == 0)  # make sure that at least one group has 0 row
            group_sizes = group_sizes * 32
        else:
            group_sizes = None
            input_shape = (n_groups, input_shape[0] // n_groups, input_shape[1])

        if flatten_axis == -2:
            input_shape = input_shape[:-1] + (2,) + input_shape[-1:]

        x = jax.random.uniform(subkeys[1], input_shape, in_dtype)

        grouped_quantizer = QuantizerFactory.create(
            scaling_mode=scaling_mode,
            q_dtype=q_dtype,
            q_layout=q_layout,
            n_groups=n_groups,
        )
        scaled_tensor = tex.grouped_quantize(
            x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer
        )

        assert_dequantized_grouped_scaled_tensor(scaled_tensor, x)


1113
1114
1115
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
class TestFusedQuantize:

1116
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1117
    @pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
1118
    @pytest_parametrize_wrapper("out_dtype,scaling_mode", QUANTIZE_QDTYPE_AND_SCALING_MODES)
1119
1120
1121
1122
1123
1124
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
    def test_quantize_dbias(
        self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis
    ):
1125
        if scaling_mode == ScalingMode.MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8(
1126
1127
1128
1129
1130
1131
1132
1133
            input_shape
        ):
            pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")

        key = jax.random.PRNGKey(0)
        input = jax.random.uniform(key, input_shape, in_dtype)

        jax_quantizer, te_quantizer = QuantizerFactory.create(
1134
            n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
1135
        )
1136

1137
1138
1139
1140
1141
        te_output, te_dbias = jit(
            lambda input: tex.quantize_dbias(
                input, quantizer=te_quantizer, flatten_axis=flatten_axis
            )
        )(input)
1142
1143
1144

        jax_output, jax_dbias = jit(
            lambda input: _jax_quantize_dbias(
1145
                input, quantizer=jax_quantizer, flatten_axis=flatten_axis
1146
            )
1147
        )(input)
1148

1149
        assert_bitwise_scaled_tensors(te_output, jax_output)
1150

1151
        assert_allclose(te_dbias, jax_dbias)
1152
1153

    def _test_quantize_dact_dbias(
1154
        self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout
1155
    ):
1156

1157
1158
1159
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1)
1160
1161
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
1162
        dz = jax.random.uniform(subkeys[1], input_shape, in_dtype, -1, 1)
1163

1164
        jax_quantizer, te_quantizer = QuantizerFactory.create(
1165
            n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
        )
        is_casted_output = te_quantizer is not None

        te_output, te_dbias = jit(
            lambda dz, x: tex.quantize_dact_dbias(
                dz,
                x,
                activation_type=activation_type,
                is_dbias=is_dbias,
                quantizer=te_quantizer,
            )
        )(dz, x)

        jax_output, jax_dbias = jit(
            lambda dz, x: _jax_quantize_dact_dbias(
                dz,
                x,
                activation_type=activation_type,
                is_dbias=is_dbias,
                quantizer=jax_quantizer,
            )
        )(dz, x)
1188

1189
        if is_casted_output:
1190
1191
1192
1193
1194
1195
1196
            # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation
            precise_comparison = not (
                in_dtype != jnp.float32 and scaling_mode.is_1d_block_scaling()
            )
            assert_bitwise_scaled_tensors(
                te_output, jax_output, precise_comparison=precise_comparison
            )
1197
        else:
1198
1199
1200
            assert isinstance(te_output, NoScaleTensor)
            assert isinstance(jax_output, NoScaleTensor)
            assert_allclose(te_output.data, jax_output.data)
1201
1202

        if is_dbias:
1203
            precise_comparison = not (
1204
1205
1206
1207
                # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
                (in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling())
                # Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently.
                or (
1208
                    activation_type in {("squared_relu",), ("clamped_silu", "clamped_linear")}
1209
1210
1211
                    and in_dtype == jnp.bfloat16
                    and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
                )
1212
1213
1214
1215
            )
            assert_allclose(
                te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype
            )
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230

    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("is_dbias", [True, False])
    def test_quantize_dact_dbias_no_quantization(
        self,
        in_dtype,
        input_shape,
        activation_type,
        is_dbias,
    ):
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=in_dtype,
1231
            scaling_mode=ScalingMode.NO_SCALING,
1232
1233
            activation_type=activation_type,
            is_dbias=is_dbias,
1234
            q_layout=QuantizeLayout.ROWWISE,
1235
        )
1236

1237
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1238
1239
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
1240
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_FP8_DTYPES)
1241
    @pytest_parametrize_wrapper("is_dbias", [True, False])
1242
    @pytest_parametrize_wrapper(
1243
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
1244
    )
1245
1246
1247
1248
1249
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_quantize_dact_dbias_tensor_scaling(
        self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout, scaling_mode
1250
1251
1252
1253
1254
    ):
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=out_dtype,
1255
            scaling_mode=scaling_mode,
1256
1257
            activation_type=activation_type,
            is_dbias=is_dbias,
1258
            q_layout=q_layout,
1259
        )
1260

1261
    @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
1262
1263
1264
1265
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper(
        "input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)]
    )
1266
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_FP8_DTYPES)
1267
    @pytest_parametrize_wrapper("is_dbias", [True, False])
1268
1269
1270
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
1271
    def test_quantize_dact_dbias_mxfp8_scaling(
1272
        self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout
1273
1274
1275
1276
1277
1278
1279
1280
    ):
        if reduce(operator.mul, input_shape[:-1]) % 128 != 0 or input_shape[-1] % 128 != 0:
            # TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes.
            # If it doesn't, move this check into the quantize_dact_dbias function and revert to JAX
            # implementation in the unsupported cases
            pytest.skip(
                f"Input shape {input_shape} is not supported by dact MXFP8 kernel in TE currently"
            )
1281

1282
1283
1284
1285
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=out_dtype,
1286
            scaling_mode=ScalingMode.MXFP8_1D_SCALING,
1287
1288
            activation_type=activation_type,
            is_dbias=is_dbias,
1289
            q_layout=q_layout,
1290
        )
1291
1292


Alp Dener's avatar
Alp Dener committed
1293
1294
1295
1296
1297
1298
valid_fp8_gemm_operand_types = [
    (jnp.float8_e4m3fn, jnp.float8_e4m3fn),
    (jnp.float8_e5m2, jnp.float8_e4m3fn),
    (jnp.float8_e4m3fn, jnp.float8_e5m2),
]

1299
1300
1301
1302
1303
supported_nvfp4_scaling_mode_pairs = [
    (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_1D_SCALING),
    (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING),
]

Alp Dener's avatar
Alp Dener committed
1304

1305
class TestDense:
1306
1307
    def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
        if data_layout[0] == "T":
1308
            a = jnp.swapaxes(a, -1, -2)
1309
        if data_layout[1] == "T":
1310
1311
            b = jnp.swapaxes(b, -1, -2)
        return jnp.dot(a, b)
1312

1313
    def _generate_gemm_input(self, m, n, k, data_layout):
1314
1315
1316
1317
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        x = jax.random.uniform(
            subkeys[0],
1318
            (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
1319
1320
1321
1322
            dtype=jnp.bfloat16,
        ) / jnp.sqrt(k)
        w = jax.random.uniform(
            subkeys[1],
1323
            (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
1324
1325
            dtype=jnp.bfloat16,
        ) / jnp.sqrt(n)
1326
1327
        lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
        rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
1328
1329
1330
1331
        contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)

        return (x, w, contracting_dims)

1332
1333
1334
1335
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
    @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
    def test_gemm_bf16(self, m, n, k, data_layout):
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
1336

Alp Dener's avatar
Alp Dener committed
1337
        primitive_out = tex.gemm(x, w, contracting_dims=contracting_dims)
1338
        ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
1339

1340
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
1341

1342
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1343
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
Alp Dener's avatar
Alp Dener committed
1344
    @pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types)
1345
    @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes)
1346
    @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
Alp Dener's avatar
Alp Dener committed
1347
1348
1349
1350
1351
1352
1353
1354
1355
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
    def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm):
        if (
            not with_jax_gemm
            and scaling_mode.is_1d_block_scaling()
            and jnp.float8_e5m2 in (x_qtype, w_qtype)
        ):
            pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.")

1356
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
1357
        quantizer_set = QuantizerFactory.create_set(
Alp Dener's avatar
Alp Dener committed
1358
1359
1360
1361
            scaling_mode=scaling_mode,
            fwd_dtype=jnp.float8_e4m3fn,
            bwd_dtype=jnp.float8_e5m2,
            is_2x2x=False,
1362
        )
Alp Dener's avatar
Alp Dener committed
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
        with use_jax_gemm(enabled=with_jax_gemm):
            primitive_out = tex.gemm(
                x,
                w,
                contracting_dims=contracting_dims,
                lhs_quantizer=(
                    quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad
                ),
                rhs_quantizer=(
                    quantizer_set.kernel if w_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad
                ),
            )
1375
        ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
1376

Alp Dener's avatar
Alp Dener committed
1377
        assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
1378

1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
    # TODO(Phuong): add bitwise test
    @pytest.mark.skipif(not is_fp4_supported, reason=fp4_unsupported_reason)
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
    @pytest_parametrize_wrapper("scaling_mode_pair", supported_nvfp4_scaling_mode_pairs)
    @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
    @pytest_parametrize_wrapper("with_jax_gemm", [True, False])
    def test_gemm_nvfp4(self, m, n, k, scaling_mode_pair, data_layout, with_jax_gemm):
        x_uses_rht = scaling_mode_pair[0] == ScalingMode.NVFP4_1D_SCALING and data_layout[0] == "T"
        w_uses_rht = scaling_mode_pair[1] == ScalingMode.NVFP4_1D_SCALING and data_layout[1] == "N"
        if x_uses_rht != w_uses_rht:
            # TODO(jberchtold): Ideally avoid a skip here and rewrite test setup to ensure both or neither use RHT
            pytest.skip("RHT must be used for both or neither operand, skipping")

        lhs_scaling_mode, rhs_scaling_mode = scaling_mode_pair
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
        lhs_quantizer = QuantizerFactory.create(
            scaling_mode=lhs_scaling_mode,
            q_dtype=jnp.float4_e2m1fn,
        )
        rhs_quantizer = QuantizerFactory.create(
            scaling_mode=rhs_scaling_mode,
            q_dtype=jnp.float4_e2m1fn,
        )
        with use_jax_gemm(enabled=with_jax_gemm):
            primitive_out = tex.gemm(
                x,
                w,
                contracting_dims=contracting_dims,
                lhs_quantizer=lhs_quantizer,
                rhs_quantizer=rhs_quantizer,
            )
        ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
        assert_allclose(primitive_out, ref_out, dtype=jnp.float4_e2m1fn)

1413
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
1414
    def test_dense_grad_bf16(self, m, n, k):
1415
1416
        data_layout = "NN"
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
1417

1418
1419
1420
        def primitive_func(x, w, contracting_dims):
            primitive_out = dense(x, w, contracting_dims=contracting_dims)
            return jnp.mean(primitive_out)
1421

1422
1423
        def ref_func(x, w, data_layout):
            return jnp.mean(self._ref_gemm_with_jnp_dot(x, w, data_layout))
1424

1425
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))
1426

1427
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))
1428

1429
1430
        primitive_out, (primitive_x_grad, primitive_w_grad) = value_n_grad_primitive_func(
            x, w, contracting_dims
1431
        )
1432
        ref_out, (ref_x_grad, ref_w_grad) = value_n_grad_ref_func(x, w, data_layout)
1433
1434
1435
1436

        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
        assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16)
        assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16)
1437

1438
1439
    @pytest_parametrize_wrapper("m,n,k", [(64, 128, 128)])
    @pytest_parametrize_wrapper("recipe", supported_recipes)
Alp Dener's avatar
Alp Dener committed
1440
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
1441
    def test_dense_grad_fp8_and_fp4(self, m, n, k, recipe, with_jax_gemm):
1442
1443
        data_layout = "NN"
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
1444
1445
1446
1447
1448
1449
1450
1451
1452

        key = jax.random.PRNGKey(1)
        bias = jax.random.uniform(key, n, dtype=jnp.bfloat16)

        def primitive_func(x, w, bias, contracting_dims, quantizer_set):
            primitive_out = dense(
                x, w, bias, contracting_dims=contracting_dims, quantizer_set=quantizer_set
            )
            return jnp.mean(primitive_out)
1453

1454
        def ref_func(x, w, bias, data_layout):
1455
            return jnp.mean(
1456
                self._ref_gemm_with_jnp_dot(x, w, data_layout) + jnp.expand_dims(bias, axis=0)
1457
            )
1458

1459
1460
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
1461

1462
1463
1464
1465
1466
1467
        quantizer_set = QuantizerFactory.create_set(
            fp8_recipe=recipe,
            quantize_meta_set=QuantizeMetaSet(
                x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta()
            ),
        )
1468

1469
        n_iterations = 3 if recipe.delayed() else 1
Alp Dener's avatar
Alp Dener committed
1470
1471
1472
1473
1474
        with use_jax_gemm(enabled=with_jax_gemm):
            for _ in range(n_iterations):
                primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = (
                    value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set)
                )
1475

1476
1477
1478
        ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func(
            x, w, bias, data_layout
        )
1479

1480
1481
1482
1483
        assert_allclose(primitive_out, ref_out, dtype=quantizer_set.x.q_dtype)
        assert_allclose(primitive_x_grad, ref_x_grad, dtype=quantizer_set.dgrad.q_dtype)
        assert_allclose(primitive_w_grad, ref_w_grad, dtype=quantizer_set.dgrad.q_dtype)
        assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=quantizer_set.dgrad.q_dtype)
1484
1485


1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
    key = jax.random.PRNGKey(0)
    subkeys = jax.random.split(key, 4)
    out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
    return out


def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer):
    if norm_type == "rmsnorm":
        ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
    else:
        ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
1499
    ln_out = ln_out.dequantize()
1500
1501
1502
1503
    return ln_out


class TestFusedDense:
1504
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1505
1506
    @pytest.mark.parametrize("m,n,k", [(64, 128, 128)])
    @pytest_parametrize_wrapper("recipe", supported_recipes)
1507
    @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
Alp Dener's avatar
Alp Dener committed
1508
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
1509
    def test_layernorm_dense_grad(self, m, n, k, recipe, norm_type, with_jax_gemm):
1510
        """
1511
        Test layernorm_dense VJP Rule
1512
        """
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
        # zero_centered_gamma is already tested in TestNorm
        zero_centered_gamma = False
        eps = 1e-6

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 4)

        # NN in FWD
        x = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16) / jnp.sqrt(k)
        w = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16) / jnp.sqrt(n)

        gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)

1526
1527
1528
1529
1530
1531
        quantizer_set = QuantizerFactory.create_set(
            fp8_recipe=recipe,
            quantize_meta_set=QuantizeMetaSet(
                x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta()
            ),
        )
1532
1533
1534

        if norm_type == "layernorm":
            beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
1535
        else:
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
            beta = None

        def prim_func(x, w, gamma, beta):
            # bias = None as quantize_dbias is already tested in test_dense_grad_fp8
            prim_out = layernorm_dense(
                x,
                w,
                gamma,
                beta,
                None,
                norm_type,
                zero_centered_gamma,
                eps,
                quantizer_set=quantizer_set,
            )
            return jnp.mean(prim_out)

        def ref_func(x, w, gamma, beta):
            x = _ref_jax_norm_impl(
                x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
            )
            return jnp.mean(jnp.dot(x, w))

        value_n_grad_prim_func = value_and_grad(prim_func, (0, 1, 2, 3))
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))

        ref_out, (ref_x_grad, ref_w_grad, ref_gamma_grad, ref_beta_grad) = value_n_grad_ref_func(
            x, w, gamma, beta
        )

1566
        n_iterations = 3 if recipe.delayed() else 1
Alp Dener's avatar
Alp Dener committed
1567
1568
1569
1570
1571
1572
1573
1574
1575
        with use_jax_gemm(enabled=with_jax_gemm):
            for _ in range(n_iterations):
                prim_out, (
                    prim_x_grad,
                    prim_w_grad,
                    prim_gamma_grad,
                    prim_beta_grad,
                ) = value_n_grad_prim_func(x, w, gamma, beta)

1576
1577
1578
1579
        assert_allclose(prim_out, ref_out, dtype=quantizer_set.x.q_dtype)
        assert_allclose(prim_x_grad, ref_x_grad, dtype=quantizer_set.dgrad.q_dtype)
        assert_allclose(prim_w_grad, ref_w_grad, dtype=quantizer_set.dgrad.q_dtype)
        assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=quantizer_set.dgrad.q_dtype)
1580
        if beta is not None:
1581
            assert_allclose(prim_beta_grad, ref_beta_grad, dtype=quantizer_set.dgrad.q_dtype)
1582

1583
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1584
    @pytest.mark.parametrize("m,n,k", [(64, 128, 128)])
1585
    @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
1586
    @pytest_parametrize_wrapper("recipe", supported_recipes)
1587
    @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
Alp Dener's avatar
Alp Dener committed
1588
1589
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
1590
    def test_layernorm_mlp_grad(
1591
        self, m, n, k, activation_type, recipe, norm_type, use_bias, with_jax_gemm
1592
    ):
1593
        """
1594
        Test layernorm_mlp VJP Rule
1595
        """
1596
1597
1598
1599
1600
1601
1602
1603
1604
        # zero_centered_gamma is already tested in TestNorm
        zero_centered_gamma = False
        eps = 1e-6

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 6)

        x = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        kernel_1 = jax.random.normal(
1605
            subkeys[1], (k, len(activation_type), n), jnp.bfloat16
1606
1607
1608
1609
1610
        ) / jnp.sqrt(k)
        kernel_2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
        gamma = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
        beta = None  # was tested in TestNorm
        if use_bias:
1611
            bias_1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
1612
1613
1614
1615
1616
1617
1618
            bias_2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
        else:
            bias_1 = None
            bias_2 = None

        quantizer_sets = QuantizerFactory.create_set(
            n_quantizer_sets=2,
1619
            fp8_recipe=recipe,
1620
1621
1622
            quantize_meta_set=QuantizeMetaSet(
                x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta()
            ),
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
        )

        if norm_type == "layernorm":
            beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
        else:
            beta = None

        def prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            return jnp.mean(
                layernorm_mlp(
                    x,
                    gamma,
                    beta,
                    [kernel_1, kernel_2],
                    [bias_1, bias_2],
                    norm_type,
                    zero_centered_gamma=zero_centered_gamma,
                    epsilon=eps,
                    activation_type=activation_type,
                    quantizer_sets=quantizer_sets,
1643
1644
                )
            )
1645

1646
1647
1648
        def _ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            ln_out = _ref_jax_norm_impl(
                x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
1649
            )
Alp Dener's avatar
Alp Dener committed
1650
            linear_1_out = jax.lax.dot_general(ln_out, kernel_1, (((1,), (0,)), ((), ())))
1651
1652
1653
1654
            if use_bias:
                bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
                linear_1_out += jnp.reshape(bias_1, bias_1_shape)

1655
            x = _jax_act_lu(linear_1_out, activation_type).data
Alp Dener's avatar
Alp Dener committed
1656
            linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ())))
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
            if use_bias:
                bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape
                linear_2_out += jnp.reshape(bias_2, bias_2_shape)

            return linear_2_out

        def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            return jnp.mean(_ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2))

        value_n_grad_prim_func = value_and_grad(prim_func, range(6))
        value_n_grad_ref_func = value_and_grad(ref_func, range(6))

1669
        n_iterations = 3 if recipe.delayed() else 1
Alp Dener's avatar
Alp Dener committed
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
        with use_jax_gemm(enabled=with_jax_gemm):
            for _ in range(n_iterations):
                prim_out, (
                    prim_x_grad,
                    prim_gamma_grad,
                    prim_kernel_1_grad,
                    prim_kernel_2_grad,
                    prim_bias_1_grad,
                    prim_bias_2_grad,
                ) = value_n_grad_prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689

        ref_out, (
            ref_x_grad,
            ref_gamma_grad,
            ref_kernel_1_grad,
            ref_kernel_2_grad,
            ref_bias_1_grad,
            ref_bias_2_grad,
        ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)

1690
1691
1692
1693
1694
1695
1696
        fwd_dtype = quantizer_sets[0].x.q_dtype
        bwd_dtype = quantizer_sets[0].dgrad.q_dtype
        assert_allclose(prim_out, ref_out, dtype=fwd_dtype)
        assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=bwd_dtype)
        assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=bwd_dtype)
        assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=bwd_dtype)
        assert_allclose(prim_x_grad, ref_x_grad, dtype=bwd_dtype)
1697
        if use_bias:
1698
1699
            assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=bwd_dtype)
            assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=bwd_dtype)
1700
1701
1702
1703
1704
1705
1706
1707
1708


# E5M2 * E5M2 is not supported
fwd_bwd_dtypes = [
    [jnp.float8_e4m3fn, jnp.float8_e4m3fn],
    [jnp.float8_e4m3fn, jnp.float8_e5m2],
    [jnp.float8_e5m2, jnp.float8_e4m3fn],
]

1709
1710
1711
1712
1713
1714
1715
1716
1717
GROUPED_DENSE_INPUT_SHAPES = [
    # (n_groups, m, n, k), the actual m will be multiplied by 32
    (5, 32, 128, 64),  # Test the case where n_groups is not a multiple of 4
    (8, 64, 32, 128),
    (8, 64, 128, 256),
]


@pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES)
1718
class TestGroupedDense:
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
    def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims):
        lhs_contract_dim, _ = contracting_dims
        assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3
        if bias is None:
            bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype)
        else:
            assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2])
        remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop()
        lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis)
        rhs = jnp.split(rhs, rhs.shape[0], axis=0)
        bias = jnp.split(bias, bias.shape[0], axis=0)
        ref_out = []
        dim_num = (contracting_dims, ((), ()))
        for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias):
1733
1734
1735
            out_i = jax.lax.dot_general(
                lhs_i, rhs_i, dim_num, precision=jax.lax.Precision.HIGHEST
            ) + jnp.expand_dims(bias_i, axis=0)
1736
1737
1738
1739
            ref_out.append(jnp.squeeze(out_i))
        return ref_out

    def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False):
1740
        key = jax.random.PRNGKey(0)
1741
1742
1743
1744
1745
1746
        subkeys = jax.random.split(key, 4)
        n_groups, m, n, k = input_shape

        group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
        group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
        group_sizes = jnp.diff(group_sizes)
1747
1748
1749
        # Make one empty input lhs to test empty GEMM handling
        group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1])
        group_sizes = group_sizes.at[1].set(0)
1750
1751
1752
1753
1754
        assert group_sizes.sum() == m

        # *32 to make sure that input shape works for MXFP8
        group_sizes = group_sizes * 32
        m = m * 32
1755

1756
1757
1758
        lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m)
        rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k)
        bias_shape = (n_groups, n)
1759

1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
        lhs = jax.random.uniform(subkeys[1], lhs_shape, dtype=dtype)
        rhs = jax.random.uniform(subkeys[2], rhs_shape, dtype=dtype)
        bias = jax.random.uniform(subkeys[3], bias_shape, dtype=dtype) if with_bias else None

        lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
        rhs_contracting_dim = (1,) if data_layout[1] == "N" else (2,)
        contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)

        return lhs, rhs, group_sizes, contracting_dims, bias

    def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype):
        assert out.dtype == ref_list[0].dtype
        out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0)
        for i in range(len(ref_list)):
            assert_allclose(out_list[i], ref_list[i], dtype=dtype)
1775
1776

    @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
1777
1778
1779
1780
    @pytest_parametrize_wrapper("layout", ["NN"])
    def test_grouped_gemm_fp16(self, dtype, input_shape, layout):
        lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
            dtype, input_shape, layout
1781
        )
1782
1783
1784
1785
1786
        num_gemms = input_shape[0]
        _ = jax.jit(tex.grouped_gemm_copy_group_sizes, static_argnames=("num_gemms",))(
            group_sizes,
            num_gemms=num_gemms,
        )
1787
        ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
1788
1789

        # jitting grouped_gemm
1790
1791
1792
        prim_out = jax.jit(
            tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes")
        )(
1793
1794
1795
1796
            lhs,
            rhs,
            group_sizes,
            contracting_dims,
1797
            use_async_d2h_group_sizes=True,
1798
        )
1799

1800
        self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
1801

1802
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1803
    @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
1804
    @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes)
1805
1806
    @pytest_parametrize_wrapper("layout", ["NN"])
    def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout):
1807
1808
        fwd_dtype, bwd_dtype = fwd_bwd_dtype
        quantizer_set = QuantizerFactory.create_set(
1809
1810
1811
1812
1813
            scaling_mode=scaling_mode,
            fwd_dtype=fwd_dtype,
            bwd_dtype=bwd_dtype,
            is_2x2x=False,
            n_groups=input_shape[0],
1814
        )
1815

1816
1817
1818
1819
1820
1821
        # quantizer_set.{x, kernel} has fwd_dtype, while quantizer_set.grad has bwd_dtype
        # We want to test E4M3 * E5M2, manually set the quantizer_set.kernel.q_dtype to bwd_dtype
        quantizer_set.kernel.q_dtype = bwd_dtype
        for quantizer in quantizer_set.kernel.quantizers:
            quantizer.q_dtype = bwd_dtype

1822
        out_dtype = jnp.bfloat16
1823
1824
1825
1826
        lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
            out_dtype, input_shape, layout
        )
        ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
1827

1828
        prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
1829
            lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
1830
1831
1832
        )

        allclose_dtype = jnp.float8_e4m3fn
1833
        if jnp.float8_e5m2 in fwd_bwd_dtype:
1834
1835
            allclose_dtype = jnp.float8_e5m2

1836
        self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, allclose_dtype)
1837

1838
1839
1840
    def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims):
        out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims)
        # Note: we use jnp.sum instead of jnp.mean to make the gradient larger
1841
1842
        # and prevent them from being clamp to zero in FP8. / sqrt(x.size) is used to
        # normalize the output and prevent the gradient from being too large for FP8.
1843
        out_sum_list = [jnp.sum(out) for out in out_list]
1844
        return jnp.sum(jnp.asarray(out_sum_list)) / jnp.sqrt(x.size)
1845
1846
1847
1848
1849
1850

    def _primitive_sum_grouped_dense(
        self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set
    ):
        out = grouped_dense(
            x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set
1851
        )
1852
        return jnp.sum(jnp.asarray(out)) / jnp.sqrt(x.size)
1853

1854
1855
1856
1857
1858
1859
1860
    @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
    def test_grouped_dense_grad_fp16(self, dtype, input_shape):
        x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
            dtype,
            input_shape,
            with_bias=True,
        )
1861

1862
        value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
1863
        # jitting the grouped_dense
1864
1865
1866
        value_n_grad_prim_func = jit(
            value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,)
        )
1867

1868
1869
        ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
            x, kernel, bias, group_sizes, contracting_dims
1870
        )
1871
1872
        prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func(
            x, kernel, bias, group_sizes, contracting_dims
1873
        )
1874

1875
1876
1877
1878
        assert_allclose(prim_out_sum, ref_out_sum, dtype=dtype)
        assert_allclose(prim_dgrad, ref_dgrad, dtype=dtype)
        assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype)
        assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
1879

1880
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1881
1882
1883
1884
    @pytest.mark.parametrize(
        "fwd_bwd_dtype",
        [(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)],
    )
1885
    @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes)
1886
1887
1888
1889
1890
1891
1892
    def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
        fwd_dtype, bwd_dtype = fwd_bwd_dtype
        dtype = jnp.bfloat16
        x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
            dtype,
            input_shape,
            with_bias=True,
1893
        )
1894

1895
1896
1897
1898
1899
1900
        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode,
            fwd_dtype=fwd_dtype,
            bwd_dtype=bwd_dtype,
            is_2x2x=True,
            n_groups=group_sizes.size,
1901
        )
1902
        value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
1903
1904

        # jitting the grouped_dense
1905
1906
1907
        value_n_grad_prim_func = jit(
            value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,)
        )
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917

        ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
            x,
            kernel,
            bias,
            group_sizes,
            contracting_dims,
        )
        prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func(
            x, kernel, bias, group_sizes, contracting_dims, quantizer_set=quantizer_set
1918
1919
        )

1920
1921
1922
1923
        assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype)
        assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype)
        assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype)
        assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957


class TestDebugInspectFFI:

    @pytest_parametrize_wrapper("shape", [(256, 128)])
    @pytest_parametrize_wrapper(
        "dtype",
        [
            jnp.float32,
            jnp.bfloat16,
            jnp.float16,
            # Note: fp4 currently doesn't work
            # jnp.float4_e2m1fn
        ]
        + ([jnp.float8_e4m3fn, jnp.float8_e5m2] if is_fp8_supported else []),
    )
    def test_debug_inspect_ffi(self, shape, dtype):
        from transformer_engine.jax.debug.experimental import inspect_array, load_array_dump

        def f(x):
            x = x + 1
            x = inspect_array(x, "my_array")
            x = x + 1
            return x

        key = jax.random.PRNGKey(0)
        x = jax.random.uniform(key, shape, jnp.float32)
        x = x.astype(dtype)
        _ = jax.jit(f)(x)

        expected = x + 1
        actual = load_array_dump("my_tensor_gpu0.bin", shape, dtype)

        assert_allclose(actual, expected, dtype=dtype)