test_custom_call_compute.py 59.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.

import jax
import jax.numpy as jnp
import pytest
from jax import jit, value_and_grad
9
from functools import reduce
10
from typing import Union
11
12
13
14
15
import operator

from utils import (
    assert_allclose,
    pytest_parametrize_wrapper,
Alp Dener's avatar
Alp Dener committed
16
    use_jax_gemm,
17
18
19
20
21
)
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.layernorm_mlp import layernorm_mlp

from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu, _jax_quantize_dact_dbias
22
23
24
25
26
from transformer_engine.jax.cpp_extensions.normalization import (
    _jax_layernorm,
    _jax_rmsnorm,
    is_norm_zero_centered_gamma_in_weight_dtype,
)
27
28
29
from transformer_engine.jax.cpp_extensions.quantization import (
    _jax_quantize,
    _jax_quantize_dbias,
30
)
31
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
32
from transformer_engine.jax import cpp_extensions as tex
33
from transformer_engine.jax.quantize import (
34
    NoScaleTensor,
35
    ScaledTensor,
36
37
38
    ScaledTensor1x,
    ScaledTensor2x,
    GroupedScaledTensor1x,
39
40
    ScalingMode,
    QuantizerFactory,
41
    QuantizeLayout,
42
    noop_quantizer_set,
43
44
45
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
46
from transformer_engine.jax.dense import dense, grouped_dense
47
from transformer_engine.jax.layernorm_dense import layernorm_dense
48

Tim Moon's avatar
Tim Moon committed
49
50
51
52
53
54
55
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (2048, 1024, 2048),
    (2048, 2048, 1024),
    (2048, 1024, 1024),
]
56
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
57
LN_CASES = [(256, 128), (128, 256)]
58
DTYPES = [jnp.bfloat16, jnp.float32]
59
60
is_fp8_supported, fp8_unsupported_reason = helper.is_fp8_available()
is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
61
62
63
64

supported_scaling_modes = []
""" Find supported scaling modes"""
if is_fp8_supported:
65
    supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING)
66
    supported_scaling_modes.append(ScalingMode.CURRENT_TENSOR_SCALING)
67
if is_mxfp8_supported:
68
    supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING)
69
70
71
72
73
74


def is_shape_supported_by_mxfp8(input_shape):
    try:
        if isinstance(input_shape, type(pytest.param(0))):
            input_shape = input_shape.values[0]
75
        ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape)
76
77
78
79
80
81
        return True
    except:
        # get_scale_shapes will raise an exception if the shape is not supported
        return False


82
83
84
def assert_bitwise_scaled_tensors(
    a: ScaledTensor, b: ScaledTensor, precise_comparison: bool = True
):
85
    if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
86
87
88
89
        if not precise_comparison:
            assert_allclose(a.dequantize(), b.dequantize(), dtype=a.data.dtype)
            return

90
        assert a.scaling_mode == b.scaling_mode
91
        assert a.scale_inv.dtype == b.scale_inv.dtype
92
93
94
95
96
        if a.scaling_mode.is_tensor_scaling():
            # Assert in dq_dtype as some unfused codepaths have an intermediate cast
            # to an input dtype which reduces precision compared to everything in fp32
            assert_allclose(a.scale_inv, b.scale_inv, dtype=a.dq_dtype)
        elif a.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
97
98
99
            # Compare MXFP8 scales as uint8
            assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8))
        else:
100
            raise ValueError(f"Unsupported scaling mode {a.scaling_mode}")
101
        assert_allclose(a.data, b.data)
102

103
    elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x):
104
105
106
107
108
109
        assert_bitwise_scaled_tensors(
            a.rowwise_tensor, b.rowwise_tensor, precise_comparison=precise_comparison
        )
        assert_bitwise_scaled_tensors(
            a.colwise_tensor, b.colwise_tensor, precise_comparison=precise_comparison
        )
110
111
112
113
114
115
    else:
        pytest.fail("Unsupported input types")


def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
    if isinstance(a, ScaledTensor1x):
116
117
118
        if a.data_layout == "T":
            flatten_axis = a.data.ndim - a.flatten_axis
            b_transpose = jnp.transpose(b, (*range(flatten_axis, b.ndim), *range(flatten_axis)))
119
120
121
122
            assert_allclose(a.dequantize(), b_transpose, dtype=a.data.dtype)
        else:
            assert_allclose(a.dequantize(), b, dtype=a.data.dtype)
    elif isinstance(a, ScaledTensor2x):
123
124
        assert_dequantized_scaled_tensor(a.rowwise_tensor, b)
        assert_dequantized_scaled_tensor(a.colwise_tensor, b)
125
126
127
128
    else:
        pytest.fail("a must be a ScaledTensor object")


129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def assert_dequantized_grouped_scaled_tensor(
    a: Union[GroupedScaledTensor1x, ScaledTensor2x], b: jnp.ndarray
):
    if isinstance(a, GroupedScaledTensor1x):
        assert a.group_sizes.sum() == b.shape[0]
        b = jnp.split(b, jnp.cumulative_sum(a.group_sizes)[:-1], axis=0)
        dq_a = a.dequantize()
        for dq_a_i, b_i in zip(dq_a, b):
            if len(dq_a_i) == 0:
                continue
            if a.data_layout == "T":
                data_ndim = len(a.original_shape)
                flatten_axis = a.flatten_axis
                if b_i.shape[0] == 1:
                    b_i = jnp.transpose(
                        b_i, (0, *range(flatten_axis, data_ndim), *range(1, flatten_axis))
                    )
                else:
                    b_i = jnp.transpose(
                        b_i, (*range(flatten_axis, data_ndim), *range(flatten_axis))
                    )
            dq_a_i = dq_a_i.reshape(b_i.shape)
            assert_allclose(dq_a_i, b_i, dtype=a.data.dtype)
    elif isinstance(a, ScaledTensor2x):
153
154
155
156
        assert isinstance(a.rowwise_tensor, GroupedScaledTensor1x)
        assert isinstance(a.colwise_tensor, GroupedScaledTensor1x)
        assert_dequantized_grouped_scaled_tensor(a.rowwise_tensor, b)
        assert_dequantized_grouped_scaled_tensor(a.colwise_tensor, b)
157
158
159
160
    else:
        pytest.fail("a must be a GroupedScaledTensor object")


161
162
163
164
165
166
167
168
169
170
171
172
ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)]
ALL_ACTIVATION_TYPES = [
    ("gelu",),
    ("gelu", "linear"),
    ("silu",),
    ("silu", "linear"),
    ("relu",),
    ("relu", "linear"),
    ("quick_gelu",),
    ("quick_gelu", "linear"),
    ("squared_relu",),
    ("squared_relu", "linear"),
173
    ("clamped_silu", "clamped_linear"),
174
]
175

176
177
178
179
180
181
182
ACTIVATION_TYPES = {
    "L0": [
        ("gelu",),
        ("gelu", "linear"),
    ],
    "L2": ALL_ACTIVATION_TYPES,
}
183
184


185
class TestActivation:
186
187
    def ref_act(self, x, activation_type, act_params):
        return _jax_act_lu(x, activation_type, act_params=act_params).data
188

189
    def value_n_grad_ref_func(self, x, activation_type, act_params):
190
        jitted_reference = jit(
191
192
193
            value_and_grad(
                lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,)
            )
194
195
        )
        return jitted_reference(x)
196

197
198
199
200
    def primitive_func(self, inputs, activation_type, quantizer, act_params):
        out = activation(
            inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params
        )
201
202
203
204
205
206
207
208
209
210
        return jnp.mean(out)

    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper(
        "activation_type",
        (
            ALL_ACTIVATION_TYPES  # Test all activation types for this test to ensure all are functional, then just test a subset for the other tests to verify other functionality
        ),
    )
    def test_act_grad(self, shape, activation_type):
211
        key = jax.random.PRNGKey(0)
212
        x = jax.random.uniform(key, shape, jnp.float32)
213
214
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
215

216
        value_n_grad_primitive_func = jit(
217
            value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3)
218
        )
219
220
221
222
223
224
225
226
227
228
229
230
        act_args = (
            {"limit": 0.75, "alpha": 1.702}
            if activation_type == ("clamped_silu", "clamped_linear")
            else {}
        )
        act_params = (
            tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
            if activation_type == ("clamped_silu", "clamped_linear")
            else None
        )
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params)
        ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)
231
232
        assert_allclose(prim_out, ref_out, dtype=x.dtype)
        assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
233

234
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
235
236
237
    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
238
239
240
241
242
243
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_act_grad_with_tensor_scaling_fp8(
        self, random_inputs, activation_type, output_type, scaling_mode
    ):
244
        x = random_inputs
245
246
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
247
        self.activation_type = activation_type
248

249
        value_n_grad_primitive_func = jit(
250
251
            value_and_grad(self.primitive_func, (0,)),
            static_argnums=(1, 3),
252
        )
253

254
        quantizer = QuantizerFactory.create(
255
            scaling_mode=scaling_mode,
256
            q_dtype=output_type,
257
            q_layout=QuantizeLayout.ROWWISE,
258
        )
259
260
261
262
263
        act_args = (
            {"limit": 0.75, "alpha": 1.702}
            if activation_type == ("clamped_silu", "clamped_linear")
            else {}
        )
264

265
266
267
268
269
270
271
272
273
        act_params = (
            tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
            if activation_type == ("clamped_silu", "clamped_linear")
            else None
        )
        prim_out, (prim_grad,) = value_n_grad_primitive_func(
            x, activation_type, quantizer, act_params
        )
        ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)
274

275
276
        assert_allclose(prim_out, ref_out, dtype=output_type)
        assert_allclose(prim_grad, ref_grad, dtype=output_type)
277

278
    @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
279
280
281
    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
282
283
284
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
285
286
287
288
289
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_act_forward_with_tensor_scaling_fp8(
        self, random_inputs, activation_type, output_type, q_layout, scaling_mode
290
291
    ):
        x = random_inputs
292
293
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
294
        self.activation_type = activation_type
295

296
297
        te_quantizer, jax_quantizer = QuantizerFactory.create(
            n_quantizers=2,
298
            scaling_mode=scaling_mode,
299
            q_dtype=output_type,
300
            q_layout=q_layout,
301
        )
302
303
304
305
306
307
308
309
310
311
312
313
        act_args = (
            {"limit": 0.75, "alpha": 1.702}
            if activation_type == ("clamped_silu", "clamped_linear")
            else {}
        )
        act_params = (
            tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
            if activation_type == ("clamped_silu", "clamped_linear")
            else None
        )
        te_output = tex.act_lu(x, activation_type, te_quantizer, act_params)
        jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params)
314
        assert_bitwise_scaled_tensors(te_output, jax_output)
315

316
    @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
317
    @pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)])
318
319
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
320
321
322
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
323
    def test_act_forward_with_block_scaling_fp8(
324
        self, random_inputs, activation_type, output_type, q_layout
325
326
    ):
        x = random_inputs
327
        x = jnp.repeat(x, len(activation_type), axis=-2)
328
        self.activation_type = activation_type
329

330
        quantizer = QuantizerFactory.create(
331
            scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
332
        )
333
334
335
336
337
338
339
340
341
342
343
344
        act_args = (
            {"limit": 0.75, "alpha": 1.702}
            if activation_type == ("clamped_silu", "clamped_linear")
            else {}
        )
        act_params = (
            tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
            if activation_type == ("clamped_silu", "clamped_linear")
            else None
        )
        output = tex.act_lu(x, activation_type, quantizer, act_params)
        ref_out = self.ref_act(x, activation_type, act_params)
345
        assert_dequantized_scaled_tensor(output, ref_out)
346
347


348
349
350
351
NORM_OUTPUT_DTYPES = {
    "L0": [jnp.float8_e4m3fn],
    "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
352

353

354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
@pytest_parametrize_wrapper("n, hidden", LN_CASES)
@pytest_parametrize_wrapper("inp_dtype", DTYPES)
@pytest_parametrize_wrapper("norm_type", ["layernorm", "rmsnorm"])
@pytest_parametrize_wrapper(
    "zero_centered_gamma",
    [
        pytest.param(True, id="zero_centered"),
        pytest.param(False, id="no_zero_centered"),
    ],
)
@pytest_parametrize_wrapper("epsilon", [1e-2, 1e-6])
class TestNorm:
    """
    Test transformer_engine.jax.layernorm APIs
    """
369

370
371
372
373
374
375
376
377
378
379
380
381
382
    def _test_norm_grad(
        self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
    ):
        def compute_loss(x):
            # Higher precision to compute the loss
            x_ = x.astype(jnp.float32)
            return jnp.mean(jnp.square(x_)).astype(x.dtype)

        def reference_func(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer):
            if norm_type == "rmsnorm":
                ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
            else:
                ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
383
384
            # This is a no-op for non-quantized data
            ln_out = ln_out.dequantize()
385
            return ln_out
386

387
388
389
390
391
392
393
394
395
396
397
398
399
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 3)

        x = jax.random.uniform(subkeys[0], (n, hidden), jnp.float32, -1, 1)
        x = x.astype(inp_dtype)
        gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
        gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
        gamma = jnp.asarray(gamma, inp_dtype)
        if norm_type == "layernorm":
            beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
            beta = jnp.asarray(beta, inp_dtype)
        else:
            beta = None
400

401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
        jitted_reference = jit(
            value_and_grad(
                lambda x, gamma, beta: compute_loss(
                    reference_func(
                        x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer=None
                    )
                ),
                (0, 1, 2),
            )
        )
        jitted_primitive = jit(
            value_and_grad(
                lambda x, gamma, beta: compute_loss(
                    layernorm(x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer)
                ),
                (0, 1, 2),
417
            )
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
        )

        reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference(
            x, gamma, beta
        )
        primitive_out, (primitive_dx, primitive_dgamma, primitive_dbeta) = jitted_primitive(
            x, gamma, beta
        )

        out_dtype = inp_dtype if quantizer is None else quantizer.q_dtype
        assert_allclose(primitive_out, reference_out, dtype=out_dtype)
        assert_allclose(primitive_dx, reference_dx, dtype=out_dtype)
        assert_allclose(primitive_dgamma, reference_dgamma, dtype=out_dtype)
        if beta is not None:
            assert_allclose(primitive_dbeta, reference_dbeta, dtype=out_dtype)
433

434
435
436
437
438
439
440
441
442
443
    def test_norm_grad(self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype):
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        self._test_norm_grad(
            n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer=None
        )
444

445
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
446
447
    # No Norm FWD E5M2 in TE backend
    @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
448
449
450
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
451
452
453
454
455
456
457
458
459
460
461
462
463
464
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_norm_grad_with_tensor_scaling_fp8(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        q_layout,
        scaling_mode,
465
466
467
468
469
470
471
472
    ):
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        quantizer = QuantizerFactory.create(
473
            scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
        )
        self._test_norm_grad(
            n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
        )

    def _test_norm_forward(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        scaling_mode,
489
        q_layout,
490
    ):
491
        key = jax.random.PRNGKey(0)
492
        subkeys = jax.random.split(key, 3)
493

494
495
496
497
498
        x = jax.random.uniform(subkeys[0], (n, hidden), inp_dtype, -1, 1)
        x = jnp.asarray(x, inp_dtype)
        gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
        gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
        gamma = jnp.asarray(gamma, inp_dtype)
499

500
        quantizer, ref_quantizer = QuantizerFactory.create(
501
            n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
502
503
504
505
506
507
        )
        if norm_type == "layernorm":
            beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
            beta = jnp.asarray(beta, inp_dtype)
            output, mu, rsigma = tex.layernorm_fwd(
                x, gamma, beta, zero_centered_gamma, epsilon, quantizer=quantizer
508
            )
509
            ref_out, ref_mu, ref_rsigma = _jax_layernorm(
510
511
512
513
514
515
                x,
                gamma,
                beta,
                zero_centered_gamma,
                epsilon,
                quantizer=ref_quantizer,
516
            )
517
518
519
        else:
            output, rsigma = tex.rmsnorm_fwd(
                x, gamma, zero_centered_gamma, epsilon, quantizer=quantizer
520
            )
521
            ref_out, ref_rsigma = _jax_rmsnorm(
522
523
524
525
526
                x,
                gamma,
                zero_centered_gamma,
                epsilon,
                quantizer=ref_quantizer,
527
            )
528
            ref_mu = None
529

530
531
532
        precise_comparison = True

        if get_cudnn_version() < (9, 10, 0) and scaling_mode == ScalingMode.MXFP8_1D_SCALING:
533
534
            # Reduce precision of test as we don't use fused norm below this version CuDNN for MXFP8 and instead
            # do an unfused norm and quantize with an intermediate cast into in_dtype which can reduce precision
535
536
537
538
539
540
541
542
543
544
545
            precise_comparison = False
        elif is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode):
            # Larger tolerances as our JAX implementation _jax_*norm uses the compute dtype float32
            # for zero-centered gamma always
            precise_comparison = False
        elif scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING and inp_dtype != jnp.float32:
            # Current implementation of Current Tensor Scaling performs unfused layernorm and quantization
            # and writes intermediate results into the input dtype, which will slightly reduce precision
            # if the input dtype is not float32
            precise_comparison = False

546
        assert_bitwise_scaled_tensors(output, ref_out, precise_comparison=precise_comparison)
547

548
549
550
        assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype)
        if norm_type == "layernorm":
            assert_allclose(mu, ref_mu, dtype=inp_dtype)
551

552
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
553
554
    # No Norm FWD E5M2 in TE backend
    @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
555
556
557
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
558
559
560
561
562
563
564
565
566
567
568
569
570
571
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_norm_forward_with_tensor_scaling_fp8(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        q_layout,
        scaling_mode,
572
573
574
575
576
577
578
579
580
581
582
583
    ):
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        self._test_norm_forward(
            n=n,
            hidden=hidden,
            norm_type=norm_type,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon,
            inp_dtype=inp_dtype,
            out_dtype=out_dtype,
584
            scaling_mode=scaling_mode,
585
            q_layout=q_layout,
586
        )
587

588
    @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
589
590
591
592
593
594
595
596
597
598
599
600
    @pytest.mark.parametrize("out_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    def test_norm_forward_with_block_scaling_fp8(
        self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype
    ):
        self._test_norm_forward(
            n=n,
            hidden=hidden,
            norm_type=norm_type,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon,
            inp_dtype=inp_dtype,
            out_dtype=out_dtype,
601
            scaling_mode=ScalingMode.MXFP8_1D_SCALING,
602
            q_layout=QuantizeLayout.ROWWISE_COLWISE,
603
        )
604
605


606
607
608
609
QUANTIZE_OUTPUT_DTYPES = {
    "L0": [jnp.float8_e4m3fn],
    "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
610

611
612
613
ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [
    ((32, 64), -1),
    ((2, 64, 32), -1),
614
    ((64, 2, 32), -2),
615
616
617
618
619
    ((32, 256, 128), -1),
    ((32, 256, 128), -2),
    ((64, 32, 32, 256), -1),
    ((64, 32, 32, 256), -2),
    ((64, 32, 32, 256), -3),
620
]
621

622
QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = {
623
    "L0": [
624
625
        ((32, 64), -1),
        ((2, 64, 32), -1),
626
        ((64, 2, 32), -2),
627
    ],
628
    "L2": ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES,
629
}
630

631
632
633
634
635
636
QUANTIZATION_INPUT_DTYPE = {
    "L0": [jnp.bfloat16],
    "L2": [jnp.float32, jnp.float16, jnp.bfloat16],
}


637
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
638
639
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
640
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
641
642
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper(
643
    "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
644
645
646
647
648
649
)
class TestQuantize:
    """
    Purely quantization related tests that will always test on a wider set of types and shapes
    """

650
    def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
651
        key = jax.random.PRNGKey(0)
652

653
654
655
656
        # Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling)
        quantizer = QuantizerFactory.create(
            scaling_mode=scaling_mode,
            q_dtype=q_dtype,
657
            q_layout=q_layout,
658
        )
659

660
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
661
662
663
        for _ in range(n_iterations):
            x = jax.random.uniform(key, input_shape, in_dtype)

664
            scaled_tensor = quantizer.quantize(x, flatten_axis=flatten_axis)
665
666
            assert_dequantized_scaled_tensor(scaled_tensor, x)

667
668
669
    def test_quantize_bitwise(
        self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
    ):
670
671
672
673
674

        key = jax.random.PRNGKey(0)
        input = jax.random.uniform(key, input_shape, in_dtype)

        te_quantizer, jax_quantizer = QuantizerFactory.create(
675
            n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout
676
        )
677

678
        jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
679

680
681
        te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
        assert_bitwise_scaled_tensors(te_output, jax_output)
682
683


684
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("flatten_axis", [-1])
@pytest_parametrize_wrapper("with_group_sizes", [True, False])
@pytest_parametrize_wrapper(
    "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE]
)
class TestGroupedQuantize:
    def test_grouped_qdq(
        self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis, with_group_sizes
    ):
        n_groups, m, n = input_shape
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        # *32 so that the input shapes works for MXFP8
        input_shape = (m * 32, n)

        if with_group_sizes:
            group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
            group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
            group_sizes = jnp.diff(group_sizes)
            assert group_sizes.sum() == m
            assert jnp.any(group_sizes == 0)  # make sure that at least one group has 0 row
            group_sizes = group_sizes * 32
        else:
            group_sizes = None
            input_shape = (n_groups, input_shape[0] // n_groups, input_shape[1])

        if flatten_axis == -2:
            input_shape = input_shape[:-1] + (2,) + input_shape[-1:]

        x = jax.random.uniform(subkeys[1], input_shape, in_dtype)

        grouped_quantizer = QuantizerFactory.create(
            scaling_mode=scaling_mode,
            q_dtype=q_dtype,
            q_layout=q_layout,
            n_groups=n_groups,
        )

        scaled_tensor = tex.grouped_quantize(
            x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer
        )

        assert_dequantized_grouped_scaled_tensor(scaled_tensor, x)


735
736
737
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
class TestFusedQuantize:

738
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
739
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
740
    @pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
741
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
742
743
744
745
746
747
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
    def test_quantize_dbias(
        self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis
    ):
748
        if scaling_mode == ScalingMode.MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8(
749
750
751
752
753
754
755
756
            input_shape
        ):
            pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")

        key = jax.random.PRNGKey(0)
        input = jax.random.uniform(key, input_shape, in_dtype)

        jax_quantizer, te_quantizer = QuantizerFactory.create(
757
            n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
758
        )
759

760
761
762
763
764
        te_output, te_dbias = jit(
            lambda input: tex.quantize_dbias(
                input, quantizer=te_quantizer, flatten_axis=flatten_axis
            )
        )(input)
765
766
767

        jax_output, jax_dbias = jit(
            lambda input: _jax_quantize_dbias(
768
                input, quantizer=jax_quantizer, flatten_axis=flatten_axis
769
            )
770
        )(input)
771

772
        assert_bitwise_scaled_tensors(te_output, jax_output)
773

774
        assert_allclose(te_dbias, jax_dbias)
775
776

    def _test_quantize_dact_dbias(
777
        self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout
778
    ):
779

780
781
782
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1)
783
784
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
785
        dz = jax.random.uniform(subkeys[1], input_shape, in_dtype, -1, 1)
786

787
        jax_quantizer, te_quantizer = QuantizerFactory.create(
788
            n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
        )
        is_casted_output = te_quantizer is not None

        te_output, te_dbias = jit(
            lambda dz, x: tex.quantize_dact_dbias(
                dz,
                x,
                activation_type=activation_type,
                is_dbias=is_dbias,
                quantizer=te_quantizer,
            )
        )(dz, x)

        jax_output, jax_dbias = jit(
            lambda dz, x: _jax_quantize_dact_dbias(
                dz,
                x,
                activation_type=activation_type,
                is_dbias=is_dbias,
                quantizer=jax_quantizer,
            )
        )(dz, x)
811

812
        if is_casted_output:
813
814
815
816
817
818
819
            # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation
            precise_comparison = not (
                in_dtype != jnp.float32 and scaling_mode.is_1d_block_scaling()
            )
            assert_bitwise_scaled_tensors(
                te_output, jax_output, precise_comparison=precise_comparison
            )
820
        else:
821
822
823
            assert isinstance(te_output, NoScaleTensor)
            assert isinstance(jax_output, NoScaleTensor)
            assert_allclose(te_output.data, jax_output.data)
824
825

        if is_dbias:
826
            precise_comparison = not (
827
828
829
830
                # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
                (in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling())
                # Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently.
                or (
831
                    activation_type in {("squared_relu",), ("clamped_silu", "clamped_linear")}
832
833
834
                    and in_dtype == jnp.bfloat16
                    and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
                )
835
836
837
838
            )
            assert_allclose(
                te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype
            )
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853

    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("is_dbias", [True, False])
    def test_quantize_dact_dbias_no_quantization(
        self,
        in_dtype,
        input_shape,
        activation_type,
        is_dbias,
    ):
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=in_dtype,
854
            scaling_mode=ScalingMode.NO_SCALING,
855
856
            activation_type=activation_type,
            is_dbias=is_dbias,
857
            q_layout=QuantizeLayout.ROWWISE,
858
        )
859

860
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
861
862
863
864
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
    @pytest_parametrize_wrapper("is_dbias", [True, False])
865
    @pytest_parametrize_wrapper(
866
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
867
    )
868
869
870
871
872
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_quantize_dact_dbias_tensor_scaling(
        self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout, scaling_mode
873
874
875
876
877
    ):
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=out_dtype,
878
            scaling_mode=scaling_mode,
879
880
            activation_type=activation_type,
            is_dbias=is_dbias,
881
            q_layout=q_layout,
882
        )
883

884
    @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
885
886
887
888
889
890
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper(
        "input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)]
    )
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
    @pytest_parametrize_wrapper("is_dbias", [True, False])
891
892
893
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
894
    def test_quantize_dact_dbias_mxfp8_scaling(
895
        self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout
896
897
898
899
900
901
902
903
    ):
        if reduce(operator.mul, input_shape[:-1]) % 128 != 0 or input_shape[-1] % 128 != 0:
            # TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes.
            # If it doesn't, move this check into the quantize_dact_dbias function and revert to JAX
            # implementation in the unsupported cases
            pytest.skip(
                f"Input shape {input_shape} is not supported by dact MXFP8 kernel in TE currently"
            )
904

905
906
907
908
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=out_dtype,
909
            scaling_mode=ScalingMode.MXFP8_1D_SCALING,
910
911
            activation_type=activation_type,
            is_dbias=is_dbias,
912
            q_layout=q_layout,
913
        )
914
915


Alp Dener's avatar
Alp Dener committed
916
917
918
919
920
921
922
valid_fp8_gemm_operand_types = [
    (jnp.float8_e4m3fn, jnp.float8_e4m3fn),
    (jnp.float8_e5m2, jnp.float8_e4m3fn),
    (jnp.float8_e4m3fn, jnp.float8_e5m2),
]


923
class TestDense:
924
925
    def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
        if data_layout[0] == "T":
926
            a = jnp.swapaxes(a, -1, -2)
927
        if data_layout[1] == "T":
928
929
            b = jnp.swapaxes(b, -1, -2)
        return jnp.dot(a, b)
930

931
    def _generate_gemm_input(self, m, n, k, data_layout):
932
933
934
935
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        x = jax.random.uniform(
            subkeys[0],
936
            (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
937
938
939
940
            dtype=jnp.bfloat16,
        ) / jnp.sqrt(k)
        w = jax.random.uniform(
            subkeys[1],
941
            (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
942
943
            dtype=jnp.bfloat16,
        ) / jnp.sqrt(n)
944
945
        lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
        rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
946
947
948
949
        contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)

        return (x, w, contracting_dims)

950
951
952
953
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
    @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
    def test_gemm_bf16(self, m, n, k, data_layout):
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
954

Alp Dener's avatar
Alp Dener committed
955
        primitive_out = tex.gemm(x, w, contracting_dims=contracting_dims)
956
        ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
957

958
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
959

960
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
961
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
Alp Dener's avatar
Alp Dener committed
962
    @pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types)
963
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
964
    @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
Alp Dener's avatar
Alp Dener committed
965
966
967
968
969
970
971
972
973
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
    def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm):
        if (
            not with_jax_gemm
            and scaling_mode.is_1d_block_scaling()
            and jnp.float8_e5m2 in (x_qtype, w_qtype)
        ):
            pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.")

974
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
975
        quantizer_set = QuantizerFactory.create_set(
Alp Dener's avatar
Alp Dener committed
976
977
978
979
            scaling_mode=scaling_mode,
            fwd_dtype=jnp.float8_e4m3fn,
            bwd_dtype=jnp.float8_e5m2,
            is_2x2x=False,
980
        )
Alp Dener's avatar
Alp Dener committed
981
982
983
984
985
986
987
988
989
990
991
992
        with use_jax_gemm(enabled=with_jax_gemm):
            primitive_out = tex.gemm(
                x,
                w,
                contracting_dims=contracting_dims,
                lhs_quantizer=(
                    quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad
                ),
                rhs_quantizer=(
                    quantizer_set.kernel if w_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad
                ),
            )
993
        ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
994

Alp Dener's avatar
Alp Dener committed
995
        assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
996

997
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
998
    def test_dense_grad_bf16(self, m, n, k):
999
1000
        data_layout = "NN"
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
1001

1002
1003
1004
        def primitive_func(x, w, contracting_dims):
            primitive_out = dense(x, w, contracting_dims=contracting_dims)
            return jnp.mean(primitive_out)
1005

1006
1007
        def ref_func(x, w, data_layout):
            return jnp.mean(self._ref_gemm_with_jnp_dot(x, w, data_layout))
1008

1009
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))
1010

1011
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))
1012

1013
1014
        primitive_out, (primitive_x_grad, primitive_w_grad) = value_n_grad_primitive_func(
            x, w, contracting_dims
1015
        )
1016
        ref_out, (ref_x_grad, ref_w_grad) = value_n_grad_ref_func(x, w, data_layout)
1017
1018
1019
1020

        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
        assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16)
        assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16)
1021

1022
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1023
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
1024
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
Alp Dener's avatar
Alp Dener committed
1025
1026
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
    def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm):
1027
1028
        data_layout = "NN"
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
1029
1030
1031
1032
1033
1034
1035
1036
1037

        key = jax.random.PRNGKey(1)
        bias = jax.random.uniform(key, n, dtype=jnp.bfloat16)

        def primitive_func(x, w, bias, contracting_dims, quantizer_set):
            primitive_out = dense(
                x, w, bias, contracting_dims=contracting_dims, quantizer_set=quantizer_set
            )
            return jnp.mean(primitive_out)
1038

1039
        def ref_func(x, w, bias, data_layout):
1040
            return jnp.mean(
1041
                self._ref_gemm_with_jnp_dot(x, w, data_layout) + jnp.expand_dims(bias, axis=0)
1042
            )
1043

1044
1045
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
1046

1047
        quantizer_set = QuantizerFactory.create_set(
Alp Dener's avatar
Alp Dener committed
1048
1049
1050
1051
            scaling_mode=scaling_mode,
            fwd_dtype=jnp.float8_e4m3fn,
            bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
            is_2x2x=True,
1052
        )
1053

1054
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
Alp Dener's avatar
Alp Dener committed
1055
1056
1057
1058
1059
        with use_jax_gemm(enabled=with_jax_gemm):
            for _ in range(n_iterations):
                primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = (
                    value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set)
                )
1060

1061
1062
1063
        ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func(
            x, w, bias, data_layout
        )
1064

Alp Dener's avatar
Alp Dener committed
1065
1066
1067
1068
        assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
        assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
        assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.float8_e5m2)
        assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp.float8_e5m2)
1069
1070


1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
    key = jax.random.PRNGKey(0)
    subkeys = jax.random.split(key, 4)
    out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
    return out


def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer):
    if norm_type == "rmsnorm":
        ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
    else:
        ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
1084
    ln_out = ln_out.dequantize()
1085
1086
1087
1088
    return ln_out


class TestFusedDense:
1089
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1090
    @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
1091
1092
    @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
    @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
Alp Dener's avatar
Alp Dener committed
1093
1094
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
    def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_gemm):
1095
        """
1096
        Test layernorm_dense VJP Rule
1097
        """
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
        # zero_centered_gamma is already tested in TestNorm
        zero_centered_gamma = False
        eps = 1e-6

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 4)

        # NN in FWD
        x = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16) / jnp.sqrt(k)
        w = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16) / jnp.sqrt(n)

        gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)

        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode,
Alp Dener's avatar
Alp Dener committed
1113
1114
            fwd_dtype=jnp.float8_e4m3fn,
            bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
1115
1116
1117
1118
1119
            is_2x2x=True,
        )

        if norm_type == "layernorm":
            beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
1120
        else:
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
            beta = None

        def prim_func(x, w, gamma, beta):
            # bias = None as quantize_dbias is already tested in test_dense_grad_fp8
            prim_out = layernorm_dense(
                x,
                w,
                gamma,
                beta,
                None,
                norm_type,
                zero_centered_gamma,
                eps,
                quantizer_set=quantizer_set,
            )
            return jnp.mean(prim_out)

        def ref_func(x, w, gamma, beta):
            x = _ref_jax_norm_impl(
                x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
            )
            return jnp.mean(jnp.dot(x, w))

        value_n_grad_prim_func = value_and_grad(prim_func, (0, 1, 2, 3))
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))

        ref_out, (ref_x_grad, ref_w_grad, ref_gamma_grad, ref_beta_grad) = value_n_grad_ref_func(
            x, w, gamma, beta
        )

1151
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
Alp Dener's avatar
Alp Dener committed
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
        with use_jax_gemm(enabled=with_jax_gemm):
            for _ in range(n_iterations):
                prim_out, (
                    prim_x_grad,
                    prim_w_grad,
                    prim_gamma_grad,
                    prim_beta_grad,
                ) = value_n_grad_prim_func(x, w, gamma, beta)

        assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn)
        assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
        assert_allclose(prim_w_grad, ref_w_grad, dtype=jnp.float8_e5m2)
        assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2)
1165
        if beta is not None:
Alp Dener's avatar
Alp Dener committed
1166
            assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp.float8_e5m2)
1167

1168
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1169
    @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
1170
1171
1172
    @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
    @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
Alp Dener's avatar
Alp Dener committed
1173
1174
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
1175
    def test_layernorm_mlp_grad(
Alp Dener's avatar
Alp Dener committed
1176
        self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm
1177
    ):
1178
        """
1179
        Test layernorm_mlp VJP Rule
1180
        """
1181
1182
1183
1184
1185
1186
1187
1188
1189
        # zero_centered_gamma is already tested in TestNorm
        zero_centered_gamma = False
        eps = 1e-6

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 6)

        x = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        kernel_1 = jax.random.normal(
1190
            subkeys[1], (k, len(activation_type), n), jnp.bfloat16
1191
1192
1193
1194
1195
        ) / jnp.sqrt(k)
        kernel_2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
        gamma = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
        beta = None  # was tested in TestNorm
        if use_bias:
1196
            bias_1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
1197
1198
1199
1200
1201
1202
1203
1204
            bias_2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
        else:
            bias_1 = None
            bias_2 = None

        quantizer_sets = QuantizerFactory.create_set(
            n_quantizer_sets=2,
            scaling_mode=scaling_mode,
Alp Dener's avatar
Alp Dener committed
1205
1206
            fwd_dtype=jnp.float8_e4m3fn,
            bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
            is_2x2x=True,
        )

        if norm_type == "layernorm":
            beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
        else:
            beta = None

        def prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            return jnp.mean(
                layernorm_mlp(
                    x,
                    gamma,
                    beta,
                    [kernel_1, kernel_2],
                    [bias_1, bias_2],
                    norm_type,
                    zero_centered_gamma=zero_centered_gamma,
                    epsilon=eps,
                    activation_type=activation_type,
                    quantizer_sets=quantizer_sets,
1228
1229
                )
            )
1230

1231
1232
1233
        def _ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            ln_out = _ref_jax_norm_impl(
                x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
1234
            )
Alp Dener's avatar
Alp Dener committed
1235
            linear_1_out = jax.lax.dot_general(ln_out, kernel_1, (((1,), (0,)), ((), ())))
1236
1237
1238
1239
            if use_bias:
                bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
                linear_1_out += jnp.reshape(bias_1, bias_1_shape)

1240
            x = _jax_act_lu(linear_1_out, activation_type).data
Alp Dener's avatar
Alp Dener committed
1241
            linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ())))
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
            if use_bias:
                bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape
                linear_2_out += jnp.reshape(bias_2, bias_2_shape)

            return linear_2_out

        def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            return jnp.mean(_ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2))

        value_n_grad_prim_func = value_and_grad(prim_func, range(6))
        value_n_grad_ref_func = value_and_grad(ref_func, range(6))

1254
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
Alp Dener's avatar
Alp Dener committed
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
        with use_jax_gemm(enabled=with_jax_gemm):
            for _ in range(n_iterations):
                prim_out, (
                    prim_x_grad,
                    prim_gamma_grad,
                    prim_kernel_1_grad,
                    prim_kernel_2_grad,
                    prim_bias_1_grad,
                    prim_bias_2_grad,
                ) = value_n_grad_prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274

        ref_out, (
            ref_x_grad,
            ref_gamma_grad,
            ref_kernel_1_grad,
            ref_kernel_2_grad,
            ref_bias_1_grad,
            ref_bias_2_grad,
        ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)

Alp Dener's avatar
Alp Dener committed
1275
        assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn)
1276

Alp Dener's avatar
Alp Dener committed
1277
        assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=jnp.float8_e5m2)
1278
        if use_bias:
Alp Dener's avatar
Alp Dener committed
1279
            assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=jnp.float8_e5m2)
1280

Alp Dener's avatar
Alp Dener committed
1281
        assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=jnp.float8_e5m2)
1282
        if use_bias:
Alp Dener's avatar
Alp Dener committed
1283
            assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=jnp.float8_e5m2)
1284

Alp Dener's avatar
Alp Dener committed
1285
1286
        assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2)
        assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
1287
1288
1289
1290
1291
1292
1293
1294
1295


# E5M2 * E5M2 is not supported
fwd_bwd_dtypes = [
    [jnp.float8_e4m3fn, jnp.float8_e4m3fn],
    [jnp.float8_e4m3fn, jnp.float8_e5m2],
    [jnp.float8_e5m2, jnp.float8_e4m3fn],
]

1296
1297
1298
1299
1300
1301
1302
1303
1304
GROUPED_DENSE_INPUT_SHAPES = [
    # (n_groups, m, n, k), the actual m will be multiplied by 32
    (5, 32, 128, 64),  # Test the case where n_groups is not a multiple of 4
    (8, 64, 32, 128),
    (8, 64, 128, 256),
]


@pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES)
1305
class TestGroupedDense:
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
    def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims):
        lhs_contract_dim, _ = contracting_dims
        assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3
        if bias is None:
            bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype)
        else:
            assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2])
        remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop()
        lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis)
        rhs = jnp.split(rhs, rhs.shape[0], axis=0)
        bias = jnp.split(bias, bias.shape[0], axis=0)
        ref_out = []
        dim_num = (contracting_dims, ((), ()))
        for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias):
1320
1321
1322
            out_i = jax.lax.dot_general(
                lhs_i, rhs_i, dim_num, precision=jax.lax.Precision.HIGHEST
            ) + jnp.expand_dims(bias_i, axis=0)
1323
1324
1325
1326
            ref_out.append(jnp.squeeze(out_i))
        return ref_out

    def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False):
1327
        key = jax.random.PRNGKey(0)
1328
1329
1330
1331
1332
1333
        subkeys = jax.random.split(key, 4)
        n_groups, m, n, k = input_shape

        group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
        group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
        group_sizes = jnp.diff(group_sizes)
1334
1335
1336
        # Make one empty input lhs to test empty GEMM handling
        group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1])
        group_sizes = group_sizes.at[1].set(0)
1337
1338
1339
1340
1341
        assert group_sizes.sum() == m

        # *32 to make sure that input shape works for MXFP8
        group_sizes = group_sizes * 32
        m = m * 32
1342

1343
1344
1345
        lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m)
        rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k)
        bias_shape = (n_groups, n)
1346

1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
        lhs = jax.random.uniform(subkeys[1], lhs_shape, dtype=dtype)
        rhs = jax.random.uniform(subkeys[2], rhs_shape, dtype=dtype)
        bias = jax.random.uniform(subkeys[3], bias_shape, dtype=dtype) if with_bias else None

        lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
        rhs_contracting_dim = (1,) if data_layout[1] == "N" else (2,)
        contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)

        return lhs, rhs, group_sizes, contracting_dims, bias

    def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype):
        assert out.dtype == ref_list[0].dtype
        out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0)
        for i in range(len(ref_list)):
            assert_allclose(out_list[i], ref_list[i], dtype=dtype)
1362
1363

    @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
1364
1365
1366
1367
    @pytest_parametrize_wrapper("layout", ["NN"])
    def test_grouped_gemm_fp16(self, dtype, input_shape, layout):
        lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
            dtype, input_shape, layout
1368
        )
1369
1370
1371
1372
1373
        num_gemms = input_shape[0]
        _ = jax.jit(tex.grouped_gemm_copy_group_sizes, static_argnames=("num_gemms",))(
            group_sizes,
            num_gemms=num_gemms,
        )
1374
        ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
1375
1376

        # jitting grouped_gemm
1377
1378
1379
        prim_out = jax.jit(
            tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes")
        )(
1380
1381
1382
1383
            lhs,
            rhs,
            group_sizes,
            contracting_dims,
1384
            use_async_d2h_group_sizes=True,
1385
        )
1386

1387
        self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
1388

1389
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1390
1391
    @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
1392
1393
    @pytest_parametrize_wrapper("layout", ["NN"])
    def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout):
1394
1395
        fwd_dtype, bwd_dtype = fwd_bwd_dtype
        quantizer_set = QuantizerFactory.create_set(
1396
1397
1398
1399
1400
            scaling_mode=scaling_mode,
            fwd_dtype=fwd_dtype,
            bwd_dtype=bwd_dtype,
            is_2x2x=False,
            n_groups=input_shape[0],
1401
        )
1402

1403
1404
1405
1406
1407
1408
        # quantizer_set.{x, kernel} has fwd_dtype, while quantizer_set.grad has bwd_dtype
        # We want to test E4M3 * E5M2, manually set the quantizer_set.kernel.q_dtype to bwd_dtype
        quantizer_set.kernel.q_dtype = bwd_dtype
        for quantizer in quantizer_set.kernel.quantizers:
            quantizer.q_dtype = bwd_dtype

1409
        out_dtype = jnp.bfloat16
1410
1411
1412
1413
        lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
            out_dtype, input_shape, layout
        )
        ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
1414

1415
        prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
1416
            lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
1417
1418
1419
        )

        allclose_dtype = jnp.float8_e4m3fn
1420
        if jnp.float8_e5m2 in fwd_bwd_dtype:
1421
1422
            allclose_dtype = jnp.float8_e5m2

1423
        self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, allclose_dtype)
1424

1425
1426
1427
    def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims):
        out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims)
        # Note: we use jnp.sum instead of jnp.mean to make the gradient larger
1428
1429
        # and prevent them from being clamp to zero in FP8. / sqrt(x.size) is used to
        # normalize the output and prevent the gradient from being too large for FP8.
1430
        out_sum_list = [jnp.sum(out) for out in out_list]
1431
        return jnp.sum(jnp.asarray(out_sum_list)) / jnp.sqrt(x.size)
1432
1433
1434
1435
1436
1437

    def _primitive_sum_grouped_dense(
        self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set
    ):
        out = grouped_dense(
            x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set
1438
        )
1439
        return jnp.sum(jnp.asarray(out)) / jnp.sqrt(x.size)
1440

1441
1442
1443
1444
1445
1446
1447
    @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
    def test_grouped_dense_grad_fp16(self, dtype, input_shape):
        x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
            dtype,
            input_shape,
            with_bias=True,
        )
1448

1449
        value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
1450
        # jitting the grouped_dense
1451
1452
1453
        value_n_grad_prim_func = jit(
            value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,)
        )
1454

1455
1456
        ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
            x, kernel, bias, group_sizes, contracting_dims
1457
        )
1458
1459
        prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func(
            x, kernel, bias, group_sizes, contracting_dims
1460
        )
1461

1462
1463
1464
1465
        assert_allclose(prim_out_sum, ref_out_sum, dtype=dtype)
        assert_allclose(prim_dgrad, ref_dgrad, dtype=dtype)
        assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype)
        assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
1466

1467
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1468
1469
1470
1471
    @pytest.mark.parametrize(
        "fwd_bwd_dtype",
        [(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)],
    )
1472
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
1473
1474
1475
1476
1477
1478
1479
    def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
        fwd_dtype, bwd_dtype = fwd_bwd_dtype
        dtype = jnp.bfloat16
        x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
            dtype,
            input_shape,
            with_bias=True,
1480
        )
1481

1482
1483
1484
1485
1486
1487
        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode,
            fwd_dtype=fwd_dtype,
            bwd_dtype=bwd_dtype,
            is_2x2x=True,
            n_groups=group_sizes.size,
1488
        )
1489
        value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
1490
1491

        # jitting the grouped_dense
1492
1493
1494
        value_n_grad_prim_func = jit(
            value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,)
        )
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504

        ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
            x,
            kernel,
            bias,
            group_sizes,
            contracting_dims,
        )
        prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func(
            x, kernel, bias, group_sizes, contracting_dims, quantizer_set=quantizer_set
1505
1506
        )

1507
1508
1509
1510
        assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype)
        assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype)
        assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype)
        assert_allclose(prim_dbias, ref_dbias, dtype=dtype)