test_custom_call_compute.py 76.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.

import jax
import jax.numpy as jnp
import pytest
from jax import jit, value_and_grad
9
from functools import reduce
10
from typing import Union
11
12
13
14
15
import operator

from utils import (
    assert_allclose,
    pytest_parametrize_wrapper,
Alp Dener's avatar
Alp Dener committed
16
    use_jax_gemm,
17
18
19
20
21
)
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.layernorm_mlp import layernorm_mlp

from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu, _jax_quantize_dact_dbias
22
23
24
25
26
from transformer_engine.jax.cpp_extensions.normalization import (
    _jax_layernorm,
    _jax_rmsnorm,
    is_norm_zero_centered_gamma_in_weight_dtype,
)
27
28
29
from transformer_engine.jax.cpp_extensions.quantization import (
    _jax_quantize,
    _jax_quantize_dbias,
30
)
31
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
32
from transformer_engine.jax import cpp_extensions as tex
33
from transformer_engine.jax.quantize import (
34
    NoScaleTensor,
35
    ScaledTensor,
36
37
38
    ScaledTensor1x,
    ScaledTensor2x,
    GroupedScaledTensor1x,
39
40
    ScalingMode,
    QuantizerFactory,
41
    QuantizeLayout,
42
    noop_quantizer_set,
43
44
45
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
46
from transformer_engine.jax.dense import dense, grouped_dense
47
from transformer_engine.jax.layernorm_dense import layernorm_dense
48
from transformer_engine.common import recipe
49

Tim Moon's avatar
Tim Moon committed
50
51
52
53
54
55
56
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (2048, 1024, 2048),
    (2048, 2048, 1024),
    (2048, 1024, 1024),
]
57
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
58
LN_CASES = [(256, 128), (128, 256)]
59
DTYPES = [jnp.bfloat16, jnp.float32]
60

61
62
63
64
65
66
67
68
69
70
71
# TODO(Phuong): remove unneccessary pytest skips
is_fp8_supported, fp8_unsupported_reason = helper.is_scaling_mode_supported(
    ScalingMode.DELAYED_TENSOR_SCALING
)
is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_scaling_mode_supported(
    ScalingMode.MXFP8_1D_SCALING
)
is_fp4_supported, fp4_unsupported_reason = helper.is_scaling_mode_supported(
    ScalingMode.NVFP4_1D_SCALING
)

72
""" Find supported scaling modes"""
73
74
75
76
supported_scaling_modes = helper.get_supported_scaling_modes()
non_fp4_supported_scaling_modes = [s for s in supported_scaling_modes if not s.is_nvfp4_scaling]
supported_recipes = helper.get_supported_quantization_recipes()
supported_recipes = [pytest.param(r, id=r.__class__.__name__) for r in supported_recipes]
77
78
79
80
81
82


def is_shape_supported_by_mxfp8(input_shape):
    try:
        if isinstance(input_shape, type(pytest.param(0))):
            input_shape = input_shape.values[0]
83
        ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape)
84
85
86
87
88
89
        return True
    except:
        # get_scale_shapes will raise an exception if the shape is not supported
        return False


90
91
92
def assert_bitwise_scaled_tensors(
    a: ScaledTensor, b: ScaledTensor, precise_comparison: bool = True
):
93
    if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
94
        if not precise_comparison and not a.scaling_mode.is_nvfp4_scaling:
95
96
97
            assert_allclose(a.dequantize(), b.dequantize(), dtype=a.data.dtype)
            return

98
        assert a.scaling_mode == b.scaling_mode
99
        assert a.scale_inv.dtype == b.scale_inv.dtype
100
        assert a.data_layout == b.data_layout
101
102
103
104
105
        if a.scaling_mode.is_tensor_scaling():
            # Assert in dq_dtype as some unfused codepaths have an intermediate cast
            # to an input dtype which reduces precision compared to everything in fp32
            assert_allclose(a.scale_inv, b.scale_inv, dtype=a.dq_dtype)
        elif a.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
106
107
            # Compare MXFP8 scales as uint8
            assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8))
108
109
110
111
112
113
114
115
116
117
        elif a.scaling_mode.is_nvfp4_scaling:
            assert_allclose(a.amax, b.amax)
            assert_allclose(a.scale_inv, b.scale_inv)
            if not precise_comparison:
                mismatch = a.data != b.data
                mismatch_fraction = jnp.mean(mismatch.astype(jnp.float32))
                assert (
                    mismatch_fraction < 0.05
                ), f"Mismatch fraction {mismatch_fraction} is too high"
                return
118
        else:
119
            raise ValueError(f"Unsupported scaling mode {a.scaling_mode}")
120
        assert_allclose(a.data, b.data)
121

122
    elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x):
123
124
125
126
127
128
        assert_bitwise_scaled_tensors(
            a.rowwise_tensor, b.rowwise_tensor, precise_comparison=precise_comparison
        )
        assert_bitwise_scaled_tensors(
            a.colwise_tensor, b.colwise_tensor, precise_comparison=precise_comparison
        )
129
130
131
132
133
134
    else:
        pytest.fail("Unsupported input types")


def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
    if isinstance(a, ScaledTensor1x):
135
136
137
        if a.data_layout == "T":
            flatten_axis = a.data.ndim - a.flatten_axis
            b_transpose = jnp.transpose(b, (*range(flatten_axis, b.ndim), *range(flatten_axis)))
138
139
140
141
            assert_allclose(a.dequantize(), b_transpose, dtype=a.data.dtype)
        else:
            assert_allclose(a.dequantize(), b, dtype=a.data.dtype)
    elif isinstance(a, ScaledTensor2x):
142
143
        assert_dequantized_scaled_tensor(a.rowwise_tensor, b)
        assert_dequantized_scaled_tensor(a.colwise_tensor, b)
144
145
146
147
    else:
        pytest.fail("a must be a ScaledTensor object")


148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def assert_dequantized_grouped_scaled_tensor(
    a: Union[GroupedScaledTensor1x, ScaledTensor2x], b: jnp.ndarray
):
    if isinstance(a, GroupedScaledTensor1x):
        assert a.group_sizes.sum() == b.shape[0]
        b = jnp.split(b, jnp.cumulative_sum(a.group_sizes)[:-1], axis=0)
        dq_a = a.dequantize()
        for dq_a_i, b_i in zip(dq_a, b):
            if len(dq_a_i) == 0:
                continue
            if a.data_layout == "T":
                data_ndim = len(a.original_shape)
                flatten_axis = a.flatten_axis
                if b_i.shape[0] == 1:
                    b_i = jnp.transpose(
                        b_i, (0, *range(flatten_axis, data_ndim), *range(1, flatten_axis))
                    )
                else:
                    b_i = jnp.transpose(
                        b_i, (*range(flatten_axis, data_ndim), *range(flatten_axis))
                    )
            dq_a_i = dq_a_i.reshape(b_i.shape)
            assert_allclose(dq_a_i, b_i, dtype=a.data.dtype)
    elif isinstance(a, ScaledTensor2x):
172
173
174
175
        assert isinstance(a.rowwise_tensor, GroupedScaledTensor1x)
        assert isinstance(a.colwise_tensor, GroupedScaledTensor1x)
        assert_dequantized_grouped_scaled_tensor(a.rowwise_tensor, b)
        assert_dequantized_grouped_scaled_tensor(a.colwise_tensor, b)
176
177
178
179
    else:
        pytest.fail("a must be a GroupedScaledTensor object")


180
181
182
183
184
185
186
187
188
189
190
191
ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)]
ALL_ACTIVATION_TYPES = [
    ("gelu",),
    ("gelu", "linear"),
    ("silu",),
    ("silu", "linear"),
    ("relu",),
    ("relu", "linear"),
    ("quick_gelu",),
    ("quick_gelu", "linear"),
    ("squared_relu",),
    ("squared_relu", "linear"),
192
    ("clamped_silu", "clamped_linear"),
193
]
194

195
196
197
198
199
200
201
ACTIVATION_TYPES = {
    "L0": [
        ("gelu",),
        ("gelu", "linear"),
    ],
    "L2": ALL_ACTIVATION_TYPES,
}
202
203


204
class TestActivation:
205
206
    def ref_act(self, x, activation_type, act_params):
        return _jax_act_lu(x, activation_type, act_params=act_params).data
207

208
    def value_n_grad_ref_func(self, x, activation_type, act_params):
209
        jitted_reference = jit(
210
211
212
            value_and_grad(
                lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,)
            )
213
214
        )
        return jitted_reference(x)
215

216
217
218
219
    def primitive_func(self, inputs, activation_type, quantizer, act_params):
        out = activation(
            inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params
        )
220
221
222
223
224
225
226
227
228
229
        return jnp.mean(out)

    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper(
        "activation_type",
        (
            ALL_ACTIVATION_TYPES  # Test all activation types for this test to ensure all are functional, then just test a subset for the other tests to verify other functionality
        ),
    )
    def test_act_grad(self, shape, activation_type):
230
        key = jax.random.PRNGKey(0)
231
        x = jax.random.uniform(key, shape, jnp.float32)
232
233
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
234

235
        value_n_grad_primitive_func = jit(
236
            value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3)
237
        )
238
239
240
241
242
243
244
245
246
247
248
249
        act_args = (
            {"limit": 0.75, "alpha": 1.702}
            if activation_type == ("clamped_silu", "clamped_linear")
            else {}
        )
        act_params = (
            tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
            if activation_type == ("clamped_silu", "clamped_linear")
            else None
        )
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params)
        ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)
250
251
        assert_allclose(prim_out, ref_out, dtype=x.dtype)
        assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
252

253
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
254
255
256
    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
257
258
259
260
261
262
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_act_grad_with_tensor_scaling_fp8(
        self, random_inputs, activation_type, output_type, scaling_mode
    ):
263
        x = random_inputs
264
265
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
266
        self.activation_type = activation_type
267

268
        value_n_grad_primitive_func = jit(
269
270
            value_and_grad(self.primitive_func, (0,)),
            static_argnums=(1, 3),
271
        )
272

273
        quantizer = QuantizerFactory.create(
274
            scaling_mode=scaling_mode,
275
            q_dtype=output_type,
276
            q_layout=QuantizeLayout.ROWWISE,
277
        )
278
279
280
281
282
        act_args = (
            {"limit": 0.75, "alpha": 1.702}
            if activation_type == ("clamped_silu", "clamped_linear")
            else {}
        )
283

284
285
286
287
288
289
290
291
292
        act_params = (
            tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
            if activation_type == ("clamped_silu", "clamped_linear")
            else None
        )
        prim_out, (prim_grad,) = value_n_grad_primitive_func(
            x, activation_type, quantizer, act_params
        )
        ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)
293

294
295
        assert_allclose(prim_out, ref_out, dtype=output_type)
        assert_allclose(prim_grad, ref_grad, dtype=output_type)
296

297
    @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
298
299
300
    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
301
302
303
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
304
305
306
307
308
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_act_forward_with_tensor_scaling_fp8(
        self, random_inputs, activation_type, output_type, q_layout, scaling_mode
309
310
    ):
        x = random_inputs
311
312
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
313
        self.activation_type = activation_type
314

315
316
        te_quantizer, jax_quantizer = QuantizerFactory.create(
            n_quantizers=2,
317
            scaling_mode=scaling_mode,
318
            q_dtype=output_type,
319
            q_layout=q_layout,
320
        )
321
322
323
324
325
326
327
328
329
330
331
332
        act_args = (
            {"limit": 0.75, "alpha": 1.702}
            if activation_type == ("clamped_silu", "clamped_linear")
            else {}
        )
        act_params = (
            tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
            if activation_type == ("clamped_silu", "clamped_linear")
            else None
        )
        te_output = tex.act_lu(x, activation_type, te_quantizer, act_params)
        jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params)
333
        assert_bitwise_scaled_tensors(te_output, jax_output)
334

335
    @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
336
    @pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)])
337
338
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
339
340
341
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
342
    def test_act_forward_with_block_scaling_fp8(
343
        self, random_inputs, activation_type, output_type, q_layout
344
345
    ):
        x = random_inputs
346
        x = jnp.repeat(x, len(activation_type), axis=-2)
347
        self.activation_type = activation_type
348

349
        quantizer = QuantizerFactory.create(
350
            scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
351
        )
352
353
354
355
356
357
358
359
360
361
362
363
        act_args = (
            {"limit": 0.75, "alpha": 1.702}
            if activation_type == ("clamped_silu", "clamped_linear")
            else {}
        )
        act_params = (
            tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
            if activation_type == ("clamped_silu", "clamped_linear")
            else None
        )
        output = tex.act_lu(x, activation_type, quantizer, act_params)
        ref_out = self.ref_act(x, activation_type, act_params)
364
        assert_dequantized_scaled_tensor(output, ref_out)
365
366


367
368
369
370
NORM_OUTPUT_DTYPES = {
    "L0": [jnp.float8_e4m3fn],
    "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
371

372

373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
@pytest_parametrize_wrapper("n, hidden", LN_CASES)
@pytest_parametrize_wrapper("inp_dtype", DTYPES)
@pytest_parametrize_wrapper("norm_type", ["layernorm", "rmsnorm"])
@pytest_parametrize_wrapper(
    "zero_centered_gamma",
    [
        pytest.param(True, id="zero_centered"),
        pytest.param(False, id="no_zero_centered"),
    ],
)
@pytest_parametrize_wrapper("epsilon", [1e-2, 1e-6])
class TestNorm:
    """
    Test transformer_engine.jax.layernorm APIs
    """
388

389
390
391
392
393
394
395
396
397
398
399
400
401
    def _test_norm_grad(
        self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
    ):
        def compute_loss(x):
            # Higher precision to compute the loss
            x_ = x.astype(jnp.float32)
            return jnp.mean(jnp.square(x_)).astype(x.dtype)

        def reference_func(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer):
            if norm_type == "rmsnorm":
                ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
            else:
                ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
402
403
            # This is a no-op for non-quantized data
            ln_out = ln_out.dequantize()
404
            return ln_out
405

406
407
408
409
410
411
412
413
414
415
416
417
418
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 3)

        x = jax.random.uniform(subkeys[0], (n, hidden), jnp.float32, -1, 1)
        x = x.astype(inp_dtype)
        gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
        gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
        gamma = jnp.asarray(gamma, inp_dtype)
        if norm_type == "layernorm":
            beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
            beta = jnp.asarray(beta, inp_dtype)
        else:
            beta = None
419

420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
        jitted_reference = jit(
            value_and_grad(
                lambda x, gamma, beta: compute_loss(
                    reference_func(
                        x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer=None
                    )
                ),
                (0, 1, 2),
            )
        )
        jitted_primitive = jit(
            value_and_grad(
                lambda x, gamma, beta: compute_loss(
                    layernorm(x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer)
                ),
                (0, 1, 2),
436
            )
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
        )

        reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference(
            x, gamma, beta
        )
        primitive_out, (primitive_dx, primitive_dgamma, primitive_dbeta) = jitted_primitive(
            x, gamma, beta
        )

        out_dtype = inp_dtype if quantizer is None else quantizer.q_dtype
        assert_allclose(primitive_out, reference_out, dtype=out_dtype)
        assert_allclose(primitive_dx, reference_dx, dtype=out_dtype)
        assert_allclose(primitive_dgamma, reference_dgamma, dtype=out_dtype)
        if beta is not None:
            assert_allclose(primitive_dbeta, reference_dbeta, dtype=out_dtype)
452

453
454
455
456
457
458
459
460
461
462
    def test_norm_grad(self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype):
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        self._test_norm_grad(
            n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer=None
        )
463

464
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
465
466
    # No Norm FWD E5M2 in TE backend
    @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
467
468
469
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
470
471
472
473
474
475
476
477
478
479
480
481
482
483
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_norm_grad_with_tensor_scaling_fp8(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        q_layout,
        scaling_mode,
484
485
486
487
488
489
490
491
    ):
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        quantizer = QuantizerFactory.create(
492
            scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
        )
        self._test_norm_grad(
            n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
        )

    def _test_norm_forward(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        scaling_mode,
508
        q_layout,
509
    ):
510
        key = jax.random.PRNGKey(0)
511
        subkeys = jax.random.split(key, 3)
512

513
514
515
516
517
        x = jax.random.uniform(subkeys[0], (n, hidden), inp_dtype, -1, 1)
        x = jnp.asarray(x, inp_dtype)
        gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
        gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
        gamma = jnp.asarray(gamma, inp_dtype)
518

519
        quantizer, ref_quantizer = QuantizerFactory.create(
520
            n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
521
522
523
524
525
526
        )
        if norm_type == "layernorm":
            beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
            beta = jnp.asarray(beta, inp_dtype)
            output, mu, rsigma = tex.layernorm_fwd(
                x, gamma, beta, zero_centered_gamma, epsilon, quantizer=quantizer
527
            )
528
            ref_out, ref_mu, ref_rsigma = _jax_layernorm(
529
530
531
532
533
534
                x,
                gamma,
                beta,
                zero_centered_gamma,
                epsilon,
                quantizer=ref_quantizer,
535
            )
536
537
538
        else:
            output, rsigma = tex.rmsnorm_fwd(
                x, gamma, zero_centered_gamma, epsilon, quantizer=quantizer
539
            )
540
            ref_out, ref_rsigma = _jax_rmsnorm(
541
542
543
544
545
                x,
                gamma,
                zero_centered_gamma,
                epsilon,
                quantizer=ref_quantizer,
546
            )
547
            ref_mu = None
548

549
550
551
        precise_comparison = True

        if get_cudnn_version() < (9, 10, 0) and scaling_mode == ScalingMode.MXFP8_1D_SCALING:
552
553
            # Reduce precision of test as we don't use fused norm below this version CuDNN for MXFP8 and instead
            # do an unfused norm and quantize with an intermediate cast into in_dtype which can reduce precision
554
555
556
557
558
559
560
561
562
563
564
            precise_comparison = False
        elif is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode):
            # Larger tolerances as our JAX implementation _jax_*norm uses the compute dtype float32
            # for zero-centered gamma always
            precise_comparison = False
        elif scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING and inp_dtype != jnp.float32:
            # Current implementation of Current Tensor Scaling performs unfused layernorm and quantization
            # and writes intermediate results into the input dtype, which will slightly reduce precision
            # if the input dtype is not float32
            precise_comparison = False

565
        assert_bitwise_scaled_tensors(output, ref_out, precise_comparison=precise_comparison)
566

567
568
569
        assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype)
        if norm_type == "layernorm":
            assert_allclose(mu, ref_mu, dtype=inp_dtype)
570

571
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
572
573
    # No Norm FWD E5M2 in TE backend
    @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
574
575
576
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
577
578
579
580
581
582
583
584
585
586
587
588
589
590
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_norm_forward_with_tensor_scaling_fp8(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        q_layout,
        scaling_mode,
591
592
593
594
595
596
597
598
599
600
601
602
    ):
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        self._test_norm_forward(
            n=n,
            hidden=hidden,
            norm_type=norm_type,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon,
            inp_dtype=inp_dtype,
            out_dtype=out_dtype,
603
            scaling_mode=scaling_mode,
604
            q_layout=q_layout,
605
        )
606

607
    @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
608
609
610
611
612
613
614
615
616
617
618
619
    @pytest.mark.parametrize("out_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    def test_norm_forward_with_block_scaling_fp8(
        self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype
    ):
        self._test_norm_forward(
            n=n,
            hidden=hidden,
            norm_type=norm_type,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon,
            inp_dtype=inp_dtype,
            out_dtype=out_dtype,
620
            scaling_mode=ScalingMode.MXFP8_1D_SCALING,
621
            q_layout=QuantizeLayout.ROWWISE_COLWISE,
622
        )
623
624


625
QUANTIZE_OUTPUT_FP8_DTYPES = {
626
627
628
    "L0": [jnp.float8_e4m3fn],
    "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
629
630
631
632
633
634
635
636
637
638
639
640
641
642
QUANTIZE_OUTPUT_DTYPES = {
    test_level: QUANTIZE_OUTPUT_FP8_DTYPES[test_level] + [jnp.float4_e2m1fn]
    for test_level in QUANTIZE_OUTPUT_FP8_DTYPES
}
QUANTIZE_QDTYPE_AND_SCALING_MODES = {
    test_level: [
        (q_dtype, scaling_mode)
        for q_dtype, scaling_mode in zip(
            QUANTIZE_OUTPUT_FP8_DTYPES[test_level], supported_scaling_modes
        )
        if q_dtype in scaling_mode.get_compatible_q_dtypes()
    ]
    for test_level in QUANTIZE_OUTPUT_FP8_DTYPES
}
643

644
645
646
ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [
    ((32, 64), -1),
    ((2, 64, 32), -1),
647
    ((64, 2, 32), -2),
648
649
650
    ((32, 256, 128), -1),
    ((32, 256, 128), -2),
    ((64, 32, 32, 256), -1),
651
    ((8192, 2, 4096), -2),
652
]
653

654
QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = {
655
    "L0": [
656
657
        ((32, 64), -1),
        ((2, 64, 32), -1),
658
        ((64, 2, 32), -2),
659
    ],
660
    "L2": ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES,
661
}
662

663
664
665
666
667
668
QUANTIZATION_INPUT_DTYPE = {
    "L0": [jnp.bfloat16],
    "L2": [jnp.float32, jnp.float16, jnp.bfloat16],
}


669
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
670
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
671
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2, jnp.float4_e2m1fn])
672
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
673
674
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper(
675
676
677
678
679
680
    "q_layout",
    [
        QuantizeLayout.ROWWISE,
        QuantizeLayout.COLWISE,
        QuantizeLayout.ROWWISE_COLWISE,
    ],
681
682
683
684
685
686
)
class TestQuantize:
    """
    Purely quantization related tests that will always test on a wider set of types and shapes
    """

687
688
    def _skip_unsupported_dtypes(self, q_dtype, scaling_mode):
        """Skip unsupported dtypes for given scaling mode. For example, NVFP4 only supports the float4_e2m1 dtype not float8 dtypes."""
689
690
691
692
        if q_dtype not in scaling_mode.get_compatible_q_dtypes():
            pytest.skip(f"Quantize dtype {q_dtype} is not supported by {scaling_mode}")
            return

693
    def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
694
        self._skip_unsupported_dtypes(q_dtype, scaling_mode)
695

696
        key = jax.random.PRNGKey(0)
697

698
699
700
701
        # Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling)
        quantizer = QuantizerFactory.create(
            scaling_mode=scaling_mode,
            q_dtype=q_dtype,
702
            q_layout=q_layout,
703
        )
704

705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
        if scaling_mode.is_nvfp4_scaling:
            if in_dtype != jnp.bfloat16:
                pytest.skip("NVFP4 scaling only supported with bfloat16 input dtype currently")
                return
            q_func = _jax_quantize
            # For NVFP4 scaling, the maximum possible error for a single value can be high between the dequantized and original tensors. To ensure quantization and dequantization is operating correctly without requiring a very high tolerance for all values, we instead test that quantizing the dequantized tensor is bitwise identical to the original quantized tensor.
            x = jax.random.uniform(key, input_shape, in_dtype) * 10
            q1 = q_func(x, quantizer=quantizer, flatten_axis=flatten_axis)

            dq_rowwise = None
            dq_colwise = None
            if isinstance(q1, ScaledTensor1x):
                dq = q1.dequantize()
                if q1.is_colwise:
                    dq_colwise = dq
                else:
                    dq_rowwise = dq
            elif isinstance(q1, ScaledTensor2x):
                dq_rowwise = q1.rowwise_tensor.dequantize()
                dq_colwise = q1.colwise_tensor.dequantize()
            else:
                raise ValueError(f"Unsupported output type {type(q1)}")

            # We only compare Q-DQ for the same quantization layout. If we for example QDQ rowwise, then re-quantize colwise, the error will be larger and may not be bitwise identical to the original colwise quantization.
            if dq_rowwise is not None:
                assert (
                    dq_rowwise.shape == x.shape
                ), f"dq_rowwise shape {dq_rowwise.shape} != x shape {x.shape}"
                q2_rowwise = q_func(dq_rowwise, quantizer=quantizer, flatten_axis=flatten_axis)
                q2_rowwise = (
                    q2_rowwise
                    if isinstance(q2_rowwise, ScaledTensor1x)
                    else q2_rowwise.rowwise_tensor
                )
                q1_rowwise = q1 if isinstance(q1, ScaledTensor1x) else q1.rowwise_tensor
                assert_bitwise_scaled_tensors(q1_rowwise, q2_rowwise)

            if dq_colwise is not None:
                # Since this is for NVFP4, we are assuming colwise has T layout and we do a transpose here to get back to original shape
                flatten_axis = flatten_axis + len(input_shape) if flatten_axis < 0 else flatten_axis
                colwise_flatten_axis = len(input_shape) - flatten_axis
                dq_colwise = jnp.transpose(
                    dq_colwise,
                    (*range(colwise_flatten_axis, dq_colwise.ndim), *range(colwise_flatten_axis)),
                )
                assert (
                    dq_colwise.shape == x.shape
                ), f"dq_colwise shape {dq_colwise.shape} != x shape {x.shape}"
                q2_colwise = q_func(dq_colwise, quantizer=quantizer, flatten_axis=flatten_axis)
                q2_colwise = (
                    q2_colwise
                    if isinstance(q2_colwise, ScaledTensor1x)
                    else q2_colwise.colwise_tensor
                )
                q1_colwise = q1 if isinstance(q1, ScaledTensor1x) else q1.colwise_tensor
                assert_bitwise_scaled_tensors(q1_colwise, q2_colwise)

            assert (
                dq_rowwise is not None or dq_colwise is not None
            ), "At least one of rowwise or colwise dq must be not None"
            return

767
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
768
769
770
        for _ in range(n_iterations):
            x = jax.random.uniform(key, input_shape, in_dtype)

771
            scaled_tensor = quantizer.quantize(x, flatten_axis=flatten_axis)
772
773
            assert_dequantized_scaled_tensor(scaled_tensor, x)

774
    def _should_use_precise_comparison(
775
        self, in_dtype, scaling_mode, quantizer, input_shape, flatten_axis
776
777
778
779
780
781
782
    ):
        if scaling_mode.is_nvfp4_scaling and in_dtype != jnp.bfloat16:
            # With NVFP4 scaling, TE kernels internally use bfloat16 so using a different input dtype can lead to small numerical differences compared to the JAX implementation
            return False

        return True

783
784
785
    def test_quantize_bitwise(
        self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
    ):
786
        self._skip_unsupported_dtypes(q_dtype, scaling_mode)
787
788
789
790
791

        key = jax.random.PRNGKey(0)
        input = jax.random.uniform(key, input_shape, in_dtype)

        te_quantizer, jax_quantizer = QuantizerFactory.create(
792
            n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout
793
        )
794

795
        jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
796

797
        te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
798
799
800
801
802

        assert_bitwise_scaled_tensors(
            te_output,
            jax_output,
            precise_comparison=self._should_use_precise_comparison(
803
                in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis
804
805
806
807
808
809
            ),
        )

    def test_quantize_bitwise_jitted(
        self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
    ):
810
        self._skip_unsupported_dtypes(q_dtype, scaling_mode)
811
812
813
814
815
816
817
818
819
820
821
822
823

        key = jax.random.PRNGKey(0)
        input = jax.random.uniform(key, input_shape, in_dtype)

        te_quantizer, jax_quantizer = QuantizerFactory.create(
            n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout
        )

        jax_impl_func_jit = jax.jit(_jax_quantize, static_argnums=(2, 3))
        te_impl_func_jit = jax.jit(tex.quantize, static_argnums=(2,))

        jax_output = jax_impl_func_jit(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)

824
        te_output = te_impl_func_jit(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
825
826
827
828
829

        assert_bitwise_scaled_tensors(
            te_output,
            jax_output,
            precise_comparison=self._should_use_precise_comparison(
830
                in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
            ),
        )


@pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16])
@pytest_parametrize_wrapper("q_dtype", [jnp.float4_e2m1fn])
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
@pytest_parametrize_wrapper(
    "scaling_mode", [s for s in supported_scaling_modes if s.is_nvfp4_scaling]
)
@pytest_parametrize_wrapper(
    "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
)
class TestStochasticRounding:

    def _dequantize(self, scaled_tensor) -> list[jnp.ndarray]:
        """Dequantizes a ScaledTensor back to it's original jnp.ndarray form. This always returns an array of jnp.ndarrays, for ScaledTensor2x there will be two tensors, for ScaledTensor1x there will be one tensor."""
        if isinstance(scaled_tensor, ScaledTensor1x):
            dq = scaled_tensor.dequantize()
            if scaled_tensor.data_layout == "T":
                dq = jnp.transpose(
                    dq,
                    (
                        *range(scaled_tensor.flatten_axis, dq.ndim),
                        *range(scaled_tensor.flatten_axis),
                    ),
                )
            return [dq]
        elif isinstance(scaled_tensor, ScaledTensor2x):
            [rowwise_dq] = self._dequantize(scaled_tensor.rowwise_tensor)
            [colwise_dq] = self._dequantize(scaled_tensor.colwise_tensor)
            return [rowwise_dq, colwise_dq]
        raise ValueError(
            "Unsupported ScaledTensor type, expected ScaledTensor but received"
            f" {type(scaled_tensor)}"
        )

    def _sample_sr_qdq(
        self, num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
    ) -> list[jnp.ndarray]:
        """Samples num_samples quantize-dequantize operations with stochastic rounding enabled and returns the dequantized tensors."""
        dq_tensors = []

        key = jax.random.PRNGKey(0)

        for i in range(num_samples):
            iter_key = jax.random.fold_in(key, i)
            sr_rng_state = jax.random.randint(
879
                iter_key, (1, 4), minval=0, maxval=2**30 - 1, dtype=jnp.uint32
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
            )
            quantizer = QuantizerFactory.create(
                q_dtype=q_dtype,
                scaling_mode=scaling_mode,
                q_layout=q_layout,
                stochastic_rounding_rng_state=sr_rng_state,
            )

            q_output = q_func(inputs, quantizer=quantizer, flatten_axis=flatten_axis)
            iter_dq = self._dequantize(q_output)
            dq_tensors.extend(iter_dq)

            avg_sr_tensor = jnp.mean(jnp.stack(dq_tensors), axis=0)
            assert avg_sr_tensor.shape == inputs.shape, (
                f"Dequantized tensor shape {avg_sr_tensor.shape} does not match input shape"
                f" {inputs.shape}"
            )

            sr_mae = jnp.mean(jnp.abs(avg_sr_tensor - inputs))

        dq_var = jnp.var(jnp.stack(dq_tensors))
        assert (
            dq_var > 0
        ), "Variance of dequantized tensors is zero, stochastic rounding may not be working"

        return dq_tensors

    def _round_nearest(
        self, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
    ) -> jnp.ndarray:
        """Quantizes and dequantizes the input tensor with round nearest quantization."""
        quantizer = QuantizerFactory.create(
            q_dtype=q_dtype,
            scaling_mode=scaling_mode,
            q_layout=q_layout,
            stochastic_rounding_rng_state=None,
        )
        q_output = q_func(inputs, quantizer=quantizer, flatten_axis=flatten_axis)
        rn_dq = self._dequantize(q_output)[0]
        return rn_dq

    def _test_sr(
        self, num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
    ) -> float:
        """Tests that the mean absolute error (MAE) of stochastic rounding is smaller than round nearest quantization over multiple samples."""
        dq_tensors = self._sample_sr_qdq(
            num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
        )
        avg_sr_tensor = jnp.mean(jnp.stack(dq_tensors).astype(jnp.float32), axis=0)
        assert avg_sr_tensor.shape == inputs.shape, (
            f"Dequantized tensor shape {avg_sr_tensor.shape} does not match input shape"
            f" {inputs.shape}"
        )

        round_nearest_tensor = self._round_nearest(
            q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
        )

        sr_mae = jnp.mean(jnp.abs(avg_sr_tensor - inputs))
        rn_mae = jnp.mean(jnp.abs(round_nearest_tensor - inputs))

        assert sr_mae < rn_mae, (
            f"Mean absolute error of stochastic rounding ({sr_mae}) is not smaller than"
            f" round nearest ({rn_mae})"
        )

        return sr_mae

    def test_sr_nvfp4(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
        """Tests that the mean absolute error of stochastic rounding is smaller than round nearest quantization over multiple samples for both TE and JAX implementations. Asserts that the MAE of both implementations is close to each other."""

        key = jax.random.PRNGKey(0)
        inputs = jax.random.uniform(key, input_shape, in_dtype)

        NUM_SAMPLES = 10

        te_mean_error = self._test_sr(
            NUM_SAMPLES, tex.quantize, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
        )
        jax_mean_error = self._test_sr(
            NUM_SAMPLES, _jax_quantize, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
        )

        assert_allclose(te_mean_error, jax_mean_error, rtol=0.2, atol=1e-4)
964
965


966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
@pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16])
@pytest_parametrize_wrapper("q_dtype", [jnp.float4_e2m1fn])
@pytest_parametrize_wrapper(
    "scaling_mode", [s for s in supported_scaling_modes if s == ScalingMode.NVFP4_1D_SCALING]
)
class TestRandomizedHadamardTransform:

    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE]
    )
    @pytest_parametrize_wrapper("input_shape,flatten_axis", [((64, 128), -1)])
    def test_rht_quantize_bitwise_jitted(
        self, in_dtype, q_dtype, scaling_mode, q_layout, input_shape, flatten_axis
    ):
        key = jax.random.PRNGKey(0)
        inputs = jax.random.uniform(key, input_shape, in_dtype)

        te_quantizer, jax_quantizer = QuantizerFactory.create(
            n_quantizers=2,
            q_dtype=q_dtype,
            scaling_mode=scaling_mode,
            q_layout=q_layout,
            use_rht=True,
        )

        jax_impl_func_jit = jax.jit(_jax_quantize, static_argnums=(2, 3))
        te_impl_func_jit = jax.jit(tex.quantize, static_argnums=(2,))

        jax_output = jax_impl_func_jit(inputs, quantizer=jax_quantizer, flatten_axis=flatten_axis)

        te_output = te_impl_func_jit(inputs, quantizer=te_quantizer, flatten_axis=flatten_axis)

        assert_bitwise_scaled_tensors(te_output, jax_output)

    def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
        if data_layout[0] == "T":
            a = jnp.swapaxes(a, -1, -2)
        if data_layout[1] == "T":
            b = jnp.swapaxes(b, -1, -2)
        return jnp.dot(a, b)

    def _generate_gemm_input(self, m, n, k, data_layout):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        x = jax.random.uniform(
            subkeys[0],
            (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
            dtype=jnp.bfloat16,
        ) / jnp.sqrt(k)
        w = jax.random.uniform(
            subkeys[1],
            (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
            dtype=jnp.bfloat16,
        ) / jnp.sqrt(n)
        lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
        rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
        contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)

        return (x, w, contracting_dims)

    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
    # We do not test NN and TT layouts here as they do not have both inputs using RHT due to RHT only supporting the colwise layout currently
    @pytest_parametrize_wrapper("data_layout", ["TN", "NT"])
    @pytest_parametrize_wrapper("with_jax_gemm", [True, False])
    def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, with_jax_gemm):
        key = jax.random.PRNGKey(0)

        lhs_scaling_mode, rhs_scaling_mode = scaling_mode, scaling_mode
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
        lhs_quantizer = QuantizerFactory.create(
            scaling_mode=lhs_scaling_mode,
            q_dtype=jnp.float4_e2m1fn,
            use_rht=True,
        )
        rhs_quantizer = QuantizerFactory.create(
            scaling_mode=rhs_scaling_mode,
            q_dtype=jnp.float4_e2m1fn,
            use_rht=True,
        )
        with use_jax_gemm(enabled=with_jax_gemm):
            primitive_out = tex.gemm(
                x,
                w,
                contracting_dims=contracting_dims,
                lhs_quantizer=lhs_quantizer,
                rhs_quantizer=rhs_quantizer,
            )
        ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
        assert_allclose(primitive_out, ref_out, dtype=jnp.float4_e2m1fn)


1057
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1058
1059
1060
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn])
1061
@pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes)
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
@pytest_parametrize_wrapper("flatten_axis", [-1])
@pytest_parametrize_wrapper("with_group_sizes", [True, False])
@pytest_parametrize_wrapper(
    "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE]
)
class TestGroupedQuantize:
    def test_grouped_qdq(
        self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis, with_group_sizes
    ):
        n_groups, m, n = input_shape
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        # *32 so that the input shapes works for MXFP8
        input_shape = (m * 32, n)

        if with_group_sizes:
            group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
            group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
            group_sizes = jnp.diff(group_sizes)
            assert group_sizes.sum() == m
            assert jnp.any(group_sizes == 0)  # make sure that at least one group has 0 row
            group_sizes = group_sizes * 32
        else:
            group_sizes = None
            input_shape = (n_groups, input_shape[0] // n_groups, input_shape[1])

        if flatten_axis == -2:
            input_shape = input_shape[:-1] + (2,) + input_shape[-1:]

        x = jax.random.uniform(subkeys[1], input_shape, in_dtype)

        grouped_quantizer = QuantizerFactory.create(
            scaling_mode=scaling_mode,
            q_dtype=q_dtype,
            q_layout=q_layout,
            n_groups=n_groups,
        )
        scaled_tensor = tex.grouped_quantize(
            x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer
        )

        assert_dequantized_grouped_scaled_tensor(scaled_tensor, x)


1107
1108
1109
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
class TestFusedQuantize:

1110
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1111
    @pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
1112
    @pytest_parametrize_wrapper("out_dtype,scaling_mode", QUANTIZE_QDTYPE_AND_SCALING_MODES)
1113
1114
1115
1116
1117
1118
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
    def test_quantize_dbias(
        self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis
    ):
1119
        if scaling_mode == ScalingMode.MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8(
1120
1121
1122
1123
1124
1125
1126
1127
            input_shape
        ):
            pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")

        key = jax.random.PRNGKey(0)
        input = jax.random.uniform(key, input_shape, in_dtype)

        jax_quantizer, te_quantizer = QuantizerFactory.create(
1128
            n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
1129
        )
1130

1131
1132
1133
1134
1135
        te_output, te_dbias = jit(
            lambda input: tex.quantize_dbias(
                input, quantizer=te_quantizer, flatten_axis=flatten_axis
            )
        )(input)
1136
1137
1138

        jax_output, jax_dbias = jit(
            lambda input: _jax_quantize_dbias(
1139
                input, quantizer=jax_quantizer, flatten_axis=flatten_axis
1140
            )
1141
        )(input)
1142

1143
        assert_bitwise_scaled_tensors(te_output, jax_output)
1144

1145
        assert_allclose(te_dbias, jax_dbias)
1146
1147

    def _test_quantize_dact_dbias(
1148
        self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout
1149
    ):
1150

1151
1152
1153
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1)
1154
1155
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
1156
        dz = jax.random.uniform(subkeys[1], input_shape, in_dtype, -1, 1)
1157

1158
        jax_quantizer, te_quantizer = QuantizerFactory.create(
1159
            n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
        )
        is_casted_output = te_quantizer is not None

        te_output, te_dbias = jit(
            lambda dz, x: tex.quantize_dact_dbias(
                dz,
                x,
                activation_type=activation_type,
                is_dbias=is_dbias,
                quantizer=te_quantizer,
            )
        )(dz, x)

        jax_output, jax_dbias = jit(
            lambda dz, x: _jax_quantize_dact_dbias(
                dz,
                x,
                activation_type=activation_type,
                is_dbias=is_dbias,
                quantizer=jax_quantizer,
            )
        )(dz, x)
1182

1183
        if is_casted_output:
1184
1185
1186
1187
1188
1189
1190
            # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation
            precise_comparison = not (
                in_dtype != jnp.float32 and scaling_mode.is_1d_block_scaling()
            )
            assert_bitwise_scaled_tensors(
                te_output, jax_output, precise_comparison=precise_comparison
            )
1191
        else:
1192
1193
1194
            assert isinstance(te_output, NoScaleTensor)
            assert isinstance(jax_output, NoScaleTensor)
            assert_allclose(te_output.data, jax_output.data)
1195
1196

        if is_dbias:
1197
            precise_comparison = not (
1198
1199
1200
1201
                # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
                (in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling())
                # Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently.
                or (
1202
                    activation_type in {("squared_relu",), ("clamped_silu", "clamped_linear")}
1203
1204
1205
                    and in_dtype == jnp.bfloat16
                    and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
                )
1206
1207
1208
1209
            )
            assert_allclose(
                te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype
            )
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224

    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("is_dbias", [True, False])
    def test_quantize_dact_dbias_no_quantization(
        self,
        in_dtype,
        input_shape,
        activation_type,
        is_dbias,
    ):
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=in_dtype,
1225
            scaling_mode=ScalingMode.NO_SCALING,
1226
1227
            activation_type=activation_type,
            is_dbias=is_dbias,
1228
            q_layout=QuantizeLayout.ROWWISE,
1229
        )
1230

1231
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1232
1233
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
1234
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_FP8_DTYPES)
1235
    @pytest_parametrize_wrapper("is_dbias", [True, False])
1236
    @pytest_parametrize_wrapper(
1237
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
1238
    )
1239
1240
1241
1242
1243
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_quantize_dact_dbias_tensor_scaling(
        self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout, scaling_mode
1244
1245
1246
1247
1248
    ):
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=out_dtype,
1249
            scaling_mode=scaling_mode,
1250
1251
            activation_type=activation_type,
            is_dbias=is_dbias,
1252
            q_layout=q_layout,
1253
        )
1254

1255
    @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
1256
1257
1258
1259
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper(
        "input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)]
    )
1260
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_FP8_DTYPES)
1261
    @pytest_parametrize_wrapper("is_dbias", [True, False])
1262
1263
1264
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
1265
    def test_quantize_dact_dbias_mxfp8_scaling(
1266
        self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout
1267
1268
1269
1270
1271
1272
1273
1274
    ):
        if reduce(operator.mul, input_shape[:-1]) % 128 != 0 or input_shape[-1] % 128 != 0:
            # TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes.
            # If it doesn't, move this check into the quantize_dact_dbias function and revert to JAX
            # implementation in the unsupported cases
            pytest.skip(
                f"Input shape {input_shape} is not supported by dact MXFP8 kernel in TE currently"
            )
1275

1276
1277
1278
1279
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=out_dtype,
1280
            scaling_mode=ScalingMode.MXFP8_1D_SCALING,
1281
1282
            activation_type=activation_type,
            is_dbias=is_dbias,
1283
            q_layout=q_layout,
1284
        )
1285
1286


Alp Dener's avatar
Alp Dener committed
1287
1288
1289
1290
1291
1292
valid_fp8_gemm_operand_types = [
    (jnp.float8_e4m3fn, jnp.float8_e4m3fn),
    (jnp.float8_e5m2, jnp.float8_e4m3fn),
    (jnp.float8_e4m3fn, jnp.float8_e5m2),
]

1293
1294
1295
1296
1297
supported_nvfp4_scaling_mode_pairs = [
    (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_1D_SCALING),
    (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING),
]

Alp Dener's avatar
Alp Dener committed
1298

1299
class TestDense:
1300
1301
    def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
        if data_layout[0] == "T":
1302
            a = jnp.swapaxes(a, -1, -2)
1303
        if data_layout[1] == "T":
1304
1305
            b = jnp.swapaxes(b, -1, -2)
        return jnp.dot(a, b)
1306

1307
    def _generate_gemm_input(self, m, n, k, data_layout):
1308
1309
1310
1311
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        x = jax.random.uniform(
            subkeys[0],
1312
            (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
1313
1314
1315
1316
            dtype=jnp.bfloat16,
        ) / jnp.sqrt(k)
        w = jax.random.uniform(
            subkeys[1],
1317
            (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
1318
1319
            dtype=jnp.bfloat16,
        ) / jnp.sqrt(n)
1320
1321
        lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
        rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
1322
1323
1324
1325
        contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)

        return (x, w, contracting_dims)

1326
1327
1328
1329
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
    @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
    def test_gemm_bf16(self, m, n, k, data_layout):
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
1330

Alp Dener's avatar
Alp Dener committed
1331
        primitive_out = tex.gemm(x, w, contracting_dims=contracting_dims)
1332
        ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
1333

1334
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
1335

1336
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1337
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
Alp Dener's avatar
Alp Dener committed
1338
    @pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types)
1339
    @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes)
1340
    @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
Alp Dener's avatar
Alp Dener committed
1341
1342
1343
1344
1345
1346
1347
1348
1349
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
    def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm):
        if (
            not with_jax_gemm
            and scaling_mode.is_1d_block_scaling()
            and jnp.float8_e5m2 in (x_qtype, w_qtype)
        ):
            pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.")

1350
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
1351
        quantizer_set = QuantizerFactory.create_set(
Alp Dener's avatar
Alp Dener committed
1352
1353
1354
1355
            scaling_mode=scaling_mode,
            fwd_dtype=jnp.float8_e4m3fn,
            bwd_dtype=jnp.float8_e5m2,
            is_2x2x=False,
1356
        )
Alp Dener's avatar
Alp Dener committed
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
        with use_jax_gemm(enabled=with_jax_gemm):
            primitive_out = tex.gemm(
                x,
                w,
                contracting_dims=contracting_dims,
                lhs_quantizer=(
                    quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad
                ),
                rhs_quantizer=(
                    quantizer_set.kernel if w_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad
                ),
            )
1369
        ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
1370

Alp Dener's avatar
Alp Dener committed
1371
        assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
1372

1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
    # TODO(Phuong): add bitwise test
    @pytest.mark.skipif(not is_fp4_supported, reason=fp4_unsupported_reason)
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
    @pytest_parametrize_wrapper("scaling_mode_pair", supported_nvfp4_scaling_mode_pairs)
    @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
    @pytest_parametrize_wrapper("with_jax_gemm", [True, False])
    def test_gemm_nvfp4(self, m, n, k, scaling_mode_pair, data_layout, with_jax_gemm):
        x_uses_rht = scaling_mode_pair[0] == ScalingMode.NVFP4_1D_SCALING and data_layout[0] == "T"
        w_uses_rht = scaling_mode_pair[1] == ScalingMode.NVFP4_1D_SCALING and data_layout[1] == "N"
        if x_uses_rht != w_uses_rht:
            # TODO(jberchtold): Ideally avoid a skip here and rewrite test setup to ensure both or neither use RHT
            pytest.skip("RHT must be used for both or neither operand, skipping")

        lhs_scaling_mode, rhs_scaling_mode = scaling_mode_pair
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
        lhs_quantizer = QuantizerFactory.create(
            scaling_mode=lhs_scaling_mode,
            q_dtype=jnp.float4_e2m1fn,
        )
        rhs_quantizer = QuantizerFactory.create(
            scaling_mode=rhs_scaling_mode,
            q_dtype=jnp.float4_e2m1fn,
        )
        with use_jax_gemm(enabled=with_jax_gemm):
            primitive_out = tex.gemm(
                x,
                w,
                contracting_dims=contracting_dims,
                lhs_quantizer=lhs_quantizer,
                rhs_quantizer=rhs_quantizer,
            )
        ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
        assert_allclose(primitive_out, ref_out, dtype=jnp.float4_e2m1fn)

1407
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
1408
    def test_dense_grad_bf16(self, m, n, k):
1409
1410
        data_layout = "NN"
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
1411

1412
1413
1414
        def primitive_func(x, w, contracting_dims):
            primitive_out = dense(x, w, contracting_dims=contracting_dims)
            return jnp.mean(primitive_out)
1415

1416
1417
        def ref_func(x, w, data_layout):
            return jnp.mean(self._ref_gemm_with_jnp_dot(x, w, data_layout))
1418

1419
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))
1420

1421
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))
1422

1423
1424
        primitive_out, (primitive_x_grad, primitive_w_grad) = value_n_grad_primitive_func(
            x, w, contracting_dims
1425
        )
1426
        ref_out, (ref_x_grad, ref_w_grad) = value_n_grad_ref_func(x, w, data_layout)
1427
1428
1429
1430

        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
        assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16)
        assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16)
1431

1432
1433
    @pytest_parametrize_wrapper("m,n,k", [(64, 128, 128)])
    @pytest_parametrize_wrapper("recipe", supported_recipes)
Alp Dener's avatar
Alp Dener committed
1434
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
1435
    def test_dense_grad_fp8_and_fp4(self, m, n, k, recipe, with_jax_gemm):
1436
1437
        data_layout = "NN"
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
1438
1439
1440
1441
1442
1443
1444
1445
1446

        key = jax.random.PRNGKey(1)
        bias = jax.random.uniform(key, n, dtype=jnp.bfloat16)

        def primitive_func(x, w, bias, contracting_dims, quantizer_set):
            primitive_out = dense(
                x, w, bias, contracting_dims=contracting_dims, quantizer_set=quantizer_set
            )
            return jnp.mean(primitive_out)
1447

1448
        def ref_func(x, w, bias, data_layout):
1449
            return jnp.mean(
1450
                self._ref_gemm_with_jnp_dot(x, w, data_layout) + jnp.expand_dims(bias, axis=0)
1451
            )
1452

1453
1454
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
1455

1456
        quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe)
1457

1458
        n_iterations = 3 if recipe.delayed() else 1
Alp Dener's avatar
Alp Dener committed
1459
1460
1461
1462
1463
        with use_jax_gemm(enabled=with_jax_gemm):
            for _ in range(n_iterations):
                primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = (
                    value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set)
                )
1464

1465
1466
1467
        ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func(
            x, w, bias, data_layout
        )
1468

1469
1470
1471
1472
        assert_allclose(primitive_out, ref_out, dtype=quantizer_set.x.q_dtype)
        assert_allclose(primitive_x_grad, ref_x_grad, dtype=quantizer_set.dgrad.q_dtype)
        assert_allclose(primitive_w_grad, ref_w_grad, dtype=quantizer_set.dgrad.q_dtype)
        assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=quantizer_set.dgrad.q_dtype)
1473
1474


1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
    key = jax.random.PRNGKey(0)
    subkeys = jax.random.split(key, 4)
    out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
    return out


def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer):
    if norm_type == "rmsnorm":
        ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
    else:
        ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
1488
    ln_out = ln_out.dequantize()
1489
1490
1491
1492
    return ln_out


class TestFusedDense:
1493
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1494
1495
    @pytest.mark.parametrize("m,n,k", [(64, 128, 128)])
    @pytest_parametrize_wrapper("recipe", supported_recipes)
1496
    @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
Alp Dener's avatar
Alp Dener committed
1497
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
1498
    def test_layernorm_dense_grad(self, m, n, k, recipe, norm_type, with_jax_gemm):
1499
        """
1500
        Test layernorm_dense VJP Rule
1501
        """
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
        # zero_centered_gamma is already tested in TestNorm
        zero_centered_gamma = False
        eps = 1e-6

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 4)

        # NN in FWD
        x = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16) / jnp.sqrt(k)
        w = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16) / jnp.sqrt(n)

        gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)

1515
        quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe)
1516
1517
1518

        if norm_type == "layernorm":
            beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
1519
        else:
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
            beta = None

        def prim_func(x, w, gamma, beta):
            # bias = None as quantize_dbias is already tested in test_dense_grad_fp8
            prim_out = layernorm_dense(
                x,
                w,
                gamma,
                beta,
                None,
                norm_type,
                zero_centered_gamma,
                eps,
                quantizer_set=quantizer_set,
            )
            return jnp.mean(prim_out)

        def ref_func(x, w, gamma, beta):
            x = _ref_jax_norm_impl(
                x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
            )
            return jnp.mean(jnp.dot(x, w))

        value_n_grad_prim_func = value_and_grad(prim_func, (0, 1, 2, 3))
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))

        ref_out, (ref_x_grad, ref_w_grad, ref_gamma_grad, ref_beta_grad) = value_n_grad_ref_func(
            x, w, gamma, beta
        )

1550
        n_iterations = 3 if recipe.delayed() else 1
Alp Dener's avatar
Alp Dener committed
1551
1552
1553
1554
1555
1556
1557
1558
1559
        with use_jax_gemm(enabled=with_jax_gemm):
            for _ in range(n_iterations):
                prim_out, (
                    prim_x_grad,
                    prim_w_grad,
                    prim_gamma_grad,
                    prim_beta_grad,
                ) = value_n_grad_prim_func(x, w, gamma, beta)

1560
1561
1562
1563
        assert_allclose(prim_out, ref_out, dtype=quantizer_set.x.q_dtype)
        assert_allclose(prim_x_grad, ref_x_grad, dtype=quantizer_set.dgrad.q_dtype)
        assert_allclose(prim_w_grad, ref_w_grad, dtype=quantizer_set.dgrad.q_dtype)
        assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=quantizer_set.dgrad.q_dtype)
1564
        if beta is not None:
1565
            assert_allclose(prim_beta_grad, ref_beta_grad, dtype=quantizer_set.dgrad.q_dtype)
1566

1567
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1568
    @pytest.mark.parametrize("m,n,k", [(64, 128, 128)])
1569
    @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
1570
    @pytest_parametrize_wrapper("recipe", supported_recipes)
1571
    @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
Alp Dener's avatar
Alp Dener committed
1572
1573
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
1574
    def test_layernorm_mlp_grad(
1575
        self, m, n, k, activation_type, recipe, norm_type, use_bias, with_jax_gemm
1576
    ):
1577
        """
1578
        Test layernorm_mlp VJP Rule
1579
        """
1580
1581
1582
1583
1584
1585
1586
1587
1588
        # zero_centered_gamma is already tested in TestNorm
        zero_centered_gamma = False
        eps = 1e-6

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 6)

        x = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        kernel_1 = jax.random.normal(
1589
            subkeys[1], (k, len(activation_type), n), jnp.bfloat16
1590
1591
1592
1593
1594
        ) / jnp.sqrt(k)
        kernel_2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
        gamma = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
        beta = None  # was tested in TestNorm
        if use_bias:
1595
            bias_1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
1596
1597
1598
1599
1600
1601
1602
            bias_2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
        else:
            bias_1 = None
            bias_2 = None

        quantizer_sets = QuantizerFactory.create_set(
            n_quantizer_sets=2,
1603
            fp8_recipe=recipe,
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
        )

        if norm_type == "layernorm":
            beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
        else:
            beta = None

        def prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            return jnp.mean(
                layernorm_mlp(
                    x,
                    gamma,
                    beta,
                    [kernel_1, kernel_2],
                    [bias_1, bias_2],
                    norm_type,
                    zero_centered_gamma=zero_centered_gamma,
                    epsilon=eps,
                    activation_type=activation_type,
                    quantizer_sets=quantizer_sets,
1624
1625
                )
            )
1626

1627
1628
1629
        def _ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            ln_out = _ref_jax_norm_impl(
                x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
1630
            )
Alp Dener's avatar
Alp Dener committed
1631
            linear_1_out = jax.lax.dot_general(ln_out, kernel_1, (((1,), (0,)), ((), ())))
1632
1633
1634
1635
            if use_bias:
                bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
                linear_1_out += jnp.reshape(bias_1, bias_1_shape)

1636
            x = _jax_act_lu(linear_1_out, activation_type).data
Alp Dener's avatar
Alp Dener committed
1637
            linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ())))
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
            if use_bias:
                bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape
                linear_2_out += jnp.reshape(bias_2, bias_2_shape)

            return linear_2_out

        def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            return jnp.mean(_ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2))

        value_n_grad_prim_func = value_and_grad(prim_func, range(6))
        value_n_grad_ref_func = value_and_grad(ref_func, range(6))

1650
        n_iterations = 3 if recipe.delayed() else 1
Alp Dener's avatar
Alp Dener committed
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
        with use_jax_gemm(enabled=with_jax_gemm):
            for _ in range(n_iterations):
                prim_out, (
                    prim_x_grad,
                    prim_gamma_grad,
                    prim_kernel_1_grad,
                    prim_kernel_2_grad,
                    prim_bias_1_grad,
                    prim_bias_2_grad,
                ) = value_n_grad_prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670

        ref_out, (
            ref_x_grad,
            ref_gamma_grad,
            ref_kernel_1_grad,
            ref_kernel_2_grad,
            ref_bias_1_grad,
            ref_bias_2_grad,
        ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)

1671
1672
1673
1674
1675
1676
1677
        fwd_dtype = quantizer_sets[0].x.q_dtype
        bwd_dtype = quantizer_sets[0].dgrad.q_dtype
        assert_allclose(prim_out, ref_out, dtype=fwd_dtype)
        assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=bwd_dtype)
        assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=bwd_dtype)
        assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=bwd_dtype)
        assert_allclose(prim_x_grad, ref_x_grad, dtype=bwd_dtype)
1678
        if use_bias:
1679
1680
            assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=bwd_dtype)
            assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=bwd_dtype)
1681
1682
1683
1684
1685
1686
1687
1688
1689


# E5M2 * E5M2 is not supported
fwd_bwd_dtypes = [
    [jnp.float8_e4m3fn, jnp.float8_e4m3fn],
    [jnp.float8_e4m3fn, jnp.float8_e5m2],
    [jnp.float8_e5m2, jnp.float8_e4m3fn],
]

1690
1691
1692
1693
1694
1695
1696
1697
1698
GROUPED_DENSE_INPUT_SHAPES = [
    # (n_groups, m, n, k), the actual m will be multiplied by 32
    (5, 32, 128, 64),  # Test the case where n_groups is not a multiple of 4
    (8, 64, 32, 128),
    (8, 64, 128, 256),
]


@pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES)
1699
class TestGroupedDense:
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
    def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims):
        lhs_contract_dim, _ = contracting_dims
        assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3
        if bias is None:
            bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype)
        else:
            assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2])
        remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop()
        lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis)
        rhs = jnp.split(rhs, rhs.shape[0], axis=0)
        bias = jnp.split(bias, bias.shape[0], axis=0)
        ref_out = []
        dim_num = (contracting_dims, ((), ()))
        for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias):
1714
1715
1716
            out_i = jax.lax.dot_general(
                lhs_i, rhs_i, dim_num, precision=jax.lax.Precision.HIGHEST
            ) + jnp.expand_dims(bias_i, axis=0)
1717
1718
1719
1720
            ref_out.append(jnp.squeeze(out_i))
        return ref_out

    def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False):
1721
        key = jax.random.PRNGKey(0)
1722
1723
1724
1725
1726
1727
        subkeys = jax.random.split(key, 4)
        n_groups, m, n, k = input_shape

        group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
        group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
        group_sizes = jnp.diff(group_sizes)
1728
1729
1730
        # Make one empty input lhs to test empty GEMM handling
        group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1])
        group_sizes = group_sizes.at[1].set(0)
1731
1732
1733
1734
1735
        assert group_sizes.sum() == m

        # *32 to make sure that input shape works for MXFP8
        group_sizes = group_sizes * 32
        m = m * 32
1736

1737
1738
1739
        lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m)
        rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k)
        bias_shape = (n_groups, n)
1740

1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
        lhs = jax.random.uniform(subkeys[1], lhs_shape, dtype=dtype)
        rhs = jax.random.uniform(subkeys[2], rhs_shape, dtype=dtype)
        bias = jax.random.uniform(subkeys[3], bias_shape, dtype=dtype) if with_bias else None

        lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
        rhs_contracting_dim = (1,) if data_layout[1] == "N" else (2,)
        contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)

        return lhs, rhs, group_sizes, contracting_dims, bias

    def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype):
        assert out.dtype == ref_list[0].dtype
        out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0)
        for i in range(len(ref_list)):
            assert_allclose(out_list[i], ref_list[i], dtype=dtype)
1756
1757

    @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
1758
1759
1760
1761
    @pytest_parametrize_wrapper("layout", ["NN"])
    def test_grouped_gemm_fp16(self, dtype, input_shape, layout):
        lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
            dtype, input_shape, layout
1762
        )
1763
1764
1765
1766
1767
        num_gemms = input_shape[0]
        _ = jax.jit(tex.grouped_gemm_copy_group_sizes, static_argnames=("num_gemms",))(
            group_sizes,
            num_gemms=num_gemms,
        )
1768
        ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
1769
1770

        # jitting grouped_gemm
1771
1772
1773
        prim_out = jax.jit(
            tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes")
        )(
1774
1775
1776
1777
            lhs,
            rhs,
            group_sizes,
            contracting_dims,
1778
            use_async_d2h_group_sizes=True,
1779
        )
1780

1781
        self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
1782

1783
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1784
    @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
1785
    @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes)
1786
1787
    @pytest_parametrize_wrapper("layout", ["NN"])
    def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout):
1788
1789
        fwd_dtype, bwd_dtype = fwd_bwd_dtype
        quantizer_set = QuantizerFactory.create_set(
1790
1791
1792
1793
1794
            scaling_mode=scaling_mode,
            fwd_dtype=fwd_dtype,
            bwd_dtype=bwd_dtype,
            is_2x2x=False,
            n_groups=input_shape[0],
1795
        )
1796

1797
1798
1799
1800
1801
1802
        # quantizer_set.{x, kernel} has fwd_dtype, while quantizer_set.grad has bwd_dtype
        # We want to test E4M3 * E5M2, manually set the quantizer_set.kernel.q_dtype to bwd_dtype
        quantizer_set.kernel.q_dtype = bwd_dtype
        for quantizer in quantizer_set.kernel.quantizers:
            quantizer.q_dtype = bwd_dtype

1803
        out_dtype = jnp.bfloat16
1804
1805
1806
1807
        lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
            out_dtype, input_shape, layout
        )
        ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
1808

1809
        prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
1810
            lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
1811
1812
1813
        )

        allclose_dtype = jnp.float8_e4m3fn
1814
        if jnp.float8_e5m2 in fwd_bwd_dtype:
1815
1816
            allclose_dtype = jnp.float8_e5m2

1817
        self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, allclose_dtype)
1818

1819
1820
1821
    def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims):
        out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims)
        # Note: we use jnp.sum instead of jnp.mean to make the gradient larger
1822
1823
        # and prevent them from being clamp to zero in FP8. / sqrt(x.size) is used to
        # normalize the output and prevent the gradient from being too large for FP8.
1824
        out_sum_list = [jnp.sum(out) for out in out_list]
1825
        return jnp.sum(jnp.asarray(out_sum_list)) / jnp.sqrt(x.size)
1826
1827
1828
1829
1830
1831

    def _primitive_sum_grouped_dense(
        self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set
    ):
        out = grouped_dense(
            x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set
1832
        )
1833
        return jnp.sum(jnp.asarray(out)) / jnp.sqrt(x.size)
1834

1835
1836
1837
1838
1839
1840
1841
    @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
    def test_grouped_dense_grad_fp16(self, dtype, input_shape):
        x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
            dtype,
            input_shape,
            with_bias=True,
        )
1842

1843
        value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
1844
        # jitting the grouped_dense
1845
1846
1847
        value_n_grad_prim_func = jit(
            value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,)
        )
1848

1849
1850
        ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
            x, kernel, bias, group_sizes, contracting_dims
1851
        )
1852
1853
        prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func(
            x, kernel, bias, group_sizes, contracting_dims
1854
        )
1855

1856
1857
1858
1859
        assert_allclose(prim_out_sum, ref_out_sum, dtype=dtype)
        assert_allclose(prim_dgrad, ref_dgrad, dtype=dtype)
        assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype)
        assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
1860

1861
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1862
1863
1864
1865
    @pytest.mark.parametrize(
        "fwd_bwd_dtype",
        [(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)],
    )
1866
    @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes)
1867
1868
1869
1870
1871
1872
1873
    def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
        fwd_dtype, bwd_dtype = fwd_bwd_dtype
        dtype = jnp.bfloat16
        x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
            dtype,
            input_shape,
            with_bias=True,
1874
        )
1875

1876
1877
1878
1879
1880
1881
        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode,
            fwd_dtype=fwd_dtype,
            bwd_dtype=bwd_dtype,
            is_2x2x=True,
            n_groups=group_sizes.size,
1882
        )
1883
        value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
1884
1885

        # jitting the grouped_dense
1886
1887
1888
        value_n_grad_prim_func = jit(
            value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,)
        )
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898

        ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
            x,
            kernel,
            bias,
            group_sizes,
            contracting_dims,
        )
        prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func(
            x, kernel, bias, group_sizes, contracting_dims, quantizer_set=quantizer_set
1899
1900
        )

1901
1902
1903
1904
        assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype)
        assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype)
        assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype)
        assert_allclose(prim_dbias, ref_dbias, dtype=dtype)