test_custom_call_compute.py 54.5 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.

import jax
import jax.numpy as jnp
7
import numpy as np
8
9
import pytest
from jax import jit, value_and_grad
10
from functools import reduce
11
from typing import Union
12
13
14
15
16
17
18
19
20
21
22
import operator

from utils import (
    assert_allclose,
    assert_tree_like_allclose,
    pytest_parametrize_wrapper,
)
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.layernorm_mlp import layernorm_mlp

from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu, _jax_quantize_dact_dbias
23
24
25
26
27
from transformer_engine.jax.cpp_extensions.normalization import (
    _jax_layernorm,
    _jax_rmsnorm,
    is_norm_zero_centered_gamma_in_weight_dtype,
)
28
29
30
from transformer_engine.jax.cpp_extensions.quantization import (
    _jax_quantize,
    _jax_quantize_dbias,
31
)
32
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
33
from transformer_engine.jax import cpp_extensions as tex
34
35
36
from transformer_engine.jax.quantize import (
    DelayedScaleQuantizer,
    ScaledTensor,
37
38
39
    ScaledTensor1x,
    ScaledTensor2x,
    GroupedScaledTensor1x,
40
41
    ScalingMode,
    QuantizerFactory,
42
    QuantizeLayout,
43
    noop_quantizer_set,
44
45
46
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
47
from transformer_engine.jax.dense import dense, grouped_dense
48
from transformer_engine.jax.layernorm_dense import layernorm_dense
49

Tim Moon's avatar
Tim Moon committed
50
51
52
53
54
55
56
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (2048, 1024, 2048),
    (2048, 2048, 1024),
    (2048, 1024, 1024),
]
57
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
58
LN_CASES = [(256, 128), (128, 256)]
59
DTYPES = [jnp.bfloat16, jnp.float32]
60
is_fp8_supported, reason = helper.is_fp8_available()
61
is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
62
63
64
65

supported_scaling_modes = []
""" Find supported scaling modes"""
if is_fp8_supported:
66
    supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING)
67
    supported_scaling_modes.append(ScalingMode.CURRENT_TENSOR_SCALING)
68
if is_mxfp8_supported:
69
    supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING)
70
71
72
73
74
75


def is_shape_supported_by_mxfp8(input_shape):
    try:
        if isinstance(input_shape, type(pytest.param(0))):
            input_shape = input_shape.values[0]
76
        ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape)
77
78
79
80
81
82
83
84
        return True
    except:
        # get_scale_shapes will raise an exception if the shape is not supported
        return False


def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
    if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
85
        assert a.scaling_mode == b.scaling_mode
86
        assert a.scale_inv.dtype == b.scale_inv.dtype
87
88
89
90
91
        if a.scaling_mode.is_tensor_scaling():
            # Assert in dq_dtype as some unfused codepaths have an intermediate cast
            # to an input dtype which reduces precision compared to everything in fp32
            assert_allclose(a.scale_inv, b.scale_inv, dtype=a.dq_dtype)
        elif a.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
92
93
94
            # Compare MXFP8 scales as uint8
            assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8))
        else:
95
            raise ValueError(f"Unsupported scaling mode {a.scaling_mode}")
96
        assert_allclose(a.data, b.data)
97

98
99
100
101
102
103
104
105
106
    elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x):
        assert_bitwise_scaled_tensors(a.rowwise_tensor, b.rowwise_tensor)
        assert_bitwise_scaled_tensors(a.colwise_tensor, b.colwise_tensor)
    else:
        pytest.fail("Unsupported input types")


def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
    if isinstance(a, ScaledTensor1x):
107
108
109
        if a.data_layout == "T":
            flatten_axis = a.data.ndim - a.flatten_axis
            b_transpose = jnp.transpose(b, (*range(flatten_axis, b.ndim), *range(flatten_axis)))
110
111
112
113
114
115
116
117
118
119
            assert_allclose(a.dequantize(), b_transpose, dtype=a.data.dtype)
        else:
            assert_allclose(a.dequantize(), b, dtype=a.data.dtype)
    elif isinstance(a, ScaledTensor2x):
        assert_dequantized_scaled_tensor(a.get_rowwise_tensor(), b)
        assert_dequantized_scaled_tensor(a.get_colwise_tensor(), b)
    else:
        pytest.fail("a must be a ScaledTensor object")


120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def assert_dequantized_grouped_scaled_tensor(
    a: Union[GroupedScaledTensor1x, ScaledTensor2x], b: jnp.ndarray
):
    if isinstance(a, GroupedScaledTensor1x):
        assert a.group_sizes.sum() == b.shape[0]
        b = jnp.split(b, jnp.cumulative_sum(a.group_sizes)[:-1], axis=0)
        dq_a = a.dequantize()
        for dq_a_i, b_i in zip(dq_a, b):
            if len(dq_a_i) == 0:
                continue
            if a.data_layout == "T":
                data_ndim = len(a.original_shape)
                flatten_axis = a.flatten_axis
                if b_i.shape[0] == 1:
                    b_i = jnp.transpose(
                        b_i, (0, *range(flatten_axis, data_ndim), *range(1, flatten_axis))
                    )
                else:
                    b_i = jnp.transpose(
                        b_i, (*range(flatten_axis, data_ndim), *range(flatten_axis))
                    )
            dq_a_i = dq_a_i.reshape(b_i.shape)
            assert_allclose(dq_a_i, b_i, dtype=a.data.dtype)
    elif isinstance(a, ScaledTensor2x):
        assert isinstance(a.get_rowwise_tensor(), GroupedScaledTensor1x)
        assert isinstance(a.get_colwise_tensor(), GroupedScaledTensor1x)
        assert_dequantized_grouped_scaled_tensor(a.get_rowwise_tensor(), b)
        assert_dequantized_grouped_scaled_tensor(a.get_colwise_tensor(), b)
    else:
        pytest.fail("a must be a GroupedScaledTensor object")


152
153
154
155
156
157
158
159
160
161
162
163
164
ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)]
ALL_ACTIVATION_TYPES = [
    ("gelu",),
    ("gelu", "linear"),
    ("silu",),
    ("silu", "linear"),
    ("relu",),
    ("relu", "linear"),
    ("quick_gelu",),
    ("quick_gelu", "linear"),
    ("squared_relu",),
    ("squared_relu", "linear"),
]
165

166
167
168
169
170
171
172
ACTIVATION_TYPES = {
    "L0": [
        ("gelu",),
        ("gelu", "linear"),
    ],
    "L2": ALL_ACTIVATION_TYPES,
}
173
174


175
176
177
178
179
180
181
182
183
class TestActivation:
    def ref_act(self, x, activation_type):
        return _jax_act_lu(x, activation_type)

    def value_n_grad_ref_func(self, x, activation_type):
        jitted_reference = jit(
            value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type)), (0,))
        )
        return jitted_reference(x)
184

185
186
187
188
189
190
191
192
193
194
195
196
    def primitive_func(self, inputs, activation_type, quantizer):
        out = activation(inputs, activation_type=activation_type, quantizer=quantizer)
        return jnp.mean(out)

    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper(
        "activation_type",
        (
            ALL_ACTIVATION_TYPES  # Test all activation types for this test to ensure all are functional, then just test a subset for the other tests to verify other functionality
        ),
    )
    def test_act_grad(self, shape, activation_type):
197
        key = jax.random.PRNGKey(0)
198
        x = jax.random.uniform(key, shape, jnp.float32)
199
200
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
201

202
203
204
        value_n_grad_primitive_func = jit(
            value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
        )
205

206
207
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None)
        ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type)
208

209
210
        assert_allclose(prim_out, ref_out, dtype=x.dtype)
        assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
211

212
213
214
215
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
216
217
218
219
220
221
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_act_grad_with_tensor_scaling_fp8(
        self, random_inputs, activation_type, output_type, scaling_mode
    ):
222
        x = random_inputs
223
224
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
225
        self.activation_type = activation_type
226

227
228
229
        value_n_grad_primitive_func = jit(
            value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
        )
230

231
        quantizer = QuantizerFactory.create(
232
            scaling_mode=scaling_mode,
233
            q_dtype=output_type,
234
            q_layout=QuantizeLayout.ROWWISE,
235
        )
236

237
238
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer)
        ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type)
239

240
241
        assert_allclose(prim_out, ref_out, dtype=output_type)
        assert_allclose(prim_grad, ref_grad, dtype=output_type)
242

243
244
245
246
    @pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
247
248
249
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
250
251
252
253
254
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_act_forward_with_tensor_scaling_fp8(
        self, random_inputs, activation_type, output_type, q_layout, scaling_mode
255
256
    ):
        x = random_inputs
257
258
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
259
        self.activation_type = activation_type
260

261
262
        te_quantizer, jax_quantizer = QuantizerFactory.create(
            n_quantizers=2,
263
            scaling_mode=scaling_mode,
264
            q_dtype=output_type,
265
            q_layout=q_layout,
266
        )
267

268
269
        te_output = tex.act_lu(x, activation_type, te_quantizer)
        jax_output = _jax_act_lu(x, activation_type, jax_quantizer)
270

271
        assert_bitwise_scaled_tensors(te_output, jax_output)
272

273
    @pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
274
    @pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)])
275
276
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
277
278
279
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
280
    def test_act_forward_with_block_scaling_fp8(
281
        self, random_inputs, activation_type, output_type, q_layout
282
283
    ):
        x = random_inputs
284
        x = jnp.repeat(x, len(activation_type), axis=-2)
285
        self.activation_type = activation_type
286

287
        quantizer = QuantizerFactory.create(
288
            scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
289
        )
290

291
292
        output = tex.act_lu(x, activation_type, quantizer)
        ref_out = self.ref_act(x, activation_type)
293

294
        assert_dequantized_scaled_tensor(output, ref_out)
295
296


297
298
299
300
NORM_OUTPUT_DTYPES = {
    "L0": [jnp.float8_e4m3fn],
    "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
301

302

303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
@pytest_parametrize_wrapper("n, hidden", LN_CASES)
@pytest_parametrize_wrapper("inp_dtype", DTYPES)
@pytest_parametrize_wrapper("norm_type", ["layernorm", "rmsnorm"])
@pytest_parametrize_wrapper(
    "zero_centered_gamma",
    [
        pytest.param(True, id="zero_centered"),
        pytest.param(False, id="no_zero_centered"),
    ],
)
@pytest_parametrize_wrapper("epsilon", [1e-2, 1e-6])
class TestNorm:
    """
    Test transformer_engine.jax.layernorm APIs
    """
318

319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
    def _test_norm_grad(
        self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
    ):
        def compute_loss(x):
            # Higher precision to compute the loss
            x_ = x.astype(jnp.float32)
            return jnp.mean(jnp.square(x_)).astype(x.dtype)

        def reference_func(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer):
            if norm_type == "rmsnorm":
                ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
            else:
                ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
            # if isinstance(ln_out, ScaledTensor):
            #     ln_out = ln_out.dequantize()
            return ln_out
335

336
337
338
339
340
341
342
343
344
345
346
347
348
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 3)

        x = jax.random.uniform(subkeys[0], (n, hidden), jnp.float32, -1, 1)
        x = x.astype(inp_dtype)
        gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
        gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
        gamma = jnp.asarray(gamma, inp_dtype)
        if norm_type == "layernorm":
            beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
            beta = jnp.asarray(beta, inp_dtype)
        else:
            beta = None
349

350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        jitted_reference = jit(
            value_and_grad(
                lambda x, gamma, beta: compute_loss(
                    reference_func(
                        x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer=None
                    )
                ),
                (0, 1, 2),
            )
        )
        jitted_primitive = jit(
            value_and_grad(
                lambda x, gamma, beta: compute_loss(
                    layernorm(x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer)
                ),
                (0, 1, 2),
366
            )
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
        )

        reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference(
            x, gamma, beta
        )
        primitive_out, (primitive_dx, primitive_dgamma, primitive_dbeta) = jitted_primitive(
            x, gamma, beta
        )

        out_dtype = inp_dtype if quantizer is None else quantizer.q_dtype
        assert_allclose(primitive_out, reference_out, dtype=out_dtype)
        assert_allclose(primitive_dx, reference_dx, dtype=out_dtype)
        assert_allclose(primitive_dgamma, reference_dgamma, dtype=out_dtype)
        if beta is not None:
            assert_allclose(primitive_dbeta, reference_dbeta, dtype=out_dtype)
382

383
384
385
386
387
388
389
390
391
392
    def test_norm_grad(self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype):
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        self._test_norm_grad(
            n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer=None
        )
393

394
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
395
396
    # No Norm FWD E5M2 in TE backend
    @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
397
398
399
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
400
401
402
403
404
405
406
407
408
409
410
411
412
413
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_norm_grad_with_tensor_scaling_fp8(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        q_layout,
        scaling_mode,
414
415
416
417
418
419
420
421
    ):
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        quantizer = QuantizerFactory.create(
422
            scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
        )
        self._test_norm_grad(
            n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
        )

    def _test_norm_forward(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        scaling_mode,
438
        q_layout,
439
    ):
440
        key = jax.random.PRNGKey(0)
441
        subkeys = jax.random.split(key, 3)
442

443
444
445
446
447
        x = jax.random.uniform(subkeys[0], (n, hidden), inp_dtype, -1, 1)
        x = jnp.asarray(x, inp_dtype)
        gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
        gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
        gamma = jnp.asarray(gamma, inp_dtype)
448

449
        quantizer, ref_quantizer = QuantizerFactory.create(
450
            n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
451
452
453
454
455
456
        )
        if norm_type == "layernorm":
            beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
            beta = jnp.asarray(beta, inp_dtype)
            output, mu, rsigma = tex.layernorm_fwd(
                x, gamma, beta, zero_centered_gamma, epsilon, quantizer=quantizer
457
            )
458
459
            ref_out, ref_mu, ref_rsigma = _jax_layernorm(
                x, gamma, beta, zero_centered_gamma, epsilon, quantizer=ref_quantizer
460
            )
461
462
463
        else:
            output, rsigma = tex.rmsnorm_fwd(
                x, gamma, zero_centered_gamma, epsilon, quantizer=quantizer
464
            )
465
466
            ref_out, ref_rsigma = _jax_rmsnorm(
                x, gamma, zero_centered_gamma, epsilon, quantizer=ref_quantizer
467
            )
468
            ref_mu = None
469

470
471
472
        precise_comparison = True

        if get_cudnn_version() < (9, 10, 0) and scaling_mode == ScalingMode.MXFP8_1D_SCALING:
473
474
            # Reduce precision of test as we don't use fused norm below this version CuDNN for MXFP8 and instead
            # do an unfused norm and quantize with an intermediate cast into in_dtype which can reduce precision
475
476
477
478
479
480
481
482
483
484
485
486
            precise_comparison = False
        elif is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode):
            # Larger tolerances as our JAX implementation _jax_*norm uses the compute dtype float32
            # for zero-centered gamma always
            precise_comparison = False
        elif scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING and inp_dtype != jnp.float32:
            # Current implementation of Current Tensor Scaling performs unfused layernorm and quantization
            # and writes intermediate results into the input dtype, which will slightly reduce precision
            # if the input dtype is not float32
            precise_comparison = False

        if precise_comparison:
487
            assert_bitwise_scaled_tensors(output, ref_out)
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
        else:
            if isinstance(ref_out, ScaledTensor1x):
                assert_allclose(output.dequantize(), ref_out.dequantize(), dtype=out_dtype)
            elif isinstance(ref_out, ScaledTensor2x):
                assert_allclose(
                    output.rowwise_tensor.dequantize(),
                    ref_out.rowwise_tensor.dequantize(),
                    dtype=out_dtype,
                )
                assert_allclose(
                    output.colwise_tensor.dequantize(),
                    ref_out.colwise_tensor.dequantize(),
                    dtype=out_dtype,
                )
            else:
                pytest.fail("Unsupported output type")

505
506
507
        assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype)
        if norm_type == "layernorm":
            assert_allclose(mu, ref_mu, dtype=inp_dtype)
508

509
510
511
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    # No Norm FWD E5M2 in TE backend
    @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
512
513
514
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
515
516
517
518
519
520
521
522
523
524
525
526
527
528
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_norm_forward_with_tensor_scaling_fp8(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        q_layout,
        scaling_mode,
529
530
531
532
533
534
535
536
537
538
539
540
    ):
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        self._test_norm_forward(
            n=n,
            hidden=hidden,
            norm_type=norm_type,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon,
            inp_dtype=inp_dtype,
            out_dtype=out_dtype,
541
            scaling_mode=scaling_mode,
542
            q_layout=q_layout,
543
        )
544

545
546
547
548
549
550
551
552
553
554
555
556
557
    @pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
    @pytest.mark.parametrize("out_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    def test_norm_forward_with_block_scaling_fp8(
        self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype
    ):
        self._test_norm_forward(
            n=n,
            hidden=hidden,
            norm_type=norm_type,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon,
            inp_dtype=inp_dtype,
            out_dtype=out_dtype,
558
            scaling_mode=ScalingMode.MXFP8_1D_SCALING,
559
            q_layout=QuantizeLayout.ROWWISE_COLWISE,
560
        )
561
562


563
564
565
566
QUANTIZE_OUTPUT_DTYPES = {
    "L0": [jnp.float8_e4m3fn],
    "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
567

568
569
570
ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [
    ((32, 64), -1),
    ((2, 64, 32), -1),
571
    ((64, 2, 32), -2),
572
573
574
575
576
    ((32, 256, 128), -1),
    ((32, 256, 128), -2),
    ((64, 32, 32, 256), -1),
    ((64, 32, 32, 256), -2),
    ((64, 32, 32, 256), -3),
577
]
578

579
QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = {
580
    "L0": [
581
582
        ((32, 64), -1),
        ((2, 64, 32), -1),
583
        ((64, 2, 32), -2),
584
    ],
585
    "L2": ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES,
586
}
587

588
589
590
591
592
593
594
595
596
QUANTIZATION_INPUT_DTYPE = {
    "L0": [jnp.bfloat16],
    "L2": [jnp.float32, jnp.float16, jnp.bfloat16],
}


@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
597
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
598
599
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper(
600
    "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
601
602
603
604
605
606
)
class TestQuantize:
    """
    Purely quantization related tests that will always test on a wider set of types and shapes
    """

607
    def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
608
        key = jax.random.PRNGKey(0)
609

610
611
612
613
        # Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling)
        quantizer = QuantizerFactory.create(
            scaling_mode=scaling_mode,
            q_dtype=q_dtype,
614
            q_layout=q_layout,
615
        )
616

617
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
618
619
620
        for _ in range(n_iterations):
            x = jax.random.uniform(key, input_shape, in_dtype)

621
            scaled_tensor = quantizer.quantize(x, flatten_axis=flatten_axis)
622
623
            assert_dequantized_scaled_tensor(scaled_tensor, x)

624
625
626
    def test_quantize_bitwise(
        self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
    ):
627
628
629
630
631

        key = jax.random.PRNGKey(0)
        input = jax.random.uniform(key, input_shape, in_dtype)

        te_quantizer, jax_quantizer = QuantizerFactory.create(
632
            n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout
633
        )
634

635
        jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
636

637
638
        te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
        assert_bitwise_scaled_tensors(te_output, jax_output)
639
640


641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("flatten_axis", [-1])
@pytest_parametrize_wrapper("with_group_sizes", [True, False])
@pytest_parametrize_wrapper(
    "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE]
)
class TestGroupedQuantize:
    def test_grouped_qdq(
        self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis, with_group_sizes
    ):
        n_groups, m, n = input_shape
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        # *32 so that the input shapes works for MXFP8
        input_shape = (m * 32, n)

        if with_group_sizes:
            group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
            group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
            group_sizes = jnp.diff(group_sizes)
            assert group_sizes.sum() == m
            assert jnp.any(group_sizes == 0)  # make sure that at least one group has 0 row
            group_sizes = group_sizes * 32
        else:
            group_sizes = None
            input_shape = (n_groups, input_shape[0] // n_groups, input_shape[1])

        if flatten_axis == -2:
            input_shape = input_shape[:-1] + (2,) + input_shape[-1:]

        x = jax.random.uniform(subkeys[1], input_shape, in_dtype)

        grouped_quantizer = QuantizerFactory.create(
            scaling_mode=scaling_mode,
            q_dtype=q_dtype,
            q_layout=q_layout,
            n_groups=n_groups,
        )

        scaled_tensor = tex.grouped_quantize(
            x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer
        )

        assert_dequantized_grouped_scaled_tensor(scaled_tensor, x)


692
693
694
695
696
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
class TestFusedQuantize:

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
697
    @pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
698
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
699
700
701
702
703
704
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
    def test_quantize_dbias(
        self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis
    ):
705
        if scaling_mode == ScalingMode.MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8(
706
707
708
709
710
711
712
713
            input_shape
        ):
            pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")

        key = jax.random.PRNGKey(0)
        input = jax.random.uniform(key, input_shape, in_dtype)

        jax_quantizer, te_quantizer = QuantizerFactory.create(
714
            n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
715
        )
716

717
718
719
720
721
        te_output, te_dbias = jit(
            lambda input: tex.quantize_dbias(
                input, quantizer=te_quantizer, flatten_axis=flatten_axis
            )
        )(input)
722
723
724

        jax_output, jax_dbias = jit(
            lambda input: _jax_quantize_dbias(
725
                input, quantizer=jax_quantizer, flatten_axis=flatten_axis
726
            )
727
        )(input)
728

729
        assert_bitwise_scaled_tensors(te_output, jax_output)
730

731
        assert_allclose(te_dbias, jax_dbias)
732
733

    def _test_quantize_dact_dbias(
734
        self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout
735
736
737
738
    ):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1)
739
740
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
741
        dz = jax.random.uniform(subkeys[1], input_shape, in_dtype, -1, 1)
742

743
        jax_quantizer, te_quantizer = QuantizerFactory.create(
744
            n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
        )
        is_casted_output = te_quantizer is not None

        te_output, te_dbias = jit(
            lambda dz, x: tex.quantize_dact_dbias(
                dz,
                x,
                activation_type=activation_type,
                is_dbias=is_dbias,
                quantizer=te_quantizer,
            )
        )(dz, x)

        jax_output, jax_dbias = jit(
            lambda dz, x: _jax_quantize_dact_dbias(
                dz,
                x,
                activation_type=activation_type,
                is_dbias=is_dbias,
                quantizer=jax_quantizer,
            )
        )(dz, x)
767

768
        if is_casted_output:
769
            assert_bitwise_scaled_tensors(te_output, jax_output)
770
        else:
771
            assert_allclose(te_output, jax_output)
772
773

        if is_dbias:
774
            assert_allclose(te_dbias, jax_dbias)
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789

    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("is_dbias", [True, False])
    def test_quantize_dact_dbias_no_quantization(
        self,
        in_dtype,
        input_shape,
        activation_type,
        is_dbias,
    ):
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=in_dtype,
790
            scaling_mode=ScalingMode.NO_SCALING,
791
792
            activation_type=activation_type,
            is_dbias=is_dbias,
793
            q_layout=QuantizeLayout.ROWWISE,
794
        )
795

796
797
798
799
800
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
    @pytest_parametrize_wrapper("is_dbias", [True, False])
801
    @pytest_parametrize_wrapper(
802
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
803
    )
804
805
806
807
808
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_quantize_dact_dbias_tensor_scaling(
        self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout, scaling_mode
809
810
811
812
813
    ):
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=out_dtype,
814
            scaling_mode=scaling_mode,
815
816
            activation_type=activation_type,
            is_dbias=is_dbias,
817
            q_layout=q_layout,
818
        )
819

820
821
822
823
824
825
826
    @pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper(
        "input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)]
    )
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
    @pytest_parametrize_wrapper("is_dbias", [True, False])
827
828
829
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
830
    def test_quantize_dact_dbias_mxfp8_scaling(
831
        self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout
832
833
834
835
836
837
838
839
    ):
        if reduce(operator.mul, input_shape[:-1]) % 128 != 0 or input_shape[-1] % 128 != 0:
            # TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes.
            # If it doesn't, move this check into the quantize_dact_dbias function and revert to JAX
            # implementation in the unsupported cases
            pytest.skip(
                f"Input shape {input_shape} is not supported by dact MXFP8 kernel in TE currently"
            )
840

841
842
843
844
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=out_dtype,
845
            scaling_mode=ScalingMode.MXFP8_1D_SCALING,
846
847
            activation_type=activation_type,
            is_dbias=is_dbias,
848
            q_layout=q_layout,
849
        )
850
851


852
class TestDense:
853
854
    def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
        if data_layout[0] == "T":
855
            a = jnp.swapaxes(a, -1, -2)
856
        if data_layout[1] == "T":
857
858
            b = jnp.swapaxes(b, -1, -2)
        return jnp.dot(a, b)
859

860
    def _generate_gemm_input(self, m, n, k, data_layout):
861
862
863
864
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        x = jax.random.uniform(
            subkeys[0],
865
            (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
866
867
868
869
            dtype=jnp.bfloat16,
        ) / jnp.sqrt(k)
        w = jax.random.uniform(
            subkeys[1],
870
            (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
871
872
            dtype=jnp.bfloat16,
        ) / jnp.sqrt(n)
873
874
        lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
        rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
875
876
877
878
        contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)

        return (x, w, contracting_dims)

879
880
881
882
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
    @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
    def test_gemm_bf16(self, m, n, k, data_layout):
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
883
884

        primitive_out = tex.gemm(x, w, contracting_dims)
885
        ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
886

887
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
888

889
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
890
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
891
892
    @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
893
894
895
    @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
    def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, data_layout):
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
896
897
898
899
900
901
        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False
        )
        primitive_out = tex.gemm(
            x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set
        )
902
        ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
903

904
        assert_allclose(primitive_out, ref_out, dtype=q_dtype)
905

906
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
907
    def test_dense_grad_bf16(self, m, n, k):
908
909
        data_layout = "NN"
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
910

911
912
913
        def primitive_func(x, w, contracting_dims):
            primitive_out = dense(x, w, contracting_dims=contracting_dims)
            return jnp.mean(primitive_out)
914

915
916
        def ref_func(x, w, data_layout):
            return jnp.mean(self._ref_gemm_with_jnp_dot(x, w, data_layout))
917

918
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))
919

920
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))
921

922
923
        primitive_out, (primitive_x_grad, primitive_w_grad) = value_n_grad_primitive_func(
            x, w, contracting_dims
924
        )
925
        ref_out, (ref_x_grad, ref_w_grad) = value_n_grad_ref_func(x, w, data_layout)
926
927
928
929

        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
        assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16)
        assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16)
930

931
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
932
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
933
934
935
    @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
    def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode):
936
937
        data_layout = "NN"
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
938
939
940
941
942
943
944
945
946

        key = jax.random.PRNGKey(1)
        bias = jax.random.uniform(key, n, dtype=jnp.bfloat16)

        def primitive_func(x, w, bias, contracting_dims, quantizer_set):
            primitive_out = dense(
                x, w, bias, contracting_dims=contracting_dims, quantizer_set=quantizer_set
            )
            return jnp.mean(primitive_out)
947

948
        def ref_func(x, w, bias, data_layout):
949
            return jnp.mean(
950
                self._ref_gemm_with_jnp_dot(x, w, data_layout) + jnp.expand_dims(bias, axis=0)
951
            )
952

953
954
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
955

956
957
        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True
958
        )
959

960
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
961
962
963
964
965
        for _ in range(n_iterations):
            primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = (
                value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set)
            )

966
967
968
        ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func(
            x, w, bias, data_layout
        )
969
970
971
972
973

        assert_allclose(primitive_out, ref_out, dtype=q_dtype)
        assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype)
        assert_allclose(primitive_w_grad, ref_w_grad, dtype=q_dtype)
        assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=q_dtype)
974
975


976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
    key = jax.random.PRNGKey(0)
    subkeys = jax.random.split(key, 4)
    out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
    return out


def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer):
    if norm_type == "rmsnorm":
        ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
    else:
        ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
    if isinstance(ln_out, ScaledTensor):
        ln_out = ln_out.dequantize()
    return ln_out


class TestFusedDense:
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
996
    @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
997
998
999
1000
    @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
    @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
    def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type):
1001
        """
1002
        Test layernorm_dense VJP Rule
1003
        """
1004
        # No Norm FWD E5M2 in TE backend
1005
1006
1007
1008
        if q_dtype == jnp.float8_e5m2 and scaling_mode in (
            ScalingMode.DELAYED_TENSOR_SCALING,
            ScalingMode.CURRENT_TENSOR_SCALING,
        ):
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
            pytest.skip("E5M2 is not supported in normalization with TE Backend!")

        # zero_centered_gamma is already tested in TestNorm
        zero_centered_gamma = False
        eps = 1e-6

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 4)

        # NN in FWD
        x = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16) / jnp.sqrt(k)
        w = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16) / jnp.sqrt(n)

        gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)

        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode,
            fwd_dtype=q_dtype,
            bwd_dtype=q_dtype,
            is_2x2x=True,
        )

        if norm_type == "layernorm":
            beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
1033
        else:
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
            beta = None

        def prim_func(x, w, gamma, beta):
            # bias = None as quantize_dbias is already tested in test_dense_grad_fp8
            prim_out = layernorm_dense(
                x,
                w,
                gamma,
                beta,
                None,
                norm_type,
                zero_centered_gamma,
                eps,
                quantizer_set=quantizer_set,
            )
            return jnp.mean(prim_out)

        def ref_func(x, w, gamma, beta):
            x = _ref_jax_norm_impl(
                x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
            )
            return jnp.mean(jnp.dot(x, w))

        value_n_grad_prim_func = value_and_grad(prim_func, (0, 1, 2, 3))
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))

        ref_out, (ref_x_grad, ref_w_grad, ref_gamma_grad, ref_beta_grad) = value_n_grad_ref_func(
            x, w, gamma, beta
        )

1064
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
        for _ in range(n_iterations):
            prim_out, (
                prim_x_grad,
                prim_w_grad,
                prim_gamma_grad,
                prim_beta_grad,
            ) = value_n_grad_prim_func(x, w, gamma, beta)

        assert_allclose(prim_out, ref_out, dtype=q_dtype)
        assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype)
        assert_allclose(prim_w_grad, ref_w_grad, dtype=q_dtype)
        assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype)
        if beta is not None:
            assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
1081
    @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
1082
1083
1084
1085
1086
1087
1088
    @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
    @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
    @pytest.mark.parametrize("use_bias", [True, False])
    def test_layernorm_mlp_grad(
        self, m, n, k, activation_type, q_dtype, scaling_mode, norm_type, use_bias
1089
    ):
1090
        """
1091
        Test layernorm_mlp VJP Rule
1092
        """
1093
        # No Norm FWD E5M2 in TE backend
1094
1095
1096
1097
        if q_dtype == jnp.float8_e5m2 and scaling_mode in (
            ScalingMode.DELAYED_TENSOR_SCALING,
            ScalingMode.CURRENT_TENSOR_SCALING,
        ):
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
            pytest.skip("E5M2 is not supported in normalization with TE Backend!")

        # zero_centered_gamma is already tested in TestNorm
        zero_centered_gamma = False
        eps = 1e-6

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 6)

        x = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        kernel_1 = jax.random.normal(
1109
            subkeys[1], (k, len(activation_type), n), jnp.bfloat16
1110
1111
1112
1113
1114
        ) / jnp.sqrt(k)
        kernel_2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
        gamma = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
        beta = None  # was tested in TestNorm
        if use_bias:
1115
            bias_1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
            bias_2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
        else:
            bias_1 = None
            bias_2 = None

        quantizer_sets = QuantizerFactory.create_set(
            n_quantizer_sets=2,
            scaling_mode=scaling_mode,
            fwd_dtype=q_dtype,
            bwd_dtype=q_dtype,
            is_2x2x=True,
        )

        if norm_type == "layernorm":
            beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
        else:
            beta = None

        def prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            return jnp.mean(
                layernorm_mlp(
                    x,
                    gamma,
                    beta,
                    [kernel_1, kernel_2],
                    [bias_1, bias_2],
                    norm_type,
                    zero_centered_gamma=zero_centered_gamma,
                    epsilon=eps,
                    activation_type=activation_type,
                    quantizer_sets=quantizer_sets,
1147
1148
                )
            )
1149

1150
1151
1152
        def _ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            ln_out = _ref_jax_norm_impl(
                x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
1153
            )
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
            # TODO: replace gemm with jnp.dot
            linear_1_out = tex.gemm(ln_out, kernel_1, ((1,), (0,)))
            if use_bias:
                bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
                linear_1_out += jnp.reshape(bias_1, bias_1_shape)

            x = _jax_act_lu(linear_1_out, activation_type)
            linear_2_out = tex.gemm(x, kernel_2, ((1,), (0,)))
            if use_bias:
                bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape
                linear_2_out += jnp.reshape(bias_2, bias_2_shape)

            return linear_2_out

        def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            return jnp.mean(_ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2))

        value_n_grad_prim_func = value_and_grad(prim_func, range(6))
        value_n_grad_ref_func = value_and_grad(ref_func, range(6))

1174
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
        for _ in range(n_iterations):
            prim_out, (
                prim_x_grad,
                prim_gamma_grad,
                prim_kernel_1_grad,
                prim_kernel_2_grad,
                prim_bias_1_grad,
                prim_bias_2_grad,
            ) = value_n_grad_prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)

        ref_out, (
            ref_x_grad,
            ref_gamma_grad,
            ref_kernel_1_grad,
            ref_kernel_2_grad,
            ref_bias_1_grad,
            ref_bias_2_grad,
        ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)

        assert_allclose(prim_out, ref_out, dtype=q_dtype)

        assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=q_dtype)
        if use_bias:
            assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=q_dtype)

        assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=q_dtype)
        if use_bias:
            assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=q_dtype)

        assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype)
        assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype)


# E5M2 * E5M2 is not supported
fwd_bwd_dtypes = [
    [jnp.float8_e4m3fn, jnp.float8_e4m3fn],
    [jnp.float8_e4m3fn, jnp.float8_e5m2],
    [jnp.float8_e5m2, jnp.float8_e4m3fn],
]

1215
1216
1217
1218
1219
1220
1221
1222
1223
GROUPED_DENSE_INPUT_SHAPES = [
    # (n_groups, m, n, k), the actual m will be multiplied by 32
    (5, 32, 128, 64),  # Test the case where n_groups is not a multiple of 4
    (8, 64, 32, 128),
    (8, 64, 128, 256),
]


@pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES)
1224
class TestGroupedDense:
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
    def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims):
        lhs_contract_dim, _ = contracting_dims
        assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3
        if bias is None:
            bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype)
        else:
            assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2])
        remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop()
        lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis)
        rhs = jnp.split(rhs, rhs.shape[0], axis=0)
        bias = jnp.split(bias, bias.shape[0], axis=0)
        ref_out = []
        dim_num = (contracting_dims, ((), ()))
        for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias):
            out_i = jax.lax.dot_general(lhs_i, rhs_i, dim_num) + jnp.expand_dims(bias_i, axis=0)
            ref_out.append(jnp.squeeze(out_i))
        return ref_out

    def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False):
1244
        key = jax.random.PRNGKey(0)
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
        subkeys = jax.random.split(key, 4)
        n_groups, m, n, k = input_shape

        group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
        group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
        group_sizes = jnp.diff(group_sizes)
        assert group_sizes.sum() == m

        # *32 to make sure that input shape works for MXFP8
        group_sizes = group_sizes * 32
        m = m * 32
1256

1257
1258
1259
        lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m)
        rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k)
        bias_shape = (n_groups, n)
1260

1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
        lhs = jax.random.uniform(subkeys[1], lhs_shape, dtype=dtype)
        rhs = jax.random.uniform(subkeys[2], rhs_shape, dtype=dtype)
        bias = jax.random.uniform(subkeys[3], bias_shape, dtype=dtype) if with_bias else None

        lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
        rhs_contracting_dim = (1,) if data_layout[1] == "N" else (2,)
        contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)

        return lhs, rhs, group_sizes, contracting_dims, bias

    def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype):
        assert out.dtype == ref_list[0].dtype
        out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0)
        for i in range(len(ref_list)):
            assert_allclose(out_list[i], ref_list[i], dtype=dtype)
1276
1277

    @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
1278
1279
1280
1281
    @pytest_parametrize_wrapper("layout", ["NN"])
    def test_grouped_gemm_fp16(self, dtype, input_shape, layout):
        lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
            dtype, input_shape, layout
1282
        )
1283
1284
1285
        ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
        prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims)
        self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
1286
1287

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
1288
1289
    @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
1290
1291
1292
1293
1294
    @pytest_parametrize_wrapper("layout", ["NN"])
    def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout):
        if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
            pytest.skip("MXFP8 is not supported in grouped_gemm yet")

1295
1296
        fwd_dtype, bwd_dtype = fwd_bwd_dtype
        quantizer_set = QuantizerFactory.create_set(
1297
1298
1299
1300
1301
            scaling_mode=scaling_mode,
            fwd_dtype=fwd_dtype,
            bwd_dtype=bwd_dtype,
            is_2x2x=False,
            n_groups=input_shape[0],
1302
        )
1303

1304
1305
1306
1307
1308
1309
        # quantizer_set.{x, kernel} has fwd_dtype, while quantizer_set.grad has bwd_dtype
        # We want to test E4M3 * E5M2, manually set the quantizer_set.kernel.q_dtype to bwd_dtype
        quantizer_set.kernel.q_dtype = bwd_dtype
        for quantizer in quantizer_set.kernel.quantizers:
            quantizer.q_dtype = bwd_dtype

1310
        out_dtype = jnp.bfloat16
1311
1312
1313
1314
1315
1316
        lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
            out_dtype, input_shape, layout
        )
        ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
        prim_out = tex.grouped_gemm(
            lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
1317
1318
1319
        )

        allclose_dtype = jnp.float8_e4m3fn
1320
        if jnp.float8_e5m2 in fwd_bwd_dtype:
1321
1322
            allclose_dtype = jnp.float8_e5m2

1323
        self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, allclose_dtype)
1324

1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
    def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims):
        out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims)
        # Note: we use jnp.sum instead of jnp.mean to make the gradient larger
        # and prevent them from being clamp to zero
        out_sum_list = [jnp.sum(out) for out in out_list]
        return jnp.sum(jnp.asarray(out_sum_list))

    def _primitive_sum_grouped_dense(
        self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set
    ):
        out = grouped_dense(
            x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set
1337
        )
1338
        return jnp.sum(jnp.asarray(out))
1339

1340
1341
1342
1343
1344
1345
1346
    @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
    def test_grouped_dense_grad_fp16(self, dtype, input_shape):
        x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
            dtype,
            input_shape,
            with_bias=True,
        )
1347

1348
1349
        value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
        value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2))
1350

1351
1352
        ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
            x, kernel, bias, group_sizes, contracting_dims
1353
        )
1354
1355
        prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func(
            x, kernel, bias, group_sizes, contracting_dims
1356
        )
1357

1358
1359
1360
1361
        assert_allclose(prim_out_sum, ref_out_sum, dtype=dtype)
        assert_allclose(prim_dgrad, ref_dgrad, dtype=dtype)
        assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype)
        assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
1362

1363
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
1364
1365
1366
1367
    @pytest.mark.parametrize(
        "fwd_bwd_dtype",
        [(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)],
    )
1368
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
1369
1370
1371
    def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
        if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
            pytest.skip("MXFP8 is not supported in grouped_dense yet")
1372

1373
1374
1375
1376
1377
1378
        fwd_dtype, bwd_dtype = fwd_bwd_dtype
        dtype = jnp.bfloat16
        x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
            dtype,
            input_shape,
            with_bias=True,
1379
        )
1380

1381
1382
1383
1384
1385
1386
        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode,
            fwd_dtype=fwd_dtype,
            bwd_dtype=bwd_dtype,
            is_2x2x=True,
            n_groups=group_sizes.size,
1387
        )
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
        value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
        value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2))

        ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
            x,
            kernel,
            bias,
            group_sizes,
            contracting_dims,
        )
        prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func(
            x, kernel, bias, group_sizes, contracting_dims, quantizer_set=quantizer_set
1400
1401
        )

1402
1403
1404
1405
        assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype)
        assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype)
        assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype)
        assert_allclose(prim_dbias, ref_dbias, dtype=dtype)