test_custom_call_compute.py 53.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.

import jax
import jax.numpy as jnp
7
import numpy as np
8
9
import pytest
from jax import jit, value_and_grad
10
11
12
13
14
15
16
17
18
19
20
21
from functools import reduce
import operator

from utils import (
    assert_allclose,
    assert_tree_like_allclose,
    pytest_parametrize_wrapper,
)
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.layernorm_mlp import layernorm_mlp

from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu, _jax_quantize_dact_dbias
22
23
24
25
26
from transformer_engine.jax.cpp_extensions.normalization import (
    _jax_layernorm,
    _jax_rmsnorm,
    is_norm_zero_centered_gamma_in_weight_dtype,
)
27
28
29
from transformer_engine.jax.cpp_extensions.quantization import (
    _jax_quantize,
    _jax_quantize_dbias,
30
)
31
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
32
from transformer_engine.jax import cpp_extensions as tex
33
34
35
36
37
from transformer_engine.jax.quantize import (
    DelayedScaleQuantizer,
    ScaledTensor,
    ScalingMode,
    QuantizerFactory,
38
    QuantizeLayout,
39
40
41
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
42
from transformer_engine.jax.dense import dense
43
44
from transformer_engine.jax.layernorm_dense import layernorm_dense
from transformer_engine.jax.quantize import ScaledTensor1x, ScaledTensor2x
45

Tim Moon's avatar
Tim Moon committed
46
47
48
49
50
51
52
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (2048, 1024, 2048),
    (2048, 2048, 1024),
    (2048, 1024, 1024),
]
53
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
54
LN_CASES = [(256, 128), (128, 256)]
55
DTYPES = [jnp.bfloat16, jnp.float32]
56
is_fp8_supported, reason = helper.is_fp8_available()
57
is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
58
59
60
61

supported_scaling_modes = []
""" Find supported scaling modes"""
if is_fp8_supported:
62
    supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING)
63
    supported_scaling_modes.append(ScalingMode.CURRENT_TENSOR_SCALING)
64
if is_mxfp8_supported:
65
    supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING)
66
67
68
69
70
71


def is_shape_supported_by_mxfp8(input_shape):
    try:
        if isinstance(input_shape, type(pytest.param(0))):
            input_shape = input_shape.values[0]
72
        ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape)
73
74
75
76
77
78
79
80
        return True
    except:
        # get_scale_shapes will raise an exception if the shape is not supported
        return False


def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
    if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
81
        assert a.scaling_mode == b.scaling_mode
82
        assert a.scale_inv.dtype == b.scale_inv.dtype
83
84
85
86
87
        if a.scaling_mode.is_tensor_scaling():
            # Assert in dq_dtype as some unfused codepaths have an intermediate cast
            # to an input dtype which reduces precision compared to everything in fp32
            assert_allclose(a.scale_inv, b.scale_inv, dtype=a.dq_dtype)
        elif a.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
88
89
90
            # Compare MXFP8 scales as uint8
            assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8))
        else:
91
            raise ValueError(f"Unsupported scaling mode {a.scaling_mode}")
92
        assert_allclose(a.data, b.data)
93

94
95
96
97
98
99
100
101
102
    elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x):
        assert_bitwise_scaled_tensors(a.rowwise_tensor, b.rowwise_tensor)
        assert_bitwise_scaled_tensors(a.colwise_tensor, b.colwise_tensor)
    else:
        pytest.fail("Unsupported input types")


def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
    if isinstance(a, ScaledTensor1x):
103
104
105
        if a.data_layout == "T":
            flatten_axis = a.data.ndim - a.flatten_axis
            b_transpose = jnp.transpose(b, (*range(flatten_axis, b.ndim), *range(flatten_axis)))
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
            assert_allclose(a.dequantize(), b_transpose, dtype=a.data.dtype)
        else:
            assert_allclose(a.dequantize(), b, dtype=a.data.dtype)
    elif isinstance(a, ScaledTensor2x):
        assert_dequantized_scaled_tensor(a.get_rowwise_tensor(), b)
        assert_dequantized_scaled_tensor(a.get_colwise_tensor(), b)
    else:
        pytest.fail("a must be a ScaledTensor object")


ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)]
ALL_ACTIVATION_TYPES = [
    ("gelu",),
    ("gelu", "linear"),
    ("silu",),
    ("silu", "linear"),
    ("relu",),
    ("relu", "linear"),
    ("quick_gelu",),
    ("quick_gelu", "linear"),
    ("squared_relu",),
    ("squared_relu", "linear"),
]
129

130
131
132
133
134
135
136
ACTIVATION_TYPES = {
    "L0": [
        ("gelu",),
        ("gelu", "linear"),
    ],
    "L2": ALL_ACTIVATION_TYPES,
}
137
138


139
140
141
142
143
144
145
146
147
class TestActivation:
    def ref_act(self, x, activation_type):
        return _jax_act_lu(x, activation_type)

    def value_n_grad_ref_func(self, x, activation_type):
        jitted_reference = jit(
            value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type)), (0,))
        )
        return jitted_reference(x)
148

149
150
151
152
153
154
155
156
157
158
159
160
    def primitive_func(self, inputs, activation_type, quantizer):
        out = activation(inputs, activation_type=activation_type, quantizer=quantizer)
        return jnp.mean(out)

    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper(
        "activation_type",
        (
            ALL_ACTIVATION_TYPES  # Test all activation types for this test to ensure all are functional, then just test a subset for the other tests to verify other functionality
        ),
    )
    def test_act_grad(self, shape, activation_type):
161
        key = jax.random.PRNGKey(0)
162
        x = jax.random.uniform(key, shape, jnp.float32)
163
164
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
165

166
167
168
        value_n_grad_primitive_func = jit(
            value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
        )
169

170
171
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None)
        ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type)
172

173
174
        assert_allclose(prim_out, ref_out, dtype=x.dtype)
        assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
175

176
177
178
179
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
180
181
182
183
184
185
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_act_grad_with_tensor_scaling_fp8(
        self, random_inputs, activation_type, output_type, scaling_mode
    ):
186
        x = random_inputs
187
188
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
189
        self.activation_type = activation_type
190

191
192
193
        value_n_grad_primitive_func = jit(
            value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
        )
194

195
        quantizer = QuantizerFactory.create(
196
            scaling_mode=scaling_mode,
197
            q_dtype=output_type,
198
            q_layout=QuantizeLayout.ROWWISE,
199
        )
200

201
202
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer)
        ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type)
203

204
205
        assert_allclose(prim_out, ref_out, dtype=output_type)
        assert_allclose(prim_grad, ref_grad, dtype=output_type)
206

207
208
209
210
    @pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
211
212
213
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
214
215
216
217
218
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_act_forward_with_tensor_scaling_fp8(
        self, random_inputs, activation_type, output_type, q_layout, scaling_mode
219
220
    ):
        x = random_inputs
221
222
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
223
        self.activation_type = activation_type
224

225
226
        te_quantizer, jax_quantizer = QuantizerFactory.create(
            n_quantizers=2,
227
            scaling_mode=scaling_mode,
228
            q_dtype=output_type,
229
            q_layout=q_layout,
230
        )
231

232
233
        te_output = tex.act_lu(x, activation_type, te_quantizer)
        jax_output = _jax_act_lu(x, activation_type, jax_quantizer)
234

235
        assert_bitwise_scaled_tensors(te_output, jax_output)
236

237
    @pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
238
    @pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)])
239
240
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
241
242
243
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
244
    def test_act_forward_with_block_scaling_fp8(
245
        self, random_inputs, activation_type, output_type, q_layout
246
247
    ):
        x = random_inputs
248
        x = jnp.repeat(x, len(activation_type), axis=-2)
249
        self.activation_type = activation_type
250

251
        quantizer = QuantizerFactory.create(
252
            scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
253
        )
254

255
256
        output = tex.act_lu(x, activation_type, quantizer)
        ref_out = self.ref_act(x, activation_type)
257

258
        assert_dequantized_scaled_tensor(output, ref_out)
259
260


261
262
263
264
NORM_OUTPUT_DTYPES = {
    "L0": [jnp.float8_e4m3fn],
    "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
265

266

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
@pytest_parametrize_wrapper("n, hidden", LN_CASES)
@pytest_parametrize_wrapper("inp_dtype", DTYPES)
@pytest_parametrize_wrapper("norm_type", ["layernorm", "rmsnorm"])
@pytest_parametrize_wrapper(
    "zero_centered_gamma",
    [
        pytest.param(True, id="zero_centered"),
        pytest.param(False, id="no_zero_centered"),
    ],
)
@pytest_parametrize_wrapper("epsilon", [1e-2, 1e-6])
class TestNorm:
    """
    Test transformer_engine.jax.layernorm APIs
    """
282

283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    def _test_norm_grad(
        self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
    ):
        def compute_loss(x):
            # Higher precision to compute the loss
            x_ = x.astype(jnp.float32)
            return jnp.mean(jnp.square(x_)).astype(x.dtype)

        def reference_func(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer):
            if norm_type == "rmsnorm":
                ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
            else:
                ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
            # if isinstance(ln_out, ScaledTensor):
            #     ln_out = ln_out.dequantize()
            return ln_out
299

300
301
302
303
304
305
306
307
308
309
310
311
312
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 3)

        x = jax.random.uniform(subkeys[0], (n, hidden), jnp.float32, -1, 1)
        x = x.astype(inp_dtype)
        gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
        gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
        gamma = jnp.asarray(gamma, inp_dtype)
        if norm_type == "layernorm":
            beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
            beta = jnp.asarray(beta, inp_dtype)
        else:
            beta = None
313

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
        jitted_reference = jit(
            value_and_grad(
                lambda x, gamma, beta: compute_loss(
                    reference_func(
                        x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer=None
                    )
                ),
                (0, 1, 2),
            )
        )
        jitted_primitive = jit(
            value_and_grad(
                lambda x, gamma, beta: compute_loss(
                    layernorm(x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer)
                ),
                (0, 1, 2),
330
            )
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
        )

        reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference(
            x, gamma, beta
        )
        primitive_out, (primitive_dx, primitive_dgamma, primitive_dbeta) = jitted_primitive(
            x, gamma, beta
        )

        out_dtype = inp_dtype if quantizer is None else quantizer.q_dtype
        assert_allclose(primitive_out, reference_out, dtype=out_dtype)
        assert_allclose(primitive_dx, reference_dx, dtype=out_dtype)
        assert_allclose(primitive_dgamma, reference_dgamma, dtype=out_dtype)
        if beta is not None:
            assert_allclose(primitive_dbeta, reference_dbeta, dtype=out_dtype)
346

347
348
349
350
351
352
353
354
355
356
    def test_norm_grad(self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype):
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        self._test_norm_grad(
            n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer=None
        )
357

358
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
359
360
    # No Norm FWD E5M2 in TE backend
    @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
361
362
363
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
364
365
366
367
368
369
370
371
372
373
374
375
376
377
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_norm_grad_with_tensor_scaling_fp8(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        q_layout,
        scaling_mode,
378
379
380
381
382
383
384
385
    ):
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        quantizer = QuantizerFactory.create(
386
            scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
        )
        self._test_norm_grad(
            n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
        )

    def _test_norm_forward(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        scaling_mode,
402
        q_layout,
403
    ):
404
        key = jax.random.PRNGKey(0)
405
        subkeys = jax.random.split(key, 3)
406

407
408
409
410
411
        x = jax.random.uniform(subkeys[0], (n, hidden), inp_dtype, -1, 1)
        x = jnp.asarray(x, inp_dtype)
        gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
        gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
        gamma = jnp.asarray(gamma, inp_dtype)
412

413
        quantizer, ref_quantizer = QuantizerFactory.create(
414
            n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
415
416
417
418
419
420
        )
        if norm_type == "layernorm":
            beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
            beta = jnp.asarray(beta, inp_dtype)
            output, mu, rsigma = tex.layernorm_fwd(
                x, gamma, beta, zero_centered_gamma, epsilon, quantizer=quantizer
421
            )
422
423
            ref_out, ref_mu, ref_rsigma = _jax_layernorm(
                x, gamma, beta, zero_centered_gamma, epsilon, quantizer=ref_quantizer
424
            )
425
426
427
        else:
            output, rsigma = tex.rmsnorm_fwd(
                x, gamma, zero_centered_gamma, epsilon, quantizer=quantizer
428
            )
429
430
            ref_out, ref_rsigma = _jax_rmsnorm(
                x, gamma, zero_centered_gamma, epsilon, quantizer=ref_quantizer
431
            )
432
            ref_mu = None
433

434
435
436
        precise_comparison = True

        if get_cudnn_version() < (9, 10, 0) and scaling_mode == ScalingMode.MXFP8_1D_SCALING:
437
438
            # Reduce precision of test as we don't use fused norm below this version CuDNN for MXFP8 and instead
            # do an unfused norm and quantize with an intermediate cast into in_dtype which can reduce precision
439
440
441
442
443
444
445
446
447
448
449
450
            precise_comparison = False
        elif is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode):
            # Larger tolerances as our JAX implementation _jax_*norm uses the compute dtype float32
            # for zero-centered gamma always
            precise_comparison = False
        elif scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING and inp_dtype != jnp.float32:
            # Current implementation of Current Tensor Scaling performs unfused layernorm and quantization
            # and writes intermediate results into the input dtype, which will slightly reduce precision
            # if the input dtype is not float32
            precise_comparison = False

        if precise_comparison:
451
            assert_bitwise_scaled_tensors(output, ref_out)
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
        else:
            if isinstance(ref_out, ScaledTensor1x):
                assert_allclose(output.dequantize(), ref_out.dequantize(), dtype=out_dtype)
            elif isinstance(ref_out, ScaledTensor2x):
                assert_allclose(
                    output.rowwise_tensor.dequantize(),
                    ref_out.rowwise_tensor.dequantize(),
                    dtype=out_dtype,
                )
                assert_allclose(
                    output.colwise_tensor.dequantize(),
                    ref_out.colwise_tensor.dequantize(),
                    dtype=out_dtype,
                )
            else:
                pytest.fail("Unsupported output type")

469
470
471
        assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype)
        if norm_type == "layernorm":
            assert_allclose(mu, ref_mu, dtype=inp_dtype)
472

473
474
475
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    # No Norm FWD E5M2 in TE backend
    @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
476
477
478
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
479
480
481
482
483
484
485
486
487
488
489
490
491
492
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_norm_forward_with_tensor_scaling_fp8(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        q_layout,
        scaling_mode,
493
494
495
496
497
498
499
500
501
502
503
504
    ):
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        self._test_norm_forward(
            n=n,
            hidden=hidden,
            norm_type=norm_type,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon,
            inp_dtype=inp_dtype,
            out_dtype=out_dtype,
505
            scaling_mode=scaling_mode,
506
            q_layout=q_layout,
507
        )
508

509
510
511
512
513
514
515
516
517
518
519
520
521
    @pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
    @pytest.mark.parametrize("out_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    def test_norm_forward_with_block_scaling_fp8(
        self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype
    ):
        self._test_norm_forward(
            n=n,
            hidden=hidden,
            norm_type=norm_type,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon,
            inp_dtype=inp_dtype,
            out_dtype=out_dtype,
522
            scaling_mode=ScalingMode.MXFP8_1D_SCALING,
523
            q_layout=QuantizeLayout.ROWWISE_COLWISE,
524
        )
525
526


527
528
529
530
QUANTIZE_OUTPUT_DTYPES = {
    "L0": [jnp.float8_e4m3fn],
    "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
531

532
533
534
535
536
537
538
539
540
ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [
    ((32, 64), -1),
    ((2, 64, 32), -1),
    ((2, 64, 32), -2),
    ((32, 256, 128), -1),
    ((32, 256, 128), -2),
    ((64, 32, 32, 256), -1),
    ((64, 32, 32, 256), -2),
    ((64, 32, 32, 256), -3),
541
]
542

543
QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = {
544
    "L0": [
545
546
547
        ((32, 64), -1),
        ((2, 64, 32), -1),
        ((2, 64, 32), -2),
548
    ],
549
    "L2": ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES,
550
}
551

552
553
554
555
556
557
558
559
560
QUANTIZATION_INPUT_DTYPE = {
    "L0": [jnp.bfloat16],
    "L2": [jnp.float32, jnp.float16, jnp.bfloat16],
}


@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
561
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
562
563
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper(
564
    "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
565
566
567
568
569
570
)
class TestQuantize:
    """
    Purely quantization related tests that will always test on a wider set of types and shapes
    """

571
    def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
572
        key = jax.random.PRNGKey(0)
573

574
575
576
577
        # Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling)
        quantizer = QuantizerFactory.create(
            scaling_mode=scaling_mode,
            q_dtype=q_dtype,
578
            q_layout=q_layout,
579
        )
580
581
582
        # Adding dimension to test if padding is done correctly when flatten 3D to 2D
        if flatten_axis == -2:
            input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
583

584
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
585
586
587
        for _ in range(n_iterations):
            x = jax.random.uniform(key, input_shape, in_dtype)

588
            scaled_tensor = quantizer.quantize(x, flatten_axis=flatten_axis)
589
590
            assert_dequantized_scaled_tensor(scaled_tensor, x)

591
592
593
    def test_quantize_bitwise(
        self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
    ):
594
595

        key = jax.random.PRNGKey(0)
596
597
        if flatten_axis == -2:
            input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
598
599
600
        input = jax.random.uniform(key, input_shape, in_dtype)

        te_quantizer, jax_quantizer = QuantizerFactory.create(
601
            n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout
602
        )
603

604
        jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
605

606
607
        te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
        assert_bitwise_scaled_tensors(te_output, jax_output)
608
609
610
611
612
613
614


@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
class TestFusedQuantize:

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
615
    @pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
616
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
617
618
619
620
621
622
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
    def test_quantize_dbias(
        self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis
    ):
623
        if scaling_mode == ScalingMode.MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8(
624
625
626
627
            input_shape
        ):
            pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")

628
629
630
631
632
633
        if (flatten_axis < 0 and flatten_axis + len(input_shape) <= 0) or flatten_axis <= 0:
            pytest.skip(
                f"Flatten axis {flatten_axis} is not supported for input shape {input_shape}. There"
                " must be at least one axis on either side of the flatten_axis split."
            )

634
635
636
637
        key = jax.random.PRNGKey(0)
        input = jax.random.uniform(key, input_shape, in_dtype)

        jax_quantizer, te_quantizer = QuantizerFactory.create(
638
            n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
639
        )
640

641
642
643
644
645
        te_output, te_dbias = jit(
            lambda input: tex.quantize_dbias(
                input, quantizer=te_quantizer, flatten_axis=flatten_axis
            )
        )(input)
646
647
648

        jax_output, jax_dbias = jit(
            lambda input: _jax_quantize_dbias(
649
                input, quantizer=jax_quantizer, flatten_axis=flatten_axis
650
            )
651
        )(input)
652

653
        assert_bitwise_scaled_tensors(te_output, jax_output)
654

655
        assert_allclose(te_dbias, jax_dbias)
656
657

    def _test_quantize_dact_dbias(
658
        self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout
659
660
661
662
    ):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1)
663
664
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
665
        dz = jax.random.uniform(subkeys[1], input_shape, in_dtype, -1, 1)
666

667
        jax_quantizer, te_quantizer = QuantizerFactory.create(
668
            n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
        )
        is_casted_output = te_quantizer is not None

        te_output, te_dbias = jit(
            lambda dz, x: tex.quantize_dact_dbias(
                dz,
                x,
                activation_type=activation_type,
                is_dbias=is_dbias,
                quantizer=te_quantizer,
            )
        )(dz, x)

        jax_output, jax_dbias = jit(
            lambda dz, x: _jax_quantize_dact_dbias(
                dz,
                x,
                activation_type=activation_type,
                is_dbias=is_dbias,
                quantizer=jax_quantizer,
            )
        )(dz, x)
691

692
        if is_casted_output:
693
            assert_bitwise_scaled_tensors(te_output, jax_output)
694
        else:
695
            assert_allclose(te_output, jax_output)
696
697

        if is_dbias:
698
            assert_allclose(te_dbias, jax_dbias)
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713

    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("is_dbias", [True, False])
    def test_quantize_dact_dbias_no_quantization(
        self,
        in_dtype,
        input_shape,
        activation_type,
        is_dbias,
    ):
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=in_dtype,
714
            scaling_mode=ScalingMode.NO_SCALING,
715
716
            activation_type=activation_type,
            is_dbias=is_dbias,
717
            q_layout=QuantizeLayout.ROWWISE,
718
        )
719

720
721
722
723
724
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
    @pytest_parametrize_wrapper("is_dbias", [True, False])
725
    @pytest_parametrize_wrapper(
726
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
727
    )
728
729
730
731
732
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_quantize_dact_dbias_tensor_scaling(
        self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout, scaling_mode
733
734
735
736
737
    ):
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=out_dtype,
738
            scaling_mode=scaling_mode,
739
740
            activation_type=activation_type,
            is_dbias=is_dbias,
741
            q_layout=q_layout,
742
        )
743

744
745
746
747
748
749
750
    @pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper(
        "input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)]
    )
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
    @pytest_parametrize_wrapper("is_dbias", [True, False])
751
752
753
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
754
    def test_quantize_dact_dbias_mxfp8_scaling(
755
        self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout
756
757
758
759
760
761
762
763
    ):
        if reduce(operator.mul, input_shape[:-1]) % 128 != 0 or input_shape[-1] % 128 != 0:
            # TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes.
            # If it doesn't, move this check into the quantize_dact_dbias function and revert to JAX
            # implementation in the unsupported cases
            pytest.skip(
                f"Input shape {input_shape} is not supported by dact MXFP8 kernel in TE currently"
            )
764

765
766
767
768
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=out_dtype,
769
            scaling_mode=ScalingMode.MXFP8_1D_SCALING,
770
771
            activation_type=activation_type,
            is_dbias=is_dbias,
772
            q_layout=q_layout,
773
        )
774
775


776
class TestDense:
777
778
    def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
        if data_layout[0] == "T":
779
            a = jnp.swapaxes(a, -1, -2)
780
        if data_layout[1] == "T":
781
782
            b = jnp.swapaxes(b, -1, -2)
        return jnp.dot(a, b)
783

784
    def _generate_gemm_input(self, m, n, k, data_layout):
785
786
787
788
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        x = jax.random.uniform(
            subkeys[0],
789
            (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
790
791
792
793
            dtype=jnp.bfloat16,
        ) / jnp.sqrt(k)
        w = jax.random.uniform(
            subkeys[1],
794
            (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
795
796
            dtype=jnp.bfloat16,
        ) / jnp.sqrt(n)
797
798
        lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
        rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
799
800
801
802
        contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)

        return (x, w, contracting_dims)

803
804
805
806
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
    @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
    def test_gemm_bf16(self, m, n, k, data_layout):
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
807
808

        primitive_out = tex.gemm(x, w, contracting_dims)
809
        ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
810

811
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
812

813
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
814
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
815
816
    @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
817
818
819
    @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
    def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, data_layout):
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
820
821
822
823
824
825
        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False
        )
        primitive_out = tex.gemm(
            x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set
        )
826
        ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
827

828
        assert_allclose(primitive_out, ref_out, dtype=q_dtype)
829

830
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
831
    def test_dense_grad_bf16(self, m, n, k):
832
833
        data_layout = "NN"
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
834

835
836
837
        def primitive_func(x, w, contracting_dims):
            primitive_out = dense(x, w, contracting_dims=contracting_dims)
            return jnp.mean(primitive_out)
838

839
840
        def ref_func(x, w, data_layout):
            return jnp.mean(self._ref_gemm_with_jnp_dot(x, w, data_layout))
841

842
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))
843

844
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))
845

846
847
        primitive_out, (primitive_x_grad, primitive_w_grad) = value_n_grad_primitive_func(
            x, w, contracting_dims
848
        )
849
        ref_out, (ref_x_grad, ref_w_grad) = value_n_grad_ref_func(x, w, data_layout)
850
851
852
853

        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
        assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16)
        assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16)
854

855
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
856
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
857
858
859
    @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
    def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode):
860
861
        data_layout = "NN"
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
862
863
864
865
866
867
868
869
870

        key = jax.random.PRNGKey(1)
        bias = jax.random.uniform(key, n, dtype=jnp.bfloat16)

        def primitive_func(x, w, bias, contracting_dims, quantizer_set):
            primitive_out = dense(
                x, w, bias, contracting_dims=contracting_dims, quantizer_set=quantizer_set
            )
            return jnp.mean(primitive_out)
871

872
        def ref_func(x, w, bias, data_layout):
873
            return jnp.mean(
874
                self._ref_gemm_with_jnp_dot(x, w, data_layout) + jnp.expand_dims(bias, axis=0)
875
            )
876

877
878
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
879

880
881
        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True
882
        )
883

884
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
885
886
887
888
889
        for _ in range(n_iterations):
            primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = (
                value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set)
            )

890
891
892
        ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func(
            x, w, bias, data_layout
        )
893
894
895
896
897

        assert_allclose(primitive_out, ref_out, dtype=q_dtype)
        assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype)
        assert_allclose(primitive_w_grad, ref_w_grad, dtype=q_dtype)
        assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=q_dtype)
898
899


900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
    key = jax.random.PRNGKey(0)
    subkeys = jax.random.split(key, 4)
    out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
    return out


def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer):
    if norm_type == "rmsnorm":
        ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
    else:
        ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
    if isinstance(ln_out, ScaledTensor):
        ln_out = ln_out.dequantize()
    return ln_out


class TestFusedDense:
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
920
    @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
921
922
923
924
    @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
    @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
    def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type):
925
        """
926
        Test layernorm_dense VJP Rule
927
        """
928
        # No Norm FWD E5M2 in TE backend
929
930
931
932
        if q_dtype == jnp.float8_e5m2 and scaling_mode in (
            ScalingMode.DELAYED_TENSOR_SCALING,
            ScalingMode.CURRENT_TENSOR_SCALING,
        ):
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
            pytest.skip("E5M2 is not supported in normalization with TE Backend!")

        # zero_centered_gamma is already tested in TestNorm
        zero_centered_gamma = False
        eps = 1e-6

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 4)

        # NN in FWD
        x = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16) / jnp.sqrt(k)
        w = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16) / jnp.sqrt(n)

        gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)

        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode,
            fwd_dtype=q_dtype,
            bwd_dtype=q_dtype,
            is_2x2x=True,
        )

        if norm_type == "layernorm":
            beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
957
        else:
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
            beta = None

        def prim_func(x, w, gamma, beta):
            # bias = None as quantize_dbias is already tested in test_dense_grad_fp8
            prim_out = layernorm_dense(
                x,
                w,
                gamma,
                beta,
                None,
                norm_type,
                zero_centered_gamma,
                eps,
                quantizer_set=quantizer_set,
            )
            return jnp.mean(prim_out)

        def ref_func(x, w, gamma, beta):
            x = _ref_jax_norm_impl(
                x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
            )
            return jnp.mean(jnp.dot(x, w))

        value_n_grad_prim_func = value_and_grad(prim_func, (0, 1, 2, 3))
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))

        ref_out, (ref_x_grad, ref_w_grad, ref_gamma_grad, ref_beta_grad) = value_n_grad_ref_func(
            x, w, gamma, beta
        )

988
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
        for _ in range(n_iterations):
            prim_out, (
                prim_x_grad,
                prim_w_grad,
                prim_gamma_grad,
                prim_beta_grad,
            ) = value_n_grad_prim_func(x, w, gamma, beta)

        assert_allclose(prim_out, ref_out, dtype=q_dtype)
        assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype)
        assert_allclose(prim_w_grad, ref_w_grad, dtype=q_dtype)
        assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype)
        if beta is not None:
            assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
1005
    @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
1006
1007
1008
1009
1010
1011
1012
    @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
    @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
    @pytest.mark.parametrize("use_bias", [True, False])
    def test_layernorm_mlp_grad(
        self, m, n, k, activation_type, q_dtype, scaling_mode, norm_type, use_bias
1013
    ):
1014
        """
1015
        Test layernorm_mlp VJP Rule
1016
        """
1017
        # No Norm FWD E5M2 in TE backend
1018
1019
1020
1021
        if q_dtype == jnp.float8_e5m2 and scaling_mode in (
            ScalingMode.DELAYED_TENSOR_SCALING,
            ScalingMode.CURRENT_TENSOR_SCALING,
        ):
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
            pytest.skip("E5M2 is not supported in normalization with TE Backend!")

        # zero_centered_gamma is already tested in TestNorm
        zero_centered_gamma = False
        eps = 1e-6

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 6)

        x = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        kernel_1 = jax.random.normal(
1033
            subkeys[1], (k, len(activation_type), n), jnp.bfloat16
1034
1035
1036
1037
1038
        ) / jnp.sqrt(k)
        kernel_2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
        gamma = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
        beta = None  # was tested in TestNorm
        if use_bias:
1039
            bias_1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
            bias_2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
        else:
            bias_1 = None
            bias_2 = None

        quantizer_sets = QuantizerFactory.create_set(
            n_quantizer_sets=2,
            scaling_mode=scaling_mode,
            fwd_dtype=q_dtype,
            bwd_dtype=q_dtype,
            is_2x2x=True,
        )

        if norm_type == "layernorm":
            beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
        else:
            beta = None

        def prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            return jnp.mean(
                layernorm_mlp(
                    x,
                    gamma,
                    beta,
                    [kernel_1, kernel_2],
                    [bias_1, bias_2],
                    norm_type,
                    zero_centered_gamma=zero_centered_gamma,
                    epsilon=eps,
                    activation_type=activation_type,
                    quantizer_sets=quantizer_sets,
1071
1072
                )
            )
1073

1074
1075
1076
        def _ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            ln_out = _ref_jax_norm_impl(
                x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
1077
            )
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
            # TODO: replace gemm with jnp.dot
            linear_1_out = tex.gemm(ln_out, kernel_1, ((1,), (0,)))
            if use_bias:
                bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
                linear_1_out += jnp.reshape(bias_1, bias_1_shape)

            x = _jax_act_lu(linear_1_out, activation_type)
            linear_2_out = tex.gemm(x, kernel_2, ((1,), (0,)))
            if use_bias:
                bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape
                linear_2_out += jnp.reshape(bias_2, bias_2_shape)

            return linear_2_out

        def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            return jnp.mean(_ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2))

        value_n_grad_prim_func = value_and_grad(prim_func, range(6))
        value_n_grad_ref_func = value_and_grad(ref_func, range(6))

1098
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
        for _ in range(n_iterations):
            prim_out, (
                prim_x_grad,
                prim_gamma_grad,
                prim_kernel_1_grad,
                prim_kernel_2_grad,
                prim_bias_1_grad,
                prim_bias_2_grad,
            ) = value_n_grad_prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)

        ref_out, (
            ref_x_grad,
            ref_gamma_grad,
            ref_kernel_1_grad,
            ref_kernel_2_grad,
            ref_bias_1_grad,
            ref_bias_2_grad,
        ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)

        assert_allclose(prim_out, ref_out, dtype=q_dtype)

        assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=q_dtype)
        if use_bias:
            assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=q_dtype)

        assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=q_dtype)
        if use_bias:
            assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=q_dtype)

        assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype)
        assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype)


# This function is modified from transformer_engine/jax/cpp_extensions/gemm.py::_jax_gemm()
def _quantize_gemm_pair(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer):
    ((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
    lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1
    rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1
    lhs_q = lhs_quantizer.quantize(
        lhs,
        is_rowwise=lhs_is_rowwise,
        is_colwise=not lhs_is_rowwise,
    )
    rhs_q = rhs_quantizer.quantize(
        rhs,
        is_rowwise=rhs_is_rowwise,
        is_colwise=not rhs_is_rowwise,
    )
    return lhs_q, rhs_q

1149

1150
1151
1152
1153
1154
1155
1156
# E5M2 * E5M2 is not supported
fwd_bwd_dtypes = [
    [jnp.float8_e4m3fn, jnp.float8_e4m3fn],
    [jnp.float8_e4m3fn, jnp.float8_e5m2],
    [jnp.float8_e5m2, jnp.float8_e4m3fn],
]

1157
"""
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
@pytest_parametrize_wrapper(
    "shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]]
)
class TestGroupedDense:
    def _ref_grouped_gemm_with_jnp_dot(self, lhs_list, rhs_list, contracting_dims_list):
        ref_out_list = []
        for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list):
            dim_nums = (contracting_dims, ((), ()))
            ref_out_list.append(jax.lax.dot_general(lhs, rhs, dim_nums))
        return ref_out_list

    def _generate_grouped_gemm_input(self, dtype, shape_list, layout_list):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, len(shape_list) * 2)

        lhs_list, rhs_list, contracting_dims_list = [], [], []
1174
        for i, ((m, n, k), data_layout) in enumerate(zip(shape_list, layout_list)):
1175
1176
            lhs = jax.random.uniform(
                subkeys[2 * i],
1177
                (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
1178
                dtype=dtype,
1179
            )
1180
1181
            rhs = jax.random.uniform(
                subkeys[2 * i + 1],
1182
                (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
1183
                dtype=dtype,
1184
            )
1185
1186
            lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
            rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
1187
            contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
1188

1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
            lhs_list.append(lhs)
            rhs_list.append(rhs)
            contracting_dims_list.append(contracting_dims)

        return lhs_list, rhs_list, contracting_dims_list

    @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
    @pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]])
    def test_grouped_gemm_fp16(self, dtype, shape_list, layout_list):
        lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input(
            dtype, shape_list, layout_list
        )
        ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list)
        primitive_out = tex.grouped_gemm(lhs_list, rhs_list, contracting_dims_list)
        for i in range(len(shape_list)):
            assert_allclose(primitive_out[i], ref_out[i], dtype=dtype)
1205
1206

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
1207
1208
1209
1210
1211
1212
1213
1214
    @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
    @pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]])
    def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list, layout_list):
        fwd_dtype, bwd_dtype = fwd_bwd_dtype
        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=False
        )
1215

1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
        out_dtype = jnp.bfloat16
        lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input(
            out_dtype, shape_list, layout_list
        )
        q_lhs_list = []
        q_rhs_list = []
        for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list):
            # quantizer_set.x and quantizer_set.kernel have the same q_dtype, we want to
            # test the case where lhs and rhs have different q_dtypes
            q_lhs, q_rhs = _quantize_gemm_pair(
                lhs, rhs, contracting_dims, quantizer_set.x, quantizer_set.dgrad
            )
            q_lhs_list.append(q_lhs)
            q_rhs_list.append(q_rhs)
1230

1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
        ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list)
        primitive_out = tex.grouped_gemm(q_lhs_list, q_rhs_list, contracting_dims_list)

        allclose_dtype = jnp.float8_e4m3fn
        if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2:
            allclose_dtype = jnp.float8_e5m2
        for i in range(len(shape_list)):
            assert_allclose(primitive_out[i], ref_out[i], dtype=allclose_dtype)

    @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
    def test_grouped_dense_grad_fp16(self, dtype, shape_list):
        group_size = len(shape_list)
        layout_list = ["NN" for _ in range(group_size)]

        x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input(
            dtype, shape_list, layout_list
        )
        bias_list = []
        key = jax.random.PRNGKey(1)
        for shape in shape_list:
            n = shape[1]
            bias = jax.random.uniform(key, n, dtype=dtype)
            bias_list.append(bias)

        def ref_func(x_list, kernel_list, bias_list, contracting_dims_list):
            out_list = []
            for i in range(len(x_list)):
                out_list.append(
                    dense(
                        x_list[i],
                        kernel_list[i],
                        bias_list[i],
                        contracting_dims=contracting_dims_list[i],
                    )
1265
                )
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
            # Note: we use jnp.sum instead of jnp.mean to make the gradient larger
            # and prevent them from being clamp to zero
            out_sum_list = [jnp.sum(out) for out in out_list]
            return jnp.sum(jnp.asarray(out_sum_list))

        def primitive_func(x_list, kernel_list, bias_list, contracting_dims_list):
            out_list = grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list)
            out_sum_list = [jnp.sum(out) for out in out_list]
            return jnp.sum(jnp.asarray(out_sum_list))

        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
1278

1279
1280
1281
1282
1283
1284
        ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func(
            x_list, kernel_list, bias_list, contracting_dims_list
        )
        primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = (
            value_n_grad_primitive_func(x_list, kernel_list, bias_list, contracting_dims_list)
        )
1285

1286
1287
1288
1289
1290
        assert_allclose(primitive_out_mean, ref_out_mean, dtype=dtype)
        for i in range(group_size):
            assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=dtype)
            assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=dtype)
            assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=dtype)
1291

1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
    def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list):
        group_size = len(shape_list)
        layout_list = ["NN" for _ in range(group_size)]
        fwd_dtype, bwd_dtype = fwd_bwd_dtype
        if fwd_dtype == jnp.float8_e5m2:
            pytest.skip("We never use E5M2 for fwd_dtype in training")

        # Question: should we use different quantizers for different groups?
        ref_quantizer_set_list = []
        quantizer_set_list = []
        for _ in range(group_size):
            ref_quantizer_set = QuantizerFactory.create_set(
                scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True
1308
            )
1309
1310
1311
1312
1313
            ref_quantizer_set_list.append(ref_quantizer_set)
            quantizer_set = QuantizerFactory.create_set(
                scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True
            )
            quantizer_set_list.append(quantizer_set)
1314

1315
1316
1317
        out_dtype = jnp.bfloat16
        x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input(
            out_dtype, shape_list, layout_list
1318
        )
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
        bias_list = []
        key = jax.random.PRNGKey(1)
        for shape in shape_list:
            n = shape[1]
            bias = jax.random.uniform(key, n, dtype=out_dtype)
            bias_list.append(bias)

        def ref_func(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list):
            out_list = []
            for i in range(len(x_list)):
                out_list.append(
                    dense(
                        x_list[i],
                        kernel_list[i],
                        bias_list[i],
                        contracting_dims=contracting_dims_list[i],
                        quantizer_set=quantizer_set_list[i],
                    )
                )
            # Note: we use jnp.sum instead of jnp.mean to make the gradient larger
            # and prevent them from being clamp to zero
            out_sum_list = [jnp.sum(out) for out in out_list]
            return jnp.sum(jnp.asarray(out_sum_list))

        def primitive_func(
            x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
        ):
            out_list = grouped_dense(
                x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
            )
            out_sum_list = [jnp.sum(out) for out in out_list]
            return jnp.sum(jnp.asarray(out_sum_list))

        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))

        ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func(
            x_list, kernel_list, bias_list, contracting_dims_list, ref_quantizer_set_list
1357
        )
1358
1359
1360
1361
        primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = (
            value_n_grad_primitive_func(
                x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
            )
1362
1363
        )

1364
1365
1366
1367
1368
1369
1370
1371
        allclose_dtype = jnp.float8_e4m3fn
        if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2:
            allclose_dtype = jnp.float8_e5m2
        assert_allclose(primitive_out_mean, ref_out_mean, dtype=allclose_dtype)
        for i in range(group_size):
            assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype)
            assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype)
            assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype)
1372
"""