test_custom_call_compute.py 56 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.

import jax
import jax.numpy as jnp
import pytest
from jax import jit, value_and_grad
9
from functools import reduce
10
from typing import Union
11
12
13
14
15
16
17
18
19
20
import operator

from utils import (
    assert_allclose,
    pytest_parametrize_wrapper,
)
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.layernorm_mlp import layernorm_mlp

from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu, _jax_quantize_dact_dbias
21
22
23
24
25
from transformer_engine.jax.cpp_extensions.normalization import (
    _jax_layernorm,
    _jax_rmsnorm,
    is_norm_zero_centered_gamma_in_weight_dtype,
)
26
27
28
from transformer_engine.jax.cpp_extensions.quantization import (
    _jax_quantize,
    _jax_quantize_dbias,
29
)
30
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
31
from transformer_engine.jax import cpp_extensions as tex
32
33
34
from transformer_engine.jax.quantize import (
    DelayedScaleQuantizer,
    ScaledTensor,
35
36
37
    ScaledTensor1x,
    ScaledTensor2x,
    GroupedScaledTensor1x,
38
39
    ScalingMode,
    QuantizerFactory,
40
    QuantizeLayout,
41
    noop_quantizer_set,
42
43
44
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
45
from transformer_engine.jax.dense import dense, grouped_dense
46
from transformer_engine.jax.layernorm_dense import layernorm_dense
47

Tim Moon's avatar
Tim Moon committed
48
49
50
51
52
53
54
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (2048, 1024, 2048),
    (2048, 2048, 1024),
    (2048, 1024, 1024),
]
55
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
56
LN_CASES = [(256, 128), (128, 256)]
57
DTYPES = [jnp.bfloat16, jnp.float32]
58
59
is_fp8_supported, fp8_unsupported_reason = helper.is_fp8_available()
is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
60
61
62
63

supported_scaling_modes = []
""" Find supported scaling modes"""
if is_fp8_supported:
64
    supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING)
65
    supported_scaling_modes.append(ScalingMode.CURRENT_TENSOR_SCALING)
66
if is_mxfp8_supported:
67
    supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING)
68
69
70
71
72
73


def is_shape_supported_by_mxfp8(input_shape):
    try:
        if isinstance(input_shape, type(pytest.param(0))):
            input_shape = input_shape.values[0]
74
        ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape)
75
76
77
78
79
80
81
82
        return True
    except:
        # get_scale_shapes will raise an exception if the shape is not supported
        return False


def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
    if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
83
        assert a.scaling_mode == b.scaling_mode
84
        assert a.scale_inv.dtype == b.scale_inv.dtype
85
86
87
88
89
        if a.scaling_mode.is_tensor_scaling():
            # Assert in dq_dtype as some unfused codepaths have an intermediate cast
            # to an input dtype which reduces precision compared to everything in fp32
            assert_allclose(a.scale_inv, b.scale_inv, dtype=a.dq_dtype)
        elif a.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
90
91
92
            # Compare MXFP8 scales as uint8
            assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8))
        else:
93
            raise ValueError(f"Unsupported scaling mode {a.scaling_mode}")
94
        assert_allclose(a.data, b.data)
95

96
97
98
99
100
101
102
103
104
    elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x):
        assert_bitwise_scaled_tensors(a.rowwise_tensor, b.rowwise_tensor)
        assert_bitwise_scaled_tensors(a.colwise_tensor, b.colwise_tensor)
    else:
        pytest.fail("Unsupported input types")


def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
    if isinstance(a, ScaledTensor1x):
105
106
107
        if a.data_layout == "T":
            flatten_axis = a.data.ndim - a.flatten_axis
            b_transpose = jnp.transpose(b, (*range(flatten_axis, b.ndim), *range(flatten_axis)))
108
109
110
111
112
113
114
115
116
117
            assert_allclose(a.dequantize(), b_transpose, dtype=a.data.dtype)
        else:
            assert_allclose(a.dequantize(), b, dtype=a.data.dtype)
    elif isinstance(a, ScaledTensor2x):
        assert_dequantized_scaled_tensor(a.get_rowwise_tensor(), b)
        assert_dequantized_scaled_tensor(a.get_colwise_tensor(), b)
    else:
        pytest.fail("a must be a ScaledTensor object")


118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def assert_dequantized_grouped_scaled_tensor(
    a: Union[GroupedScaledTensor1x, ScaledTensor2x], b: jnp.ndarray
):
    if isinstance(a, GroupedScaledTensor1x):
        assert a.group_sizes.sum() == b.shape[0]
        b = jnp.split(b, jnp.cumulative_sum(a.group_sizes)[:-1], axis=0)
        dq_a = a.dequantize()
        for dq_a_i, b_i in zip(dq_a, b):
            if len(dq_a_i) == 0:
                continue
            if a.data_layout == "T":
                data_ndim = len(a.original_shape)
                flatten_axis = a.flatten_axis
                if b_i.shape[0] == 1:
                    b_i = jnp.transpose(
                        b_i, (0, *range(flatten_axis, data_ndim), *range(1, flatten_axis))
                    )
                else:
                    b_i = jnp.transpose(
                        b_i, (*range(flatten_axis, data_ndim), *range(flatten_axis))
                    )
            dq_a_i = dq_a_i.reshape(b_i.shape)
            assert_allclose(dq_a_i, b_i, dtype=a.data.dtype)
    elif isinstance(a, ScaledTensor2x):
        assert isinstance(a.get_rowwise_tensor(), GroupedScaledTensor1x)
        assert isinstance(a.get_colwise_tensor(), GroupedScaledTensor1x)
        assert_dequantized_grouped_scaled_tensor(a.get_rowwise_tensor(), b)
        assert_dequantized_grouped_scaled_tensor(a.get_colwise_tensor(), b)
    else:
        pytest.fail("a must be a GroupedScaledTensor object")


150
151
152
153
154
155
156
157
158
159
160
161
162
ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)]
ALL_ACTIVATION_TYPES = [
    ("gelu",),
    ("gelu", "linear"),
    ("silu",),
    ("silu", "linear"),
    ("relu",),
    ("relu", "linear"),
    ("quick_gelu",),
    ("quick_gelu", "linear"),
    ("squared_relu",),
    ("squared_relu", "linear"),
]
163

164
165
166
167
168
169
170
ACTIVATION_TYPES = {
    "L0": [
        ("gelu",),
        ("gelu", "linear"),
    ],
    "L2": ALL_ACTIVATION_TYPES,
}
171
172


173
174
175
176
177
178
179
180
181
class TestActivation:
    def ref_act(self, x, activation_type):
        return _jax_act_lu(x, activation_type)

    def value_n_grad_ref_func(self, x, activation_type):
        jitted_reference = jit(
            value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type)), (0,))
        )
        return jitted_reference(x)
182

183
184
185
186
187
188
189
190
191
192
193
194
    def primitive_func(self, inputs, activation_type, quantizer):
        out = activation(inputs, activation_type=activation_type, quantizer=quantizer)
        return jnp.mean(out)

    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper(
        "activation_type",
        (
            ALL_ACTIVATION_TYPES  # Test all activation types for this test to ensure all are functional, then just test a subset for the other tests to verify other functionality
        ),
    )
    def test_act_grad(self, shape, activation_type):
195
        key = jax.random.PRNGKey(0)
196
        x = jax.random.uniform(key, shape, jnp.float32)
197
198
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
199

200
201
202
        value_n_grad_primitive_func = jit(
            value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
        )
203

204
205
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None)
        ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type)
206

207
208
        assert_allclose(prim_out, ref_out, dtype=x.dtype)
        assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
209

210
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
211
212
213
    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
214
215
216
217
218
219
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_act_grad_with_tensor_scaling_fp8(
        self, random_inputs, activation_type, output_type, scaling_mode
    ):
220
        x = random_inputs
221
222
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
223
        self.activation_type = activation_type
224

225
226
227
        value_n_grad_primitive_func = jit(
            value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
        )
228

229
        quantizer = QuantizerFactory.create(
230
            scaling_mode=scaling_mode,
231
            q_dtype=output_type,
232
            q_layout=QuantizeLayout.ROWWISE,
233
        )
234

235
236
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer)
        ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type)
237

238
239
        assert_allclose(prim_out, ref_out, dtype=output_type)
        assert_allclose(prim_grad, ref_grad, dtype=output_type)
240

241
    @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
242
243
244
    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
245
246
247
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
248
249
250
251
252
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_act_forward_with_tensor_scaling_fp8(
        self, random_inputs, activation_type, output_type, q_layout, scaling_mode
253
254
    ):
        x = random_inputs
255
256
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
257
        self.activation_type = activation_type
258

259
260
        te_quantizer, jax_quantizer = QuantizerFactory.create(
            n_quantizers=2,
261
            scaling_mode=scaling_mode,
262
            q_dtype=output_type,
263
            q_layout=q_layout,
264
        )
265

266
267
        te_output = tex.act_lu(x, activation_type, te_quantizer)
        jax_output = _jax_act_lu(x, activation_type, jax_quantizer)
268

269
        assert_bitwise_scaled_tensors(te_output, jax_output)
270

271
    @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
272
    @pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)])
273
274
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
275
276
277
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
278
    def test_act_forward_with_block_scaling_fp8(
279
        self, random_inputs, activation_type, output_type, q_layout
280
281
    ):
        x = random_inputs
282
        x = jnp.repeat(x, len(activation_type), axis=-2)
283
        self.activation_type = activation_type
284

285
        quantizer = QuantizerFactory.create(
286
            scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
287
        )
288

289
290
        output = tex.act_lu(x, activation_type, quantizer)
        ref_out = self.ref_act(x, activation_type)
291

292
        assert_dequantized_scaled_tensor(output, ref_out)
293
294


295
296
297
298
NORM_OUTPUT_DTYPES = {
    "L0": [jnp.float8_e4m3fn],
    "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
299

300

301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
@pytest_parametrize_wrapper("n, hidden", LN_CASES)
@pytest_parametrize_wrapper("inp_dtype", DTYPES)
@pytest_parametrize_wrapper("norm_type", ["layernorm", "rmsnorm"])
@pytest_parametrize_wrapper(
    "zero_centered_gamma",
    [
        pytest.param(True, id="zero_centered"),
        pytest.param(False, id="no_zero_centered"),
    ],
)
@pytest_parametrize_wrapper("epsilon", [1e-2, 1e-6])
class TestNorm:
    """
    Test transformer_engine.jax.layernorm APIs
    """
316

317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    def _test_norm_grad(
        self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
    ):
        def compute_loss(x):
            # Higher precision to compute the loss
            x_ = x.astype(jnp.float32)
            return jnp.mean(jnp.square(x_)).astype(x.dtype)

        def reference_func(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer):
            if norm_type == "rmsnorm":
                ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
            else:
                ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
            # if isinstance(ln_out, ScaledTensor):
            #     ln_out = ln_out.dequantize()
            return ln_out
333

334
335
336
337
338
339
340
341
342
343
344
345
346
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 3)

        x = jax.random.uniform(subkeys[0], (n, hidden), jnp.float32, -1, 1)
        x = x.astype(inp_dtype)
        gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
        gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
        gamma = jnp.asarray(gamma, inp_dtype)
        if norm_type == "layernorm":
            beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
            beta = jnp.asarray(beta, inp_dtype)
        else:
            beta = None
347

348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
        jitted_reference = jit(
            value_and_grad(
                lambda x, gamma, beta: compute_loss(
                    reference_func(
                        x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer=None
                    )
                ),
                (0, 1, 2),
            )
        )
        jitted_primitive = jit(
            value_and_grad(
                lambda x, gamma, beta: compute_loss(
                    layernorm(x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer)
                ),
                (0, 1, 2),
364
            )
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
        )

        reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference(
            x, gamma, beta
        )
        primitive_out, (primitive_dx, primitive_dgamma, primitive_dbeta) = jitted_primitive(
            x, gamma, beta
        )

        out_dtype = inp_dtype if quantizer is None else quantizer.q_dtype
        assert_allclose(primitive_out, reference_out, dtype=out_dtype)
        assert_allclose(primitive_dx, reference_dx, dtype=out_dtype)
        assert_allclose(primitive_dgamma, reference_dgamma, dtype=out_dtype)
        if beta is not None:
            assert_allclose(primitive_dbeta, reference_dbeta, dtype=out_dtype)
380

381
382
383
384
385
386
387
388
389
390
    def test_norm_grad(self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype):
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        self._test_norm_grad(
            n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer=None
        )
391

392
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
393
394
    # No Norm FWD E5M2 in TE backend
    @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
395
396
397
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
398
399
400
401
402
403
404
405
406
407
408
409
410
411
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_norm_grad_with_tensor_scaling_fp8(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        q_layout,
        scaling_mode,
412
413
414
415
416
417
418
419
    ):
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        quantizer = QuantizerFactory.create(
420
            scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
        )
        self._test_norm_grad(
            n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
        )

    def _test_norm_forward(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        scaling_mode,
436
        q_layout,
437
    ):
438
        key = jax.random.PRNGKey(0)
439
        subkeys = jax.random.split(key, 3)
440

441
442
443
444
445
        x = jax.random.uniform(subkeys[0], (n, hidden), inp_dtype, -1, 1)
        x = jnp.asarray(x, inp_dtype)
        gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
        gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
        gamma = jnp.asarray(gamma, inp_dtype)
446

447
        quantizer, ref_quantizer = QuantizerFactory.create(
448
            n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
449
450
451
452
453
454
        )
        if norm_type == "layernorm":
            beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
            beta = jnp.asarray(beta, inp_dtype)
            output, mu, rsigma = tex.layernorm_fwd(
                x, gamma, beta, zero_centered_gamma, epsilon, quantizer=quantizer
455
            )
456
457
            ref_out, ref_mu, ref_rsigma = _jax_layernorm(
                x, gamma, beta, zero_centered_gamma, epsilon, quantizer=ref_quantizer
458
            )
459
460
461
        else:
            output, rsigma = tex.rmsnorm_fwd(
                x, gamma, zero_centered_gamma, epsilon, quantizer=quantizer
462
            )
463
464
            ref_out, ref_rsigma = _jax_rmsnorm(
                x, gamma, zero_centered_gamma, epsilon, quantizer=ref_quantizer
465
            )
466
            ref_mu = None
467

468
469
470
        precise_comparison = True

        if get_cudnn_version() < (9, 10, 0) and scaling_mode == ScalingMode.MXFP8_1D_SCALING:
471
472
            # Reduce precision of test as we don't use fused norm below this version CuDNN for MXFP8 and instead
            # do an unfused norm and quantize with an intermediate cast into in_dtype which can reduce precision
473
474
475
476
477
478
479
480
481
482
483
484
            precise_comparison = False
        elif is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode):
            # Larger tolerances as our JAX implementation _jax_*norm uses the compute dtype float32
            # for zero-centered gamma always
            precise_comparison = False
        elif scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING and inp_dtype != jnp.float32:
            # Current implementation of Current Tensor Scaling performs unfused layernorm and quantization
            # and writes intermediate results into the input dtype, which will slightly reduce precision
            # if the input dtype is not float32
            precise_comparison = False

        if precise_comparison:
485
            assert_bitwise_scaled_tensors(output, ref_out)
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
        else:
            if isinstance(ref_out, ScaledTensor1x):
                assert_allclose(output.dequantize(), ref_out.dequantize(), dtype=out_dtype)
            elif isinstance(ref_out, ScaledTensor2x):
                assert_allclose(
                    output.rowwise_tensor.dequantize(),
                    ref_out.rowwise_tensor.dequantize(),
                    dtype=out_dtype,
                )
                assert_allclose(
                    output.colwise_tensor.dequantize(),
                    ref_out.colwise_tensor.dequantize(),
                    dtype=out_dtype,
                )
            else:
                pytest.fail("Unsupported output type")

503
504
505
        assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype)
        if norm_type == "layernorm":
            assert_allclose(mu, ref_mu, dtype=inp_dtype)
506

507
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
508
509
    # No Norm FWD E5M2 in TE backend
    @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
510
511
512
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
513
514
515
516
517
518
519
520
521
522
523
524
525
526
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_norm_forward_with_tensor_scaling_fp8(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        q_layout,
        scaling_mode,
527
528
529
530
531
532
533
534
535
536
537
538
    ):
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        self._test_norm_forward(
            n=n,
            hidden=hidden,
            norm_type=norm_type,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon,
            inp_dtype=inp_dtype,
            out_dtype=out_dtype,
539
            scaling_mode=scaling_mode,
540
            q_layout=q_layout,
541
        )
542

543
    @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
544
545
546
547
548
549
550
551
552
553
554
555
    @pytest.mark.parametrize("out_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    def test_norm_forward_with_block_scaling_fp8(
        self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype
    ):
        self._test_norm_forward(
            n=n,
            hidden=hidden,
            norm_type=norm_type,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon,
            inp_dtype=inp_dtype,
            out_dtype=out_dtype,
556
            scaling_mode=ScalingMode.MXFP8_1D_SCALING,
557
            q_layout=QuantizeLayout.ROWWISE_COLWISE,
558
        )
559
560


561
562
563
564
QUANTIZE_OUTPUT_DTYPES = {
    "L0": [jnp.float8_e4m3fn],
    "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
565

566
567
568
ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [
    ((32, 64), -1),
    ((2, 64, 32), -1),
569
    ((64, 2, 32), -2),
570
571
572
573
574
    ((32, 256, 128), -1),
    ((32, 256, 128), -2),
    ((64, 32, 32, 256), -1),
    ((64, 32, 32, 256), -2),
    ((64, 32, 32, 256), -3),
575
]
576

577
QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = {
578
    "L0": [
579
580
        ((32, 64), -1),
        ((2, 64, 32), -1),
581
        ((64, 2, 32), -2),
582
    ],
583
    "L2": ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES,
584
}
585

586
587
588
589
590
591
QUANTIZATION_INPUT_DTYPE = {
    "L0": [jnp.bfloat16],
    "L2": [jnp.float32, jnp.float16, jnp.bfloat16],
}


592
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
593
594
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
595
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
596
597
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper(
598
    "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
599
600
601
602
603
604
)
class TestQuantize:
    """
    Purely quantization related tests that will always test on a wider set of types and shapes
    """

605
    def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
606
        key = jax.random.PRNGKey(0)
607

608
609
610
611
        # Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling)
        quantizer = QuantizerFactory.create(
            scaling_mode=scaling_mode,
            q_dtype=q_dtype,
612
            q_layout=q_layout,
613
        )
614

615
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
616
617
618
        for _ in range(n_iterations):
            x = jax.random.uniform(key, input_shape, in_dtype)

619
            scaled_tensor = quantizer.quantize(x, flatten_axis=flatten_axis)
620
621
            assert_dequantized_scaled_tensor(scaled_tensor, x)

622
623
624
    def test_quantize_bitwise(
        self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
    ):
625
626
627
628
629

        key = jax.random.PRNGKey(0)
        input = jax.random.uniform(key, input_shape, in_dtype)

        te_quantizer, jax_quantizer = QuantizerFactory.create(
630
            n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout
631
        )
632

633
        jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
634

635
636
        te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
        assert_bitwise_scaled_tensors(te_output, jax_output)
637
638


639
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("flatten_axis", [-1])
@pytest_parametrize_wrapper("with_group_sizes", [True, False])
@pytest_parametrize_wrapper(
    "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE]
)
class TestGroupedQuantize:
    def test_grouped_qdq(
        self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis, with_group_sizes
    ):
        n_groups, m, n = input_shape
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        # *32 so that the input shapes works for MXFP8
        input_shape = (m * 32, n)

        if with_group_sizes:
            group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
            group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
            group_sizes = jnp.diff(group_sizes)
            assert group_sizes.sum() == m
            assert jnp.any(group_sizes == 0)  # make sure that at least one group has 0 row
            group_sizes = group_sizes * 32
        else:
            group_sizes = None
            input_shape = (n_groups, input_shape[0] // n_groups, input_shape[1])

        if flatten_axis == -2:
            input_shape = input_shape[:-1] + (2,) + input_shape[-1:]

        x = jax.random.uniform(subkeys[1], input_shape, in_dtype)

        grouped_quantizer = QuantizerFactory.create(
            scaling_mode=scaling_mode,
            q_dtype=q_dtype,
            q_layout=q_layout,
            n_groups=n_groups,
        )

683
684
685
686
        # grouped_quantize does not work with cudaGraph yet, so the jitting will breaks
        # To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
        # disable cudaGraph, then use the following jitted function

687
688
689
690
691
692
693
        scaled_tensor = tex.grouped_quantize(
            x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer
        )

        assert_dequantized_grouped_scaled_tensor(scaled_tensor, x)


694
695
696
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
class TestFusedQuantize:

697
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
698
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
699
    @pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
700
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
701
702
703
704
705
706
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
    def test_quantize_dbias(
        self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis
    ):
707
        if scaling_mode == ScalingMode.MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8(
708
709
710
711
712
713
714
715
            input_shape
        ):
            pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")

        key = jax.random.PRNGKey(0)
        input = jax.random.uniform(key, input_shape, in_dtype)

        jax_quantizer, te_quantizer = QuantizerFactory.create(
716
            n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
717
        )
718

719
720
721
722
723
        te_output, te_dbias = jit(
            lambda input: tex.quantize_dbias(
                input, quantizer=te_quantizer, flatten_axis=flatten_axis
            )
        )(input)
724
725
726

        jax_output, jax_dbias = jit(
            lambda input: _jax_quantize_dbias(
727
                input, quantizer=jax_quantizer, flatten_axis=flatten_axis
728
            )
729
        )(input)
730

731
        assert_bitwise_scaled_tensors(te_output, jax_output)
732

733
        assert_allclose(te_dbias, jax_dbias)
734
735

    def _test_quantize_dact_dbias(
736
        self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout
737
738
739
740
    ):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1)
741
742
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
743
        dz = jax.random.uniform(subkeys[1], input_shape, in_dtype, -1, 1)
744

745
        jax_quantizer, te_quantizer = QuantizerFactory.create(
746
            n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
        )
        is_casted_output = te_quantizer is not None

        te_output, te_dbias = jit(
            lambda dz, x: tex.quantize_dact_dbias(
                dz,
                x,
                activation_type=activation_type,
                is_dbias=is_dbias,
                quantizer=te_quantizer,
            )
        )(dz, x)

        jax_output, jax_dbias = jit(
            lambda dz, x: _jax_quantize_dact_dbias(
                dz,
                x,
                activation_type=activation_type,
                is_dbias=is_dbias,
                quantizer=jax_quantizer,
            )
        )(dz, x)
769

770
        if is_casted_output:
771
            assert_bitwise_scaled_tensors(te_output, jax_output)
772
        else:
773
            assert_allclose(te_output, jax_output)
774
775

        if is_dbias:
776
            assert_allclose(te_dbias, jax_dbias)
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791

    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("is_dbias", [True, False])
    def test_quantize_dact_dbias_no_quantization(
        self,
        in_dtype,
        input_shape,
        activation_type,
        is_dbias,
    ):
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=in_dtype,
792
            scaling_mode=ScalingMode.NO_SCALING,
793
794
            activation_type=activation_type,
            is_dbias=is_dbias,
795
            q_layout=QuantizeLayout.ROWWISE,
796
        )
797

798
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
799
800
801
802
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
    @pytest_parametrize_wrapper("is_dbias", [True, False])
803
    @pytest_parametrize_wrapper(
804
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
805
    )
806
807
808
809
810
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_quantize_dact_dbias_tensor_scaling(
        self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout, scaling_mode
811
812
813
814
815
    ):
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=out_dtype,
816
            scaling_mode=scaling_mode,
817
818
            activation_type=activation_type,
            is_dbias=is_dbias,
819
            q_layout=q_layout,
820
        )
821

822
    @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
823
824
825
826
827
828
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper(
        "input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)]
    )
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
    @pytest_parametrize_wrapper("is_dbias", [True, False])
829
830
831
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
832
    def test_quantize_dact_dbias_mxfp8_scaling(
833
        self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout
834
835
836
837
838
839
840
841
    ):
        if reduce(operator.mul, input_shape[:-1]) % 128 != 0 or input_shape[-1] % 128 != 0:
            # TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes.
            # If it doesn't, move this check into the quantize_dact_dbias function and revert to JAX
            # implementation in the unsupported cases
            pytest.skip(
                f"Input shape {input_shape} is not supported by dact MXFP8 kernel in TE currently"
            )
842

843
844
845
846
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=out_dtype,
847
            scaling_mode=ScalingMode.MXFP8_1D_SCALING,
848
849
            activation_type=activation_type,
            is_dbias=is_dbias,
850
            q_layout=q_layout,
851
        )
852
853


854
class TestDense:
855
856
    def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
        if data_layout[0] == "T":
857
            a = jnp.swapaxes(a, -1, -2)
858
        if data_layout[1] == "T":
859
860
            b = jnp.swapaxes(b, -1, -2)
        return jnp.dot(a, b)
861

862
    def _generate_gemm_input(self, m, n, k, data_layout):
863
864
865
866
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        x = jax.random.uniform(
            subkeys[0],
867
            (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
868
869
870
871
            dtype=jnp.bfloat16,
        ) / jnp.sqrt(k)
        w = jax.random.uniform(
            subkeys[1],
872
            (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
873
874
            dtype=jnp.bfloat16,
        ) / jnp.sqrt(n)
875
876
        lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
        rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
877
878
879
880
        contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)

        return (x, w, contracting_dims)

881
882
883
884
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
    @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
    def test_gemm_bf16(self, m, n, k, data_layout):
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
885
886

        primitive_out = tex.gemm(x, w, contracting_dims)
887
        ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
888

889
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
890

891
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
892
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
893
894
    @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
895
896
897
    @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
    def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, data_layout):
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
898
899
900
901
902
903
        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False
        )
        primitive_out = tex.gemm(
            x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set
        )
904
        ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
905

906
        assert_allclose(primitive_out, ref_out, dtype=q_dtype)
907

908
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
909
    def test_dense_grad_bf16(self, m, n, k):
910
911
        data_layout = "NN"
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
912

913
914
915
        def primitive_func(x, w, contracting_dims):
            primitive_out = dense(x, w, contracting_dims=contracting_dims)
            return jnp.mean(primitive_out)
916

917
918
        def ref_func(x, w, data_layout):
            return jnp.mean(self._ref_gemm_with_jnp_dot(x, w, data_layout))
919

920
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))
921

922
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))
923

924
925
        primitive_out, (primitive_x_grad, primitive_w_grad) = value_n_grad_primitive_func(
            x, w, contracting_dims
926
        )
927
        ref_out, (ref_x_grad, ref_w_grad) = value_n_grad_ref_func(x, w, data_layout)
928
929
930
931

        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
        assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16)
        assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16)
932

933
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
934
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
935
936
937
    @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
    def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode):
938
939
        data_layout = "NN"
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
940
941
942
943
944
945
946
947
948

        key = jax.random.PRNGKey(1)
        bias = jax.random.uniform(key, n, dtype=jnp.bfloat16)

        def primitive_func(x, w, bias, contracting_dims, quantizer_set):
            primitive_out = dense(
                x, w, bias, contracting_dims=contracting_dims, quantizer_set=quantizer_set
            )
            return jnp.mean(primitive_out)
949

950
        def ref_func(x, w, bias, data_layout):
951
            return jnp.mean(
952
                self._ref_gemm_with_jnp_dot(x, w, data_layout) + jnp.expand_dims(bias, axis=0)
953
            )
954

955
956
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
957

958
959
        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True
960
        )
961

962
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
963
964
965
966
967
        for _ in range(n_iterations):
            primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = (
                value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set)
            )

968
969
970
        ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func(
            x, w, bias, data_layout
        )
971
972
973
974
975

        assert_allclose(primitive_out, ref_out, dtype=q_dtype)
        assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype)
        assert_allclose(primitive_w_grad, ref_w_grad, dtype=q_dtype)
        assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=q_dtype)
976
977


978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
    key = jax.random.PRNGKey(0)
    subkeys = jax.random.split(key, 4)
    out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
    return out


def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer):
    if norm_type == "rmsnorm":
        ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
    else:
        ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
    if isinstance(ln_out, ScaledTensor):
        ln_out = ln_out.dequantize()
    return ln_out


class TestFusedDense:
997
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
998
    @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
999
1000
1001
1002
    @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
    @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
    def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type):
1003
        """
1004
        Test layernorm_dense VJP Rule
1005
        """
1006
        # No Norm FWD E5M2 in TE backend
1007
1008
1009
1010
        if q_dtype == jnp.float8_e5m2 and scaling_mode in (
            ScalingMode.DELAYED_TENSOR_SCALING,
            ScalingMode.CURRENT_TENSOR_SCALING,
        ):
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
            pytest.skip("E5M2 is not supported in normalization with TE Backend!")

        # zero_centered_gamma is already tested in TestNorm
        zero_centered_gamma = False
        eps = 1e-6

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 4)

        # NN in FWD
        x = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16) / jnp.sqrt(k)
        w = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16) / jnp.sqrt(n)

        gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)

        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode,
            fwd_dtype=q_dtype,
            bwd_dtype=q_dtype,
            is_2x2x=True,
        )

        if norm_type == "layernorm":
            beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
1035
        else:
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
            beta = None

        def prim_func(x, w, gamma, beta):
            # bias = None as quantize_dbias is already tested in test_dense_grad_fp8
            prim_out = layernorm_dense(
                x,
                w,
                gamma,
                beta,
                None,
                norm_type,
                zero_centered_gamma,
                eps,
                quantizer_set=quantizer_set,
            )
            return jnp.mean(prim_out)

        def ref_func(x, w, gamma, beta):
            x = _ref_jax_norm_impl(
                x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
            )
            return jnp.mean(jnp.dot(x, w))

        value_n_grad_prim_func = value_and_grad(prim_func, (0, 1, 2, 3))
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))

        ref_out, (ref_x_grad, ref_w_grad, ref_gamma_grad, ref_beta_grad) = value_n_grad_ref_func(
            x, w, gamma, beta
        )

1066
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
        for _ in range(n_iterations):
            prim_out, (
                prim_x_grad,
                prim_w_grad,
                prim_gamma_grad,
                prim_beta_grad,
            ) = value_n_grad_prim_func(x, w, gamma, beta)

        assert_allclose(prim_out, ref_out, dtype=q_dtype)
        assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype)
        assert_allclose(prim_w_grad, ref_w_grad, dtype=q_dtype)
        assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype)
        if beta is not None:
            assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype)

1082
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1083
    @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
1084
1085
1086
1087
1088
1089
1090
    @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
    @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
    @pytest.mark.parametrize("use_bias", [True, False])
    def test_layernorm_mlp_grad(
        self, m, n, k, activation_type, q_dtype, scaling_mode, norm_type, use_bias
1091
    ):
1092
        """
1093
        Test layernorm_mlp VJP Rule
1094
        """
1095
        # No Norm FWD E5M2 in TE backend
1096
1097
1098
1099
        if q_dtype == jnp.float8_e5m2 and scaling_mode in (
            ScalingMode.DELAYED_TENSOR_SCALING,
            ScalingMode.CURRENT_TENSOR_SCALING,
        ):
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
            pytest.skip("E5M2 is not supported in normalization with TE Backend!")

        # zero_centered_gamma is already tested in TestNorm
        zero_centered_gamma = False
        eps = 1e-6

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 6)

        x = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        kernel_1 = jax.random.normal(
1111
            subkeys[1], (k, len(activation_type), n), jnp.bfloat16
1112
1113
1114
1115
1116
        ) / jnp.sqrt(k)
        kernel_2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
        gamma = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
        beta = None  # was tested in TestNorm
        if use_bias:
1117
            bias_1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
            bias_2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
        else:
            bias_1 = None
            bias_2 = None

        quantizer_sets = QuantizerFactory.create_set(
            n_quantizer_sets=2,
            scaling_mode=scaling_mode,
            fwd_dtype=q_dtype,
            bwd_dtype=q_dtype,
            is_2x2x=True,
        )

        if norm_type == "layernorm":
            beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
        else:
            beta = None

        def prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            return jnp.mean(
                layernorm_mlp(
                    x,
                    gamma,
                    beta,
                    [kernel_1, kernel_2],
                    [bias_1, bias_2],
                    norm_type,
                    zero_centered_gamma=zero_centered_gamma,
                    epsilon=eps,
                    activation_type=activation_type,
                    quantizer_sets=quantizer_sets,
1149
1150
                )
            )
1151

1152
1153
1154
        def _ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            ln_out = _ref_jax_norm_impl(
                x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
1155
            )
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
            # TODO: replace gemm with jnp.dot
            linear_1_out = tex.gemm(ln_out, kernel_1, ((1,), (0,)))
            if use_bias:
                bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
                linear_1_out += jnp.reshape(bias_1, bias_1_shape)

            x = _jax_act_lu(linear_1_out, activation_type)
            linear_2_out = tex.gemm(x, kernel_2, ((1,), (0,)))
            if use_bias:
                bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape
                linear_2_out += jnp.reshape(bias_2, bias_2_shape)

            return linear_2_out

        def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            return jnp.mean(_ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2))

        value_n_grad_prim_func = value_and_grad(prim_func, range(6))
        value_n_grad_ref_func = value_and_grad(ref_func, range(6))

1176
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
        for _ in range(n_iterations):
            prim_out, (
                prim_x_grad,
                prim_gamma_grad,
                prim_kernel_1_grad,
                prim_kernel_2_grad,
                prim_bias_1_grad,
                prim_bias_2_grad,
            ) = value_n_grad_prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)

        ref_out, (
            ref_x_grad,
            ref_gamma_grad,
            ref_kernel_1_grad,
            ref_kernel_2_grad,
            ref_bias_1_grad,
            ref_bias_2_grad,
        ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)

        assert_allclose(prim_out, ref_out, dtype=q_dtype)

        assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=q_dtype)
        if use_bias:
            assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=q_dtype)

        assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=q_dtype)
        if use_bias:
            assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=q_dtype)

        assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype)
        assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype)


# E5M2 * E5M2 is not supported
fwd_bwd_dtypes = [
    [jnp.float8_e4m3fn, jnp.float8_e4m3fn],
    [jnp.float8_e4m3fn, jnp.float8_e5m2],
    [jnp.float8_e5m2, jnp.float8_e4m3fn],
]

1217
1218
1219
1220
1221
1222
1223
1224
1225
GROUPED_DENSE_INPUT_SHAPES = [
    # (n_groups, m, n, k), the actual m will be multiplied by 32
    (5, 32, 128, 64),  # Test the case where n_groups is not a multiple of 4
    (8, 64, 32, 128),
    (8, 64, 128, 256),
]


@pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES)
1226
class TestGroupedDense:
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
    def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims):
        lhs_contract_dim, _ = contracting_dims
        assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3
        if bias is None:
            bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype)
        else:
            assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2])
        remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop()
        lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis)
        rhs = jnp.split(rhs, rhs.shape[0], axis=0)
        bias = jnp.split(bias, bias.shape[0], axis=0)
        ref_out = []
        dim_num = (contracting_dims, ((), ()))
        for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias):
            out_i = jax.lax.dot_general(lhs_i, rhs_i, dim_num) + jnp.expand_dims(bias_i, axis=0)
            ref_out.append(jnp.squeeze(out_i))
        return ref_out

    def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False):
1246
        key = jax.random.PRNGKey(0)
1247
1248
1249
1250
1251
1252
        subkeys = jax.random.split(key, 4)
        n_groups, m, n, k = input_shape

        group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
        group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
        group_sizes = jnp.diff(group_sizes)
1253
1254
1255
        # Make one empty input lhs to test empty GEMM handling
        group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1])
        group_sizes = group_sizes.at[1].set(0)
1256
1257
1258
1259
1260
        assert group_sizes.sum() == m

        # *32 to make sure that input shape works for MXFP8
        group_sizes = group_sizes * 32
        m = m * 32
1261

1262
1263
1264
        lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m)
        rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k)
        bias_shape = (n_groups, n)
1265

1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
        lhs = jax.random.uniform(subkeys[1], lhs_shape, dtype=dtype)
        rhs = jax.random.uniform(subkeys[2], rhs_shape, dtype=dtype)
        bias = jax.random.uniform(subkeys[3], bias_shape, dtype=dtype) if with_bias else None

        lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
        rhs_contracting_dim = (1,) if data_layout[1] == "N" else (2,)
        contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)

        return lhs, rhs, group_sizes, contracting_dims, bias

    def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype):
        assert out.dtype == ref_list[0].dtype
        out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0)
        for i in range(len(ref_list)):
            assert_allclose(out_list[i], ref_list[i], dtype=dtype)
1281
1282

    @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
1283
1284
1285
1286
    @pytest_parametrize_wrapper("layout", ["NN"])
    def test_grouped_gemm_fp16(self, dtype, input_shape, layout):
        lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
            dtype, input_shape, layout
1287
        )
1288
        ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298

        # grouped_gemm does not work with cudaGraph yet, so the jitting will breaks
        # To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
        # disable cudaGraph, then use the following jitted function

        # jitting grouped_gemm
        # prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
        #     lhs, rhs, group_sizes, contracting_dims,
        # )

1299
1300
        prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims)
        self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
1301

1302
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1303
1304
    @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
1305
1306
    @pytest_parametrize_wrapper("layout", ["NN"])
    def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout):
1307
1308
        fwd_dtype, bwd_dtype = fwd_bwd_dtype
        quantizer_set = QuantizerFactory.create_set(
1309
1310
1311
1312
1313
            scaling_mode=scaling_mode,
            fwd_dtype=fwd_dtype,
            bwd_dtype=bwd_dtype,
            is_2x2x=False,
            n_groups=input_shape[0],
1314
        )
1315

1316
1317
1318
1319
1320
1321
        # quantizer_set.{x, kernel} has fwd_dtype, while quantizer_set.grad has bwd_dtype
        # We want to test E4M3 * E5M2, manually set the quantizer_set.kernel.q_dtype to bwd_dtype
        quantizer_set.kernel.q_dtype = bwd_dtype
        for quantizer in quantizer_set.kernel.quantizers:
            quantizer.q_dtype = bwd_dtype

1322
        out_dtype = jnp.bfloat16
1323
1324
1325
1326
        lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
            out_dtype, input_shape, layout
        )
        ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
1327
1328
1329
1330
1331
1332

        # jitting grouped_gemm
        # prim_out = jax.jit(tex.grouped_gemm, static_argnames=('contracting_dims',))(
        #         lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
        #         )

1333
1334
        prim_out = tex.grouped_gemm(
            lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
1335
1336
1337
        )

        allclose_dtype = jnp.float8_e4m3fn
1338
        if jnp.float8_e5m2 in fwd_bwd_dtype:
1339
1340
            allclose_dtype = jnp.float8_e5m2

1341
        self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, allclose_dtype)
1342

1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
    def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims):
        out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims)
        # Note: we use jnp.sum instead of jnp.mean to make the gradient larger
        # and prevent them from being clamp to zero
        out_sum_list = [jnp.sum(out) for out in out_list]
        return jnp.sum(jnp.asarray(out_sum_list))

    def _primitive_sum_grouped_dense(
        self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set
    ):
        out = grouped_dense(
            x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set
1355
        )
1356
        return jnp.sum(jnp.asarray(out))
1357

1358
1359
1360
1361
1362
1363
1364
    @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
    def test_grouped_dense_grad_fp16(self, dtype, input_shape):
        x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
            dtype,
            input_shape,
            with_bias=True,
        )
1365

1366
        value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
1367
1368
1369
        # jitting the grouped_dense
        # value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)),
        #                              static_argnums=(4,))
1370
        value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2))
1371

1372
1373
        ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
            x, kernel, bias, group_sizes, contracting_dims
1374
        )
1375
1376
        prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func(
            x, kernel, bias, group_sizes, contracting_dims
1377
        )
1378

1379
1380
1381
1382
        assert_allclose(prim_out_sum, ref_out_sum, dtype=dtype)
        assert_allclose(prim_dgrad, ref_dgrad, dtype=dtype)
        assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype)
        assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
1383

1384
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1385
1386
1387
1388
    @pytest.mark.parametrize(
        "fwd_bwd_dtype",
        [(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)],
    )
1389
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
1390
1391
1392
1393
1394
1395
1396
    def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
        fwd_dtype, bwd_dtype = fwd_bwd_dtype
        dtype = jnp.bfloat16
        x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
            dtype,
            input_shape,
            with_bias=True,
1397
        )
1398

1399
1400
1401
1402
1403
1404
        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode,
            fwd_dtype=fwd_dtype,
            bwd_dtype=bwd_dtype,
            is_2x2x=True,
            n_groups=group_sizes.size,
1405
        )
1406
        value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
1407
1408
1409
1410

        # jitting the grouped_dense
        # value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)),
        #                              static_argnums=(4,))
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
        value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2))

        ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
            x,
            kernel,
            bias,
            group_sizes,
            contracting_dims,
        )
        prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func(
            x, kernel, bias, group_sizes, contracting_dims, quantizer_set=quantizer_set
1422
1423
        )

1424
1425
1426
1427
        assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype)
        assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype)
        assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype)
        assert_allclose(prim_dbias, ref_dbias, dtype=dtype)