test_custom_call_compute.py 50.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.

import jax
import jax.numpy as jnp
import pytest
from jax import jit, value_and_grad
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from functools import reduce
import operator

from utils import (
    assert_allclose,
    assert_tree_like_allclose,
    pytest_parametrize_wrapper,
)
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.layernorm_mlp import layernorm_mlp

from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu, _jax_quantize_dact_dbias
from transformer_engine.jax.cpp_extensions.normalization import _jax_layernorm, _jax_rmsnorm
from transformer_engine.jax.cpp_extensions.quantization import (
    _jax_quantize,
    _jax_quantize_dbias,
25
)
26
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
27
from transformer_engine.jax import cpp_extensions as tex
28
29
30
31
32
from transformer_engine.jax.quantize import (
    DelayedScaleQuantizer,
    ScaledTensor,
    ScalingMode,
    QuantizerFactory,
33
    QuantizeLayout,
34
35
36
37
38
39
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense, grouped_dense
from transformer_engine.jax.layernorm_dense import layernorm_dense
from transformer_engine.jax.quantize import ScaledTensor1x, ScaledTensor2x
40

Tim Moon's avatar
Tim Moon committed
41
42
43
44
45
46
47
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (2048, 1024, 2048),
    (2048, 2048, 1024),
    (2048, 1024, 1024),
]
48
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
49
LN_CASES = [(256, 128), (128, 256)]
50
DTYPES = [jnp.bfloat16, jnp.float32]
51
is_fp8_supported, reason = helper.is_fp8_available()
52
is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
53
54
55
56

supported_scaling_modes = []
""" Find supported scaling modes"""
if is_fp8_supported:
57
    supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING)
58
if is_mxfp8_supported:
59
    supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING)
60
61
62
63
64
65


def is_shape_supported_by_mxfp8(input_shape):
    try:
        if isinstance(input_shape, type(pytest.param(0))):
            input_shape = input_shape.values[0]
66
        ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape)
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        return True
    except:
        # get_scale_shapes will raise an exception if the shape is not supported
        return False


def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
    if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
        assert_allclose(a.data, b.data)
        assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8))
    elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x):
        assert_bitwise_scaled_tensors(a.rowwise_tensor, b.rowwise_tensor)
        assert_bitwise_scaled_tensors(a.colwise_tensor, b.colwise_tensor)
    else:
        pytest.fail("Unsupported input types")


def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
    if isinstance(a, ScaledTensor1x):
86
87
88
        if a.data_layout == "T":
            flatten_axis = a.data.ndim - a.flatten_axis
            b_transpose = jnp.transpose(b, (*range(flatten_axis, b.ndim), *range(flatten_axis)))
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
            assert_allclose(a.dequantize(), b_transpose, dtype=a.data.dtype)
        else:
            assert_allclose(a.dequantize(), b, dtype=a.data.dtype)
    elif isinstance(a, ScaledTensor2x):
        assert_dequantized_scaled_tensor(a.get_rowwise_tensor(), b)
        assert_dequantized_scaled_tensor(a.get_colwise_tensor(), b)
    else:
        pytest.fail("a must be a ScaledTensor object")


ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)]
ALL_ACTIVATION_TYPES = [
    ("gelu",),
    ("gelu", "linear"),
    ("silu",),
    ("silu", "linear"),
    ("relu",),
    ("relu", "linear"),
    ("quick_gelu",),
    ("quick_gelu", "linear"),
    ("squared_relu",),
    ("squared_relu", "linear"),
]
112

113
114
115
116
117
118
119
ACTIVATION_TYPES = {
    "L0": [
        ("gelu",),
        ("gelu", "linear"),
    ],
    "L2": ALL_ACTIVATION_TYPES,
}
120
121


122
123
124
125
126
127
128
129
130
class TestActivation:
    def ref_act(self, x, activation_type):
        return _jax_act_lu(x, activation_type)

    def value_n_grad_ref_func(self, x, activation_type):
        jitted_reference = jit(
            value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type)), (0,))
        )
        return jitted_reference(x)
131

132
133
134
135
136
137
138
139
140
141
142
143
    def primitive_func(self, inputs, activation_type, quantizer):
        out = activation(inputs, activation_type=activation_type, quantizer=quantizer)
        return jnp.mean(out)

    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper(
        "activation_type",
        (
            ALL_ACTIVATION_TYPES  # Test all activation types for this test to ensure all are functional, then just test a subset for the other tests to verify other functionality
        ),
    )
    def test_act_grad(self, shape, activation_type):
144
        key = jax.random.PRNGKey(0)
145
        x = jax.random.uniform(key, shape, jnp.float32)
146
147
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
148

149
150
151
        value_n_grad_primitive_func = jit(
            value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
        )
152

153
154
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None)
        ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type)
155

156
157
        assert_allclose(prim_out, ref_out, dtype=x.dtype)
        assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
158

159
160
161
162
163
164
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, output_type):
        x = random_inputs
165
166
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
167
        self.activation_type = activation_type
168

169
170
171
        value_n_grad_primitive_func = jit(
            value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
        )
172

173
        quantizer = QuantizerFactory.create(
174
            scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
175
            q_dtype=output_type,
176
            q_layout=QuantizeLayout.ROWWISE,
177
        )
178

179
180
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer)
        ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type)
181

182
183
        assert_allclose(prim_out, ref_out, dtype=output_type)
        assert_allclose(prim_grad, ref_grad, dtype=output_type)
184

185
186
187
188
    @pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
189
190
191
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
192
    def test_act_forward_with_delayed_scaling_fp8(
193
        self, random_inputs, activation_type, output_type, q_layout
194
195
    ):
        x = random_inputs
196
197
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
198
        self.activation_type = activation_type
199

200
201
        te_quantizer, jax_quantizer = QuantizerFactory.create(
            n_quantizers=2,
202
            scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
203
            q_dtype=output_type,
204
            q_layout=q_layout,
205
        )
206

207
208
        te_output = tex.act_lu(x, activation_type, te_quantizer)
        jax_output = _jax_act_lu(x, activation_type, jax_quantizer)
209

210
        assert_bitwise_scaled_tensors(te_output, jax_output)
211

212
    @pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
213
    @pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)])
214
215
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
216
217
218
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
219
    def test_act_forward_with_block_scaling_fp8(
220
        self, random_inputs, activation_type, output_type, q_layout
221
222
    ):
        x = random_inputs
223
        x = jnp.repeat(x, len(activation_type), axis=-2)
224
        self.activation_type = activation_type
225

226
        quantizer = QuantizerFactory.create(
227
            scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
228
        )
229

230
231
        output = tex.act_lu(x, activation_type, quantizer)
        ref_out = self.ref_act(x, activation_type)
232

233
        assert_dequantized_scaled_tensor(output, ref_out)
234
235


236
237
238
239
NORM_OUTPUT_DTYPES = {
    "L0": [jnp.float8_e4m3fn],
    "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
240

241

242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
@pytest_parametrize_wrapper("n, hidden", LN_CASES)
@pytest_parametrize_wrapper("inp_dtype", DTYPES)
@pytest_parametrize_wrapper("norm_type", ["layernorm", "rmsnorm"])
@pytest_parametrize_wrapper(
    "zero_centered_gamma",
    [
        pytest.param(True, id="zero_centered"),
        pytest.param(False, id="no_zero_centered"),
    ],
)
@pytest_parametrize_wrapper("epsilon", [1e-2, 1e-6])
class TestNorm:
    """
    Test transformer_engine.jax.layernorm APIs
    """
257

258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
    def _test_norm_grad(
        self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
    ):
        def compute_loss(x):
            # Higher precision to compute the loss
            x_ = x.astype(jnp.float32)
            return jnp.mean(jnp.square(x_)).astype(x.dtype)

        def reference_func(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer):
            if norm_type == "rmsnorm":
                ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
            else:
                ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
            # if isinstance(ln_out, ScaledTensor):
            #     ln_out = ln_out.dequantize()
            return ln_out
274

275
276
277
278
279
280
281
282
283
284
285
286
287
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 3)

        x = jax.random.uniform(subkeys[0], (n, hidden), jnp.float32, -1, 1)
        x = x.astype(inp_dtype)
        gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
        gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
        gamma = jnp.asarray(gamma, inp_dtype)
        if norm_type == "layernorm":
            beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
            beta = jnp.asarray(beta, inp_dtype)
        else:
            beta = None
288

289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        jitted_reference = jit(
            value_and_grad(
                lambda x, gamma, beta: compute_loss(
                    reference_func(
                        x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer=None
                    )
                ),
                (0, 1, 2),
            )
        )
        jitted_primitive = jit(
            value_and_grad(
                lambda x, gamma, beta: compute_loss(
                    layernorm(x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer)
                ),
                (0, 1, 2),
305
            )
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        )

        reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference(
            x, gamma, beta
        )
        primitive_out, (primitive_dx, primitive_dgamma, primitive_dbeta) = jitted_primitive(
            x, gamma, beta
        )

        out_dtype = inp_dtype if quantizer is None else quantizer.q_dtype
        assert_allclose(primitive_out, reference_out, dtype=out_dtype)
        assert_allclose(primitive_dx, reference_dx, dtype=out_dtype)
        assert_allclose(primitive_dgamma, reference_dgamma, dtype=out_dtype)
        if beta is not None:
            assert_allclose(primitive_dbeta, reference_dbeta, dtype=out_dtype)
321

322
323
324
325
326
327
328
329
330
331
    def test_norm_grad(self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype):
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        self._test_norm_grad(
            n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer=None
        )
332

333
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
334
335
    # No Norm FWD E5M2 in TE backend
    @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
336
337
338
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
339
    def test_norm_grad_with_delayed_scaling_fp8(
340
        self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout
341
342
343
344
345
346
347
348
    ):
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        quantizer = QuantizerFactory.create(
349
            scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
350
351
            q_dtype=out_dtype,
            q_layout=q_layout,
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
        )
        self._test_norm_grad(
            n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
        )

    def _test_norm_forward(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        scaling_mode,
367
        q_layout,
368
    ):
369
        key = jax.random.PRNGKey(0)
370
        subkeys = jax.random.split(key, 3)
371

372
373
374
375
376
        x = jax.random.uniform(subkeys[0], (n, hidden), inp_dtype, -1, 1)
        x = jnp.asarray(x, inp_dtype)
        gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
        gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
        gamma = jnp.asarray(gamma, inp_dtype)
377

378
        quantizer, ref_quantizer = QuantizerFactory.create(
379
            n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
380
381
382
383
384
385
        )
        if norm_type == "layernorm":
            beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
            beta = jnp.asarray(beta, inp_dtype)
            output, mu, rsigma = tex.layernorm_fwd(
                x, gamma, beta, zero_centered_gamma, epsilon, quantizer=quantizer
386
            )
387
388
            ref_out, ref_mu, ref_rsigma = _jax_layernorm(
                x, gamma, beta, zero_centered_gamma, epsilon, quantizer=ref_quantizer
389
            )
390
391
392
        else:
            output, rsigma = tex.rmsnorm_fwd(
                x, gamma, zero_centered_gamma, epsilon, quantizer=quantizer
393
            )
394
395
            ref_out, ref_rsigma = _jax_rmsnorm(
                x, gamma, zero_centered_gamma, epsilon, quantizer=ref_quantizer
396
            )
397
            ref_mu = None
398

399
400
401
402
403
404
        if get_cudnn_version() < (9, 10, 0):
            # Reduce precision of test as we don't use fused norm below this version CuDNN for MXFP8 and instead
            # do an unfused norm and quantize with an intermediate cast into in_dtype which can reduce precision
            assert_allclose(output.dequantize(), ref_out.dequantize(), dtype=out_dtype)
        else:
            assert_bitwise_scaled_tensors(output, ref_out)
405
406
407
        assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype)
        if norm_type == "layernorm":
            assert_allclose(mu, ref_mu, dtype=inp_dtype)
408

409
410
411
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    # No Norm FWD E5M2 in TE backend
    @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
412
413
414
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
415
    def test_norm_forward_with_delayed_scaling_fp8(
416
        self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout
417
418
419
420
421
422
423
424
425
426
427
428
    ):
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        self._test_norm_forward(
            n=n,
            hidden=hidden,
            norm_type=norm_type,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon,
            inp_dtype=inp_dtype,
            out_dtype=out_dtype,
429
            scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
430
            q_layout=q_layout,
431
        )
432

433
434
435
436
437
438
439
440
441
442
443
444
445
    @pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
    @pytest.mark.parametrize("out_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    def test_norm_forward_with_block_scaling_fp8(
        self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype
    ):
        self._test_norm_forward(
            n=n,
            hidden=hidden,
            norm_type=norm_type,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon,
            inp_dtype=inp_dtype,
            out_dtype=out_dtype,
446
            scaling_mode=ScalingMode.MXFP8_1D_SCALING,
447
            q_layout=QuantizeLayout.ROWWISE_COLWISE,
448
        )
449
450


451
452
453
454
QUANTIZE_OUTPUT_DTYPES = {
    "L0": [jnp.float8_e4m3fn],
    "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
455

456
ALL_QUANTIZE_TEST_SHAPES = [
457
458
    (32, 64),
    (2, 64, 32),
459
]
460

461
462
QUANTIZE_TEST_SHAPES = {
    "L0": [
463
464
        (32, 256, 128),
        (64, 32, 32, 256),
465
466
467
    ],
    "L2": ALL_QUANTIZE_TEST_SHAPES,
}
468

469
470
471
472
473
474
475
476
477
478
479
QUANTIZATION_INPUT_DTYPE = {
    "L0": [jnp.bfloat16],
    "L2": [jnp.float32, jnp.float16, jnp.bfloat16],
}


@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("input_shape", ALL_QUANTIZE_TEST_SHAPES)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
480
@pytest_parametrize_wrapper("flatten_axis", [-1, -2])
481
@pytest_parametrize_wrapper(
482
    "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
483
484
485
486
487
488
)
class TestQuantize:
    """
    Purely quantization related tests that will always test on a wider set of types and shapes
    """

489
    def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
490
        key = jax.random.PRNGKey(0)
491

492
493
494
495
        # Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling)
        quantizer = QuantizerFactory.create(
            scaling_mode=scaling_mode,
            q_dtype=q_dtype,
496
            q_layout=q_layout,
497
        )
498
499
500
        # Adding dimension to test if padding is done correctly when flatten 3D to 2D
        if flatten_axis == -2:
            input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
501

502
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
503
504
505
        for _ in range(n_iterations):
            x = jax.random.uniform(key, input_shape, in_dtype)

506
            scaled_tensor = quantizer.quantize(x, flatten_axis=flatten_axis)
507
508
            assert_dequantized_scaled_tensor(scaled_tensor, x)

509
510
511
    def test_quantize_bitwise(
        self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
    ):
512
513

        key = jax.random.PRNGKey(0)
514
515
        if flatten_axis == -2:
            input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
516
517
518
        input = jax.random.uniform(key, input_shape, in_dtype)

        te_quantizer, jax_quantizer = QuantizerFactory.create(
519
            n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout
520
        )
521

522
        jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
523

524
525
        te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
        assert_bitwise_scaled_tensors(te_output, jax_output)
526
527
528
529
530
531
532
533
534


@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
class TestFusedQuantize:

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
    @pytest_parametrize_wrapper("input_shape", QUANTIZE_TEST_SHAPES)
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
535
536
537
538
539
540
541
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
    @pytest_parametrize_wrapper("flatten_axis", [-1, -2])
    def test_quantize_dbias(
        self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis
    ):
542
        if scaling_mode == ScalingMode.MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8(
543
544
545
546
547
548
549
550
            input_shape
        ):
            pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")

        key = jax.random.PRNGKey(0)
        input = jax.random.uniform(key, input_shape, in_dtype)

        jax_quantizer, te_quantizer = QuantizerFactory.create(
551
            n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
552
        )
553

554
555
556
557
558
        te_output, te_dbias = jit(
            lambda input: tex.quantize_dbias(
                input, quantizer=te_quantizer, flatten_axis=flatten_axis
            )
        )(input)
559
560
561

        jax_output, jax_dbias = jit(
            lambda input: _jax_quantize_dbias(
562
                input, quantizer=jax_quantizer, flatten_axis=flatten_axis
563
            )
564
        )(input)
565

566
        assert_bitwise_scaled_tensors(te_output, jax_output)
567

568
        assert_allclose(te_dbias, jax_dbias)
569
570

    def _test_quantize_dact_dbias(
571
        self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout
572
573
574
575
    ):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1)
576
577
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
578
        dz = jax.random.uniform(subkeys[1], input_shape, in_dtype, -1, 1)
579

580
        jax_quantizer, te_quantizer = QuantizerFactory.create(
581
            n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
        )
        is_casted_output = te_quantizer is not None

        te_output, te_dbias = jit(
            lambda dz, x: tex.quantize_dact_dbias(
                dz,
                x,
                activation_type=activation_type,
                is_dbias=is_dbias,
                quantizer=te_quantizer,
            )
        )(dz, x)

        jax_output, jax_dbias = jit(
            lambda dz, x: _jax_quantize_dact_dbias(
                dz,
                x,
                activation_type=activation_type,
                is_dbias=is_dbias,
                quantizer=jax_quantizer,
            )
        )(dz, x)
604

605
        if is_casted_output:
606
            assert_bitwise_scaled_tensors(te_output, jax_output)
607
        else:
608
            assert_allclose(te_output, jax_output)
609
610

        if is_dbias:
611
            assert_allclose(te_dbias, jax_dbias)
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626

    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("is_dbias", [True, False])
    def test_quantize_dact_dbias_no_quantization(
        self,
        in_dtype,
        input_shape,
        activation_type,
        is_dbias,
    ):
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=in_dtype,
627
            scaling_mode=ScalingMode.NO_SCALING,
628
629
            activation_type=activation_type,
            is_dbias=is_dbias,
630
            q_layout=QuantizeLayout.ROWWISE,
631
        )
632

633
634
635
636
637
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
    @pytest_parametrize_wrapper("is_dbias", [True, False])
638
639
640
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
641
    def test_quantize_dact_dbias_delayed_scaling(
642
        self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout
643
644
645
646
647
    ):
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=out_dtype,
648
            scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
649
650
            activation_type=activation_type,
            is_dbias=is_dbias,
651
            q_layout=q_layout,
652
        )
653

654
655
656
657
658
659
660
    @pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper(
        "input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)]
    )
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
    @pytest_parametrize_wrapper("is_dbias", [True, False])
661
662
663
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
664
    def test_quantize_dact_dbias_mxfp8_scaling(
665
        self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout
666
667
668
669
670
671
672
673
    ):
        if reduce(operator.mul, input_shape[:-1]) % 128 != 0 or input_shape[-1] % 128 != 0:
            # TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes.
            # If it doesn't, move this check into the quantize_dact_dbias function and revert to JAX
            # implementation in the unsupported cases
            pytest.skip(
                f"Input shape {input_shape} is not supported by dact MXFP8 kernel in TE currently"
            )
674

675
676
677
678
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=out_dtype,
679
            scaling_mode=ScalingMode.MXFP8_1D_SCALING,
680
681
            activation_type=activation_type,
            is_dbias=is_dbias,
682
            q_layout=q_layout,
683
        )
684
685


686
class TestDense:
687
688
    def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
        if data_layout[0] == "T":
689
            a = jnp.swapaxes(a, -1, -2)
690
        if data_layout[1] == "T":
691
692
            b = jnp.swapaxes(b, -1, -2)
        return jnp.dot(a, b)
693

694
    def _generate_gemm_input(self, m, n, k, data_layout):
695
696
697
698
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        x = jax.random.uniform(
            subkeys[0],
699
            (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
700
701
702
703
            dtype=jnp.bfloat16,
        ) / jnp.sqrt(k)
        w = jax.random.uniform(
            subkeys[1],
704
            (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
705
706
            dtype=jnp.bfloat16,
        ) / jnp.sqrt(n)
707
708
        lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
        rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
709
710
711
712
        contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)

        return (x, w, contracting_dims)

713
714
715
716
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
    @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
    def test_gemm_bf16(self, m, n, k, data_layout):
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
717
718

        primitive_out = tex.gemm(x, w, contracting_dims)
719
        ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
720

721
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
722

723
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
724
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
725
726
    @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
727
728
729
    @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
    def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, data_layout):
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
730
731
732
733
734
735
        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False
        )
        primitive_out = tex.gemm(
            x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set
        )
736
        ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
737

738
        assert_allclose(primitive_out, ref_out, dtype=q_dtype)
739

740
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
741
    def test_dense_grad_bf16(self, m, n, k):
742
743
        data_layout = "NN"
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
744

745
746
747
        def primitive_func(x, w, contracting_dims):
            primitive_out = dense(x, w, contracting_dims=contracting_dims)
            return jnp.mean(primitive_out)
748

749
750
        def ref_func(x, w, data_layout):
            return jnp.mean(self._ref_gemm_with_jnp_dot(x, w, data_layout))
751

752
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))
753

754
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))
755

756
757
        primitive_out, (primitive_x_grad, primitive_w_grad) = value_n_grad_primitive_func(
            x, w, contracting_dims
758
        )
759
        ref_out, (ref_x_grad, ref_w_grad) = value_n_grad_ref_func(x, w, data_layout)
760
761
762
763

        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
        assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16)
        assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16)
764

765
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
766
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
767
768
769
    @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
    def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode):
770
771
        data_layout = "NN"
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
772
773
774
775
776
777
778
779
780

        key = jax.random.PRNGKey(1)
        bias = jax.random.uniform(key, n, dtype=jnp.bfloat16)

        def primitive_func(x, w, bias, contracting_dims, quantizer_set):
            primitive_out = dense(
                x, w, bias, contracting_dims=contracting_dims, quantizer_set=quantizer_set
            )
            return jnp.mean(primitive_out)
781

782
        def ref_func(x, w, bias, data_layout):
783
            return jnp.mean(
784
                self._ref_gemm_with_jnp_dot(x, w, data_layout) + jnp.expand_dims(bias, axis=0)
785
            )
786

787
788
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
789

790
791
        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True
792
        )
793

794
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
795
796
797
798
799
        for _ in range(n_iterations):
            primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = (
                value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set)
            )

800
801
802
        ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func(
            x, w, bias, data_layout
        )
803
804
805
806
807

        assert_allclose(primitive_out, ref_out, dtype=q_dtype)
        assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype)
        assert_allclose(primitive_w_grad, ref_w_grad, dtype=q_dtype)
        assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=q_dtype)
808
809


810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
    key = jax.random.PRNGKey(0)
    subkeys = jax.random.split(key, 4)
    out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
    return out


def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer):
    if norm_type == "rmsnorm":
        ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
    else:
        ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
    if isinstance(ln_out, ScaledTensor):
        ln_out = ln_out.dequantize()
    return ln_out


class TestFusedDense:
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
830
    @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
831
832
833
834
    @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
    @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
    def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type):
835
        """
836
        Test layernorm_dense VJP Rule
837
        """
838
        # No Norm FWD E5M2 in TE backend
839
        if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
            pytest.skip("E5M2 is not supported in normalization with TE Backend!")

        # zero_centered_gamma is already tested in TestNorm
        zero_centered_gamma = False
        eps = 1e-6

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 4)

        # NN in FWD
        x = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16) / jnp.sqrt(k)
        w = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16) / jnp.sqrt(n)

        gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)

        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode,
            fwd_dtype=q_dtype,
            bwd_dtype=q_dtype,
            is_2x2x=True,
        )

        if norm_type == "layernorm":
            beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
864
        else:
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
            beta = None

        def prim_func(x, w, gamma, beta):
            # bias = None as quantize_dbias is already tested in test_dense_grad_fp8
            prim_out = layernorm_dense(
                x,
                w,
                gamma,
                beta,
                None,
                norm_type,
                zero_centered_gamma,
                eps,
                quantizer_set=quantizer_set,
            )
            return jnp.mean(prim_out)

        def ref_func(x, w, gamma, beta):
            x = _ref_jax_norm_impl(
                x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
            )
            return jnp.mean(jnp.dot(x, w))

        value_n_grad_prim_func = value_and_grad(prim_func, (0, 1, 2, 3))
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))

        ref_out, (ref_x_grad, ref_w_grad, ref_gamma_grad, ref_beta_grad) = value_n_grad_ref_func(
            x, w, gamma, beta
        )

895
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
        for _ in range(n_iterations):
            prim_out, (
                prim_x_grad,
                prim_w_grad,
                prim_gamma_grad,
                prim_beta_grad,
            ) = value_n_grad_prim_func(x, w, gamma, beta)

        assert_allclose(prim_out, ref_out, dtype=q_dtype)
        assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype)
        assert_allclose(prim_w_grad, ref_w_grad, dtype=q_dtype)
        assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype)
        if beta is not None:
            assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
912
    @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
913
914
915
916
917
918
919
    @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
    @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
    @pytest.mark.parametrize("use_bias", [True, False])
    def test_layernorm_mlp_grad(
        self, m, n, k, activation_type, q_dtype, scaling_mode, norm_type, use_bias
920
    ):
921
        """
922
        Test layernorm_mlp VJP Rule
923
        """
924
        # No Norm FWD E5M2 in TE backend
925
        if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
926
927
928
929
930
931
932
933
934
935
936
            pytest.skip("E5M2 is not supported in normalization with TE Backend!")

        # zero_centered_gamma is already tested in TestNorm
        zero_centered_gamma = False
        eps = 1e-6

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 6)

        x = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        kernel_1 = jax.random.normal(
937
            subkeys[1], (k, len(activation_type), n), jnp.bfloat16
938
939
940
941
942
        ) / jnp.sqrt(k)
        kernel_2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
        gamma = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
        beta = None  # was tested in TestNorm
        if use_bias:
943
            bias_1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
            bias_2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
        else:
            bias_1 = None
            bias_2 = None

        quantizer_sets = QuantizerFactory.create_set(
            n_quantizer_sets=2,
            scaling_mode=scaling_mode,
            fwd_dtype=q_dtype,
            bwd_dtype=q_dtype,
            is_2x2x=True,
        )

        if norm_type == "layernorm":
            beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
        else:
            beta = None

        def prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            return jnp.mean(
                layernorm_mlp(
                    x,
                    gamma,
                    beta,
                    [kernel_1, kernel_2],
                    [bias_1, bias_2],
                    norm_type,
                    zero_centered_gamma=zero_centered_gamma,
                    epsilon=eps,
                    activation_type=activation_type,
                    quantizer_sets=quantizer_sets,
975
976
                )
            )
977

978
979
980
        def _ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            ln_out = _ref_jax_norm_impl(
                x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
981
            )
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
            # TODO: replace gemm with jnp.dot
            linear_1_out = tex.gemm(ln_out, kernel_1, ((1,), (0,)))
            if use_bias:
                bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
                linear_1_out += jnp.reshape(bias_1, bias_1_shape)

            x = _jax_act_lu(linear_1_out, activation_type)
            linear_2_out = tex.gemm(x, kernel_2, ((1,), (0,)))
            if use_bias:
                bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape
                linear_2_out += jnp.reshape(bias_2, bias_2_shape)

            return linear_2_out

        def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            return jnp.mean(_ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2))

        value_n_grad_prim_func = value_and_grad(prim_func, range(6))
        value_n_grad_ref_func = value_and_grad(ref_func, range(6))

1002
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
        for _ in range(n_iterations):
            prim_out, (
                prim_x_grad,
                prim_gamma_grad,
                prim_kernel_1_grad,
                prim_kernel_2_grad,
                prim_bias_1_grad,
                prim_bias_2_grad,
            ) = value_n_grad_prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)

        ref_out, (
            ref_x_grad,
            ref_gamma_grad,
            ref_kernel_1_grad,
            ref_kernel_2_grad,
            ref_bias_1_grad,
            ref_bias_2_grad,
        ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)

        assert_allclose(prim_out, ref_out, dtype=q_dtype)

        assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=q_dtype)
        if use_bias:
            assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=q_dtype)

        assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=q_dtype)
        if use_bias:
            assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=q_dtype)

        assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype)
        assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype)


# This function is modified from transformer_engine/jax/cpp_extensions/gemm.py::_jax_gemm()
def _quantize_gemm_pair(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer):
    ((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
    lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1
    rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1
    lhs_q = lhs_quantizer.quantize(
        lhs,
        is_rowwise=lhs_is_rowwise,
        is_colwise=not lhs_is_rowwise,
    )
    rhs_q = rhs_quantizer.quantize(
        rhs,
        is_rowwise=rhs_is_rowwise,
        is_colwise=not rhs_is_rowwise,
    )
    return lhs_q, rhs_q

1053

1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
# E5M2 * E5M2 is not supported
fwd_bwd_dtypes = [
    [jnp.float8_e4m3fn, jnp.float8_e4m3fn],
    [jnp.float8_e4m3fn, jnp.float8_e5m2],
    [jnp.float8_e5m2, jnp.float8_e4m3fn],
]


@pytest_parametrize_wrapper(
    "shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]]
)
class TestGroupedDense:
    def _ref_grouped_gemm_with_jnp_dot(self, lhs_list, rhs_list, contracting_dims_list):
        ref_out_list = []
        for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list):
            dim_nums = (contracting_dims, ((), ()))
            ref_out_list.append(jax.lax.dot_general(lhs, rhs, dim_nums))
        return ref_out_list

    def _generate_grouped_gemm_input(self, dtype, shape_list, layout_list):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, len(shape_list) * 2)

        lhs_list, rhs_list, contracting_dims_list = [], [], []
1078
        for i, ((m, n, k), data_layout) in enumerate(zip(shape_list, layout_list)):
1079
1080
            lhs = jax.random.uniform(
                subkeys[2 * i],
1081
                (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
1082
                dtype=dtype,
1083
            )
1084
1085
            rhs = jax.random.uniform(
                subkeys[2 * i + 1],
1086
                (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
1087
                dtype=dtype,
1088
            )
1089
1090
            lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
            rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
1091
            contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
1092

1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
            lhs_list.append(lhs)
            rhs_list.append(rhs)
            contracting_dims_list.append(contracting_dims)

        return lhs_list, rhs_list, contracting_dims_list

    @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
    @pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]])
    def test_grouped_gemm_fp16(self, dtype, shape_list, layout_list):
        lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input(
            dtype, shape_list, layout_list
        )
        ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list)
        primitive_out = tex.grouped_gemm(lhs_list, rhs_list, contracting_dims_list)
        for i in range(len(shape_list)):
            assert_allclose(primitive_out[i], ref_out[i], dtype=dtype)
1109
1110

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
1111
1112
1113
1114
1115
1116
1117
1118
    @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
    @pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]])
    def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list, layout_list):
        fwd_dtype, bwd_dtype = fwd_bwd_dtype
        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=False
        )
1119

1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
        out_dtype = jnp.bfloat16
        lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input(
            out_dtype, shape_list, layout_list
        )
        q_lhs_list = []
        q_rhs_list = []
        for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list):
            # quantizer_set.x and quantizer_set.kernel have the same q_dtype, we want to
            # test the case where lhs and rhs have different q_dtypes
            q_lhs, q_rhs = _quantize_gemm_pair(
                lhs, rhs, contracting_dims, quantizer_set.x, quantizer_set.dgrad
            )
            q_lhs_list.append(q_lhs)
            q_rhs_list.append(q_rhs)
1134

1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
        ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list)
        primitive_out = tex.grouped_gemm(q_lhs_list, q_rhs_list, contracting_dims_list)

        allclose_dtype = jnp.float8_e4m3fn
        if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2:
            allclose_dtype = jnp.float8_e5m2
        for i in range(len(shape_list)):
            assert_allclose(primitive_out[i], ref_out[i], dtype=allclose_dtype)

    @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
    def test_grouped_dense_grad_fp16(self, dtype, shape_list):
        group_size = len(shape_list)
        layout_list = ["NN" for _ in range(group_size)]

        x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input(
            dtype, shape_list, layout_list
        )
        bias_list = []
        key = jax.random.PRNGKey(1)
        for shape in shape_list:
            n = shape[1]
            bias = jax.random.uniform(key, n, dtype=dtype)
            bias_list.append(bias)

        def ref_func(x_list, kernel_list, bias_list, contracting_dims_list):
            out_list = []
            for i in range(len(x_list)):
                out_list.append(
                    dense(
                        x_list[i],
                        kernel_list[i],
                        bias_list[i],
                        contracting_dims=contracting_dims_list[i],
                    )
1169
                )
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
            # Note: we use jnp.sum instead of jnp.mean to make the gradient larger
            # and prevent them from being clamp to zero
            out_sum_list = [jnp.sum(out) for out in out_list]
            return jnp.sum(jnp.asarray(out_sum_list))

        def primitive_func(x_list, kernel_list, bias_list, contracting_dims_list):
            out_list = grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list)
            out_sum_list = [jnp.sum(out) for out in out_list]
            return jnp.sum(jnp.asarray(out_sum_list))

        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
1182

1183
1184
1185
1186
1187
1188
        ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func(
            x_list, kernel_list, bias_list, contracting_dims_list
        )
        primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = (
            value_n_grad_primitive_func(x_list, kernel_list, bias_list, contracting_dims_list)
        )
1189

1190
1191
1192
1193
1194
        assert_allclose(primitive_out_mean, ref_out_mean, dtype=dtype)
        for i in range(group_size):
            assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=dtype)
            assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=dtype)
            assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=dtype)
1195

1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
    def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list):
        group_size = len(shape_list)
        layout_list = ["NN" for _ in range(group_size)]
        fwd_dtype, bwd_dtype = fwd_bwd_dtype
        if fwd_dtype == jnp.float8_e5m2:
            pytest.skip("We never use E5M2 for fwd_dtype in training")

        # Question: should we use different quantizers for different groups?
        ref_quantizer_set_list = []
        quantizer_set_list = []
        for _ in range(group_size):
            ref_quantizer_set = QuantizerFactory.create_set(
                scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True
1212
            )
1213
1214
1215
1216
1217
            ref_quantizer_set_list.append(ref_quantizer_set)
            quantizer_set = QuantizerFactory.create_set(
                scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True
            )
            quantizer_set_list.append(quantizer_set)
1218

1219
1220
1221
        out_dtype = jnp.bfloat16
        x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input(
            out_dtype, shape_list, layout_list
1222
        )
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
        bias_list = []
        key = jax.random.PRNGKey(1)
        for shape in shape_list:
            n = shape[1]
            bias = jax.random.uniform(key, n, dtype=out_dtype)
            bias_list.append(bias)

        def ref_func(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list):
            out_list = []
            for i in range(len(x_list)):
                out_list.append(
                    dense(
                        x_list[i],
                        kernel_list[i],
                        bias_list[i],
                        contracting_dims=contracting_dims_list[i],
                        quantizer_set=quantizer_set_list[i],
                    )
                )
            # Note: we use jnp.sum instead of jnp.mean to make the gradient larger
            # and prevent them from being clamp to zero
            out_sum_list = [jnp.sum(out) for out in out_list]
            return jnp.sum(jnp.asarray(out_sum_list))

        def primitive_func(
            x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
        ):
            out_list = grouped_dense(
                x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
            )
            out_sum_list = [jnp.sum(out) for out in out_list]
            return jnp.sum(jnp.asarray(out_sum_list))

        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))

        ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func(
            x_list, kernel_list, bias_list, contracting_dims_list, ref_quantizer_set_list
1261
        )
1262
1263
1264
1265
        primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = (
            value_n_grad_primitive_func(
                x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
            )
1266
1267
        )

1268
1269
1270
1271
1272
1273
1274
1275
        allclose_dtype = jnp.float8_e4m3fn
        if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2:
            allclose_dtype = jnp.float8_e5m2
        assert_allclose(primitive_out_mean, ref_out_mean, dtype=allclose_dtype)
        for i in range(group_size):
            assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype)
            assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype)
            assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype)