test_custom_call_compute.py 56.5 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.

import jax
import jax.numpy as jnp
import pytest
from jax import jit, value_and_grad
9
from functools import reduce
10
from typing import Union
11
12
13
14
15
import operator

from utils import (
    assert_allclose,
    pytest_parametrize_wrapper,
Alp Dener's avatar
Alp Dener committed
16
    use_jax_gemm,
17
18
19
20
21
)
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.layernorm_mlp import layernorm_mlp

from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu, _jax_quantize_dact_dbias
22
23
24
25
26
from transformer_engine.jax.cpp_extensions.normalization import (
    _jax_layernorm,
    _jax_rmsnorm,
    is_norm_zero_centered_gamma_in_weight_dtype,
)
27
28
29
from transformer_engine.jax.cpp_extensions.quantization import (
    _jax_quantize,
    _jax_quantize_dbias,
30
)
31
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
32
from transformer_engine.jax import cpp_extensions as tex
33
34
from transformer_engine.jax.quantize import (
    ScaledTensor,
35
36
37
    ScaledTensor1x,
    ScaledTensor2x,
    GroupedScaledTensor1x,
38
39
    ScalingMode,
    QuantizerFactory,
40
    QuantizeLayout,
41
    noop_quantizer_set,
42
43
44
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
45
from transformer_engine.jax.dense import dense, grouped_dense
46
from transformer_engine.jax.layernorm_dense import layernorm_dense
47

Tim Moon's avatar
Tim Moon committed
48
49
50
51
52
53
54
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (2048, 1024, 2048),
    (2048, 2048, 1024),
    (2048, 1024, 1024),
]
55
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
56
LN_CASES = [(256, 128), (128, 256)]
57
DTYPES = [jnp.bfloat16, jnp.float32]
58
59
is_fp8_supported, fp8_unsupported_reason = helper.is_fp8_available()
is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
60
61
62
63

supported_scaling_modes = []
""" Find supported scaling modes"""
if is_fp8_supported:
64
    supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING)
65
    supported_scaling_modes.append(ScalingMode.CURRENT_TENSOR_SCALING)
66
if is_mxfp8_supported:
67
    supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING)
68
69
70
71
72
73


def is_shape_supported_by_mxfp8(input_shape):
    try:
        if isinstance(input_shape, type(pytest.param(0))):
            input_shape = input_shape.values[0]
74
        ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape)
75
76
77
78
79
80
        return True
    except:
        # get_scale_shapes will raise an exception if the shape is not supported
        return False


81
82
83
def assert_bitwise_scaled_tensors(
    a: ScaledTensor, b: ScaledTensor, precise_comparison: bool = True
):
84
    if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
85
86
87
88
        if not precise_comparison:
            assert_allclose(a.dequantize(), b.dequantize(), dtype=a.data.dtype)
            return

89
        assert a.scaling_mode == b.scaling_mode
90
        assert a.scale_inv.dtype == b.scale_inv.dtype
91
92
93
94
95
        if a.scaling_mode.is_tensor_scaling():
            # Assert in dq_dtype as some unfused codepaths have an intermediate cast
            # to an input dtype which reduces precision compared to everything in fp32
            assert_allclose(a.scale_inv, b.scale_inv, dtype=a.dq_dtype)
        elif a.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
96
97
98
            # Compare MXFP8 scales as uint8
            assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8))
        else:
99
            raise ValueError(f"Unsupported scaling mode {a.scaling_mode}")
100
        assert_allclose(a.data, b.data)
101

102
    elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x):
103
104
105
106
107
108
        assert_bitwise_scaled_tensors(
            a.rowwise_tensor, b.rowwise_tensor, precise_comparison=precise_comparison
        )
        assert_bitwise_scaled_tensors(
            a.colwise_tensor, b.colwise_tensor, precise_comparison=precise_comparison
        )
109
110
111
112
113
114
    else:
        pytest.fail("Unsupported input types")


def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
    if isinstance(a, ScaledTensor1x):
115
116
117
        if a.data_layout == "T":
            flatten_axis = a.data.ndim - a.flatten_axis
            b_transpose = jnp.transpose(b, (*range(flatten_axis, b.ndim), *range(flatten_axis)))
118
119
120
121
            assert_allclose(a.dequantize(), b_transpose, dtype=a.data.dtype)
        else:
            assert_allclose(a.dequantize(), b, dtype=a.data.dtype)
    elif isinstance(a, ScaledTensor2x):
122
123
        assert_dequantized_scaled_tensor(a.rowwise_tensor, b)
        assert_dequantized_scaled_tensor(a.colwise_tensor, b)
124
125
126
127
    else:
        pytest.fail("a must be a ScaledTensor object")


128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def assert_dequantized_grouped_scaled_tensor(
    a: Union[GroupedScaledTensor1x, ScaledTensor2x], b: jnp.ndarray
):
    if isinstance(a, GroupedScaledTensor1x):
        assert a.group_sizes.sum() == b.shape[0]
        b = jnp.split(b, jnp.cumulative_sum(a.group_sizes)[:-1], axis=0)
        dq_a = a.dequantize()
        for dq_a_i, b_i in zip(dq_a, b):
            if len(dq_a_i) == 0:
                continue
            if a.data_layout == "T":
                data_ndim = len(a.original_shape)
                flatten_axis = a.flatten_axis
                if b_i.shape[0] == 1:
                    b_i = jnp.transpose(
                        b_i, (0, *range(flatten_axis, data_ndim), *range(1, flatten_axis))
                    )
                else:
                    b_i = jnp.transpose(
                        b_i, (*range(flatten_axis, data_ndim), *range(flatten_axis))
                    )
            dq_a_i = dq_a_i.reshape(b_i.shape)
            assert_allclose(dq_a_i, b_i, dtype=a.data.dtype)
    elif isinstance(a, ScaledTensor2x):
152
153
154
155
        assert isinstance(a.rowwise_tensor, GroupedScaledTensor1x)
        assert isinstance(a.colwise_tensor, GroupedScaledTensor1x)
        assert_dequantized_grouped_scaled_tensor(a.rowwise_tensor, b)
        assert_dequantized_grouped_scaled_tensor(a.colwise_tensor, b)
156
157
158
159
    else:
        pytest.fail("a must be a GroupedScaledTensor object")


160
161
162
163
164
165
166
167
168
169
170
171
172
ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)]
ALL_ACTIVATION_TYPES = [
    ("gelu",),
    ("gelu", "linear"),
    ("silu",),
    ("silu", "linear"),
    ("relu",),
    ("relu", "linear"),
    ("quick_gelu",),
    ("quick_gelu", "linear"),
    ("squared_relu",),
    ("squared_relu", "linear"),
]
173

174
175
176
177
178
179
180
ACTIVATION_TYPES = {
    "L0": [
        ("gelu",),
        ("gelu", "linear"),
    ],
    "L2": ALL_ACTIVATION_TYPES,
}
181
182


183
184
185
186
187
188
189
190
191
class TestActivation:
    def ref_act(self, x, activation_type):
        return _jax_act_lu(x, activation_type)

    def value_n_grad_ref_func(self, x, activation_type):
        jitted_reference = jit(
            value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type)), (0,))
        )
        return jitted_reference(x)
192

193
194
195
196
197
198
199
200
201
202
203
204
    def primitive_func(self, inputs, activation_type, quantizer):
        out = activation(inputs, activation_type=activation_type, quantizer=quantizer)
        return jnp.mean(out)

    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper(
        "activation_type",
        (
            ALL_ACTIVATION_TYPES  # Test all activation types for this test to ensure all are functional, then just test a subset for the other tests to verify other functionality
        ),
    )
    def test_act_grad(self, shape, activation_type):
205
        key = jax.random.PRNGKey(0)
206
        x = jax.random.uniform(key, shape, jnp.float32)
207
208
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
209

210
211
212
        value_n_grad_primitive_func = jit(
            value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
        )
213

214
215
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None)
        ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type)
216

217
218
        assert_allclose(prim_out, ref_out, dtype=x.dtype)
        assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
219

220
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
221
222
223
    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
224
225
226
227
228
229
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_act_grad_with_tensor_scaling_fp8(
        self, random_inputs, activation_type, output_type, scaling_mode
    ):
230
        x = random_inputs
231
232
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
233
        self.activation_type = activation_type
234

235
236
237
        value_n_grad_primitive_func = jit(
            value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
        )
238

239
        quantizer = QuantizerFactory.create(
240
            scaling_mode=scaling_mode,
241
            q_dtype=output_type,
242
            q_layout=QuantizeLayout.ROWWISE,
243
        )
244

245
246
        prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer)
        ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type)
247

248
249
        assert_allclose(prim_out, ref_out, dtype=output_type)
        assert_allclose(prim_grad, ref_grad, dtype=output_type)
250

251
    @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
252
253
254
    @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
255
256
257
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
258
259
260
261
262
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_act_forward_with_tensor_scaling_fp8(
        self, random_inputs, activation_type, output_type, q_layout, scaling_mode
263
264
    ):
        x = random_inputs
265
266
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
267
        self.activation_type = activation_type
268

269
270
        te_quantizer, jax_quantizer = QuantizerFactory.create(
            n_quantizers=2,
271
            scaling_mode=scaling_mode,
272
            q_dtype=output_type,
273
            q_layout=q_layout,
274
        )
275

276
277
        te_output = tex.act_lu(x, activation_type, te_quantizer)
        jax_output = _jax_act_lu(x, activation_type, jax_quantizer)
278

279
        assert_bitwise_scaled_tensors(te_output, jax_output)
280

281
    @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
282
    @pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)])
283
284
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
285
286
287
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
288
    def test_act_forward_with_block_scaling_fp8(
289
        self, random_inputs, activation_type, output_type, q_layout
290
291
    ):
        x = random_inputs
292
        x = jnp.repeat(x, len(activation_type), axis=-2)
293
        self.activation_type = activation_type
294

295
        quantizer = QuantizerFactory.create(
296
            scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
297
        )
298

299
300
        output = tex.act_lu(x, activation_type, quantizer)
        ref_out = self.ref_act(x, activation_type)
301

302
        assert_dequantized_scaled_tensor(output, ref_out)
303
304


305
306
307
308
NORM_OUTPUT_DTYPES = {
    "L0": [jnp.float8_e4m3fn],
    "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
309

310

311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
@pytest_parametrize_wrapper("n, hidden", LN_CASES)
@pytest_parametrize_wrapper("inp_dtype", DTYPES)
@pytest_parametrize_wrapper("norm_type", ["layernorm", "rmsnorm"])
@pytest_parametrize_wrapper(
    "zero_centered_gamma",
    [
        pytest.param(True, id="zero_centered"),
        pytest.param(False, id="no_zero_centered"),
    ],
)
@pytest_parametrize_wrapper("epsilon", [1e-2, 1e-6])
class TestNorm:
    """
    Test transformer_engine.jax.layernorm APIs
    """
326

327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    def _test_norm_grad(
        self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
    ):
        def compute_loss(x):
            # Higher precision to compute the loss
            x_ = x.astype(jnp.float32)
            return jnp.mean(jnp.square(x_)).astype(x.dtype)

        def reference_func(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer):
            if norm_type == "rmsnorm":
                ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
            else:
                ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
            # if isinstance(ln_out, ScaledTensor):
            #     ln_out = ln_out.dequantize()
            return ln_out
343

344
345
346
347
348
349
350
351
352
353
354
355
356
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 3)

        x = jax.random.uniform(subkeys[0], (n, hidden), jnp.float32, -1, 1)
        x = x.astype(inp_dtype)
        gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
        gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
        gamma = jnp.asarray(gamma, inp_dtype)
        if norm_type == "layernorm":
            beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
            beta = jnp.asarray(beta, inp_dtype)
        else:
            beta = None
357

358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
        jitted_reference = jit(
            value_and_grad(
                lambda x, gamma, beta: compute_loss(
                    reference_func(
                        x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer=None
                    )
                ),
                (0, 1, 2),
            )
        )
        jitted_primitive = jit(
            value_and_grad(
                lambda x, gamma, beta: compute_loss(
                    layernorm(x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer)
                ),
                (0, 1, 2),
374
            )
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        )

        reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference(
            x, gamma, beta
        )
        primitive_out, (primitive_dx, primitive_dgamma, primitive_dbeta) = jitted_primitive(
            x, gamma, beta
        )

        out_dtype = inp_dtype if quantizer is None else quantizer.q_dtype
        assert_allclose(primitive_out, reference_out, dtype=out_dtype)
        assert_allclose(primitive_dx, reference_dx, dtype=out_dtype)
        assert_allclose(primitive_dgamma, reference_dgamma, dtype=out_dtype)
        if beta is not None:
            assert_allclose(primitive_dbeta, reference_dbeta, dtype=out_dtype)
390

391
392
393
394
395
396
397
398
399
400
    def test_norm_grad(self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype):
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        self._test_norm_grad(
            n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer=None
        )
401

402
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
403
404
    # No Norm FWD E5M2 in TE backend
    @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
405
406
407
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
408
409
410
411
412
413
414
415
416
417
418
419
420
421
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_norm_grad_with_tensor_scaling_fp8(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        q_layout,
        scaling_mode,
422
423
424
425
426
427
428
429
    ):
        """
        Test transformer_engine.jax.layernorm.layernorm
        """
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        quantizer = QuantizerFactory.create(
430
            scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        )
        self._test_norm_grad(
            n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
        )

    def _test_norm_forward(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        scaling_mode,
446
        q_layout,
447
    ):
448
        key = jax.random.PRNGKey(0)
449
        subkeys = jax.random.split(key, 3)
450

451
452
453
454
455
        x = jax.random.uniform(subkeys[0], (n, hidden), inp_dtype, -1, 1)
        x = jnp.asarray(x, inp_dtype)
        gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
        gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
        gamma = jnp.asarray(gamma, inp_dtype)
456

457
        quantizer, ref_quantizer = QuantizerFactory.create(
458
            n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
459
460
461
462
463
464
        )
        if norm_type == "layernorm":
            beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
            beta = jnp.asarray(beta, inp_dtype)
            output, mu, rsigma = tex.layernorm_fwd(
                x, gamma, beta, zero_centered_gamma, epsilon, quantizer=quantizer
465
            )
466
467
            ref_out, ref_mu, ref_rsigma = _jax_layernorm(
                x, gamma, beta, zero_centered_gamma, epsilon, quantizer=ref_quantizer
468
            )
469
470
471
        else:
            output, rsigma = tex.rmsnorm_fwd(
                x, gamma, zero_centered_gamma, epsilon, quantizer=quantizer
472
            )
473
474
            ref_out, ref_rsigma = _jax_rmsnorm(
                x, gamma, zero_centered_gamma, epsilon, quantizer=ref_quantizer
475
            )
476
            ref_mu = None
477

478
479
480
        precise_comparison = True

        if get_cudnn_version() < (9, 10, 0) and scaling_mode == ScalingMode.MXFP8_1D_SCALING:
481
482
            # Reduce precision of test as we don't use fused norm below this version CuDNN for MXFP8 and instead
            # do an unfused norm and quantize with an intermediate cast into in_dtype which can reduce precision
483
484
485
486
487
488
489
490
491
492
493
            precise_comparison = False
        elif is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode):
            # Larger tolerances as our JAX implementation _jax_*norm uses the compute dtype float32
            # for zero-centered gamma always
            precise_comparison = False
        elif scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING and inp_dtype != jnp.float32:
            # Current implementation of Current Tensor Scaling performs unfused layernorm and quantization
            # and writes intermediate results into the input dtype, which will slightly reduce precision
            # if the input dtype is not float32
            precise_comparison = False

494
        assert_bitwise_scaled_tensors(output, ref_out, precise_comparison=precise_comparison)
495

496
497
498
        assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype)
        if norm_type == "layernorm":
            assert_allclose(mu, ref_mu, dtype=inp_dtype)
499

500
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
501
502
    # No Norm FWD E5M2 in TE backend
    @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
503
504
505
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
506
507
508
509
510
511
512
513
514
515
516
517
518
519
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_norm_forward_with_tensor_scaling_fp8(
        self,
        n,
        hidden,
        norm_type,
        zero_centered_gamma,
        epsilon,
        inp_dtype,
        out_dtype,
        q_layout,
        scaling_mode,
520
521
522
523
524
525
526
527
528
529
530
531
    ):
        if norm_type == "rmsnorm" and zero_centered_gamma is True:
            pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

        self._test_norm_forward(
            n=n,
            hidden=hidden,
            norm_type=norm_type,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon,
            inp_dtype=inp_dtype,
            out_dtype=out_dtype,
532
            scaling_mode=scaling_mode,
533
            q_layout=q_layout,
534
        )
535

536
    @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
537
538
539
540
541
542
543
544
545
546
547
548
    @pytest.mark.parametrize("out_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
    def test_norm_forward_with_block_scaling_fp8(
        self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype
    ):
        self._test_norm_forward(
            n=n,
            hidden=hidden,
            norm_type=norm_type,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon,
            inp_dtype=inp_dtype,
            out_dtype=out_dtype,
549
            scaling_mode=ScalingMode.MXFP8_1D_SCALING,
550
            q_layout=QuantizeLayout.ROWWISE_COLWISE,
551
        )
552
553


554
555
556
557
QUANTIZE_OUTPUT_DTYPES = {
    "L0": [jnp.float8_e4m3fn],
    "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
558

559
560
561
ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [
    ((32, 64), -1),
    ((2, 64, 32), -1),
562
    ((64, 2, 32), -2),
563
564
565
566
567
    ((32, 256, 128), -1),
    ((32, 256, 128), -2),
    ((64, 32, 32, 256), -1),
    ((64, 32, 32, 256), -2),
    ((64, 32, 32, 256), -3),
568
]
569

570
QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = {
571
    "L0": [
572
573
        ((32, 64), -1),
        ((2, 64, 32), -1),
574
        ((64, 2, 32), -2),
575
    ],
576
    "L2": ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES,
577
}
578

579
580
581
582
583
584
QUANTIZATION_INPUT_DTYPE = {
    "L0": [jnp.bfloat16],
    "L2": [jnp.float32, jnp.float16, jnp.bfloat16],
}


585
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
586
587
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
588
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
589
590
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper(
591
    "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
592
593
594
595
596
597
)
class TestQuantize:
    """
    Purely quantization related tests that will always test on a wider set of types and shapes
    """

598
    def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
599
        key = jax.random.PRNGKey(0)
600

601
602
603
604
        # Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling)
        quantizer = QuantizerFactory.create(
            scaling_mode=scaling_mode,
            q_dtype=q_dtype,
605
            q_layout=q_layout,
606
        )
607

608
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
609
610
611
        for _ in range(n_iterations):
            x = jax.random.uniform(key, input_shape, in_dtype)

612
            scaled_tensor = quantizer.quantize(x, flatten_axis=flatten_axis)
613
614
            assert_dequantized_scaled_tensor(scaled_tensor, x)

615
616
617
    def test_quantize_bitwise(
        self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
    ):
618
619
620
621
622

        key = jax.random.PRNGKey(0)
        input = jax.random.uniform(key, input_shape, in_dtype)

        te_quantizer, jax_quantizer = QuantizerFactory.create(
623
            n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout
624
        )
625

626
        jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
627

628
629
        te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
        assert_bitwise_scaled_tensors(te_output, jax_output)
630
631


632
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("flatten_axis", [-1])
@pytest_parametrize_wrapper("with_group_sizes", [True, False])
@pytest_parametrize_wrapper(
    "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE]
)
class TestGroupedQuantize:
    def test_grouped_qdq(
        self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis, with_group_sizes
    ):
        n_groups, m, n = input_shape
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        # *32 so that the input shapes works for MXFP8
        input_shape = (m * 32, n)

        if with_group_sizes:
            group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
            group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
            group_sizes = jnp.diff(group_sizes)
            assert group_sizes.sum() == m
            assert jnp.any(group_sizes == 0)  # make sure that at least one group has 0 row
            group_sizes = group_sizes * 32
        else:
            group_sizes = None
            input_shape = (n_groups, input_shape[0] // n_groups, input_shape[1])

        if flatten_axis == -2:
            input_shape = input_shape[:-1] + (2,) + input_shape[-1:]

        x = jax.random.uniform(subkeys[1], input_shape, in_dtype)

        grouped_quantizer = QuantizerFactory.create(
            scaling_mode=scaling_mode,
            q_dtype=q_dtype,
            q_layout=q_layout,
            n_groups=n_groups,
        )

        scaled_tensor = tex.grouped_quantize(
            x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer
        )

        assert_dequantized_grouped_scaled_tensor(scaled_tensor, x)


683
684
685
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
class TestFusedQuantize:

686
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
687
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
688
    @pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
689
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
690
691
692
693
694
695
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
    def test_quantize_dbias(
        self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis
    ):
696
        if scaling_mode == ScalingMode.MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8(
697
698
699
700
701
702
703
704
            input_shape
        ):
            pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")

        key = jax.random.PRNGKey(0)
        input = jax.random.uniform(key, input_shape, in_dtype)

        jax_quantizer, te_quantizer = QuantizerFactory.create(
705
            n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
706
        )
707

708
709
710
711
712
        te_output, te_dbias = jit(
            lambda input: tex.quantize_dbias(
                input, quantizer=te_quantizer, flatten_axis=flatten_axis
            )
        )(input)
713
714
715

        jax_output, jax_dbias = jit(
            lambda input: _jax_quantize_dbias(
716
                input, quantizer=jax_quantizer, flatten_axis=flatten_axis
717
            )
718
        )(input)
719

720
        assert_bitwise_scaled_tensors(te_output, jax_output)
721

722
        assert_allclose(te_dbias, jax_dbias)
723
724

    def _test_quantize_dact_dbias(
725
        self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout
726
727
728
729
    ):
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1)
730
731
        x = jnp.expand_dims(x, axis=-2)
        x = jnp.repeat(x, len(activation_type), axis=-2)
732
        dz = jax.random.uniform(subkeys[1], input_shape, in_dtype, -1, 1)
733

734
        jax_quantizer, te_quantizer = QuantizerFactory.create(
735
            n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
        )
        is_casted_output = te_quantizer is not None

        te_output, te_dbias = jit(
            lambda dz, x: tex.quantize_dact_dbias(
                dz,
                x,
                activation_type=activation_type,
                is_dbias=is_dbias,
                quantizer=te_quantizer,
            )
        )(dz, x)

        jax_output, jax_dbias = jit(
            lambda dz, x: _jax_quantize_dact_dbias(
                dz,
                x,
                activation_type=activation_type,
                is_dbias=is_dbias,
                quantizer=jax_quantizer,
            )
        )(dz, x)
758

759
        if is_casted_output:
760
761
762
763
764
765
766
            # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation
            precise_comparison = not (
                in_dtype != jnp.float32 and scaling_mode.is_1d_block_scaling()
            )
            assert_bitwise_scaled_tensors(
                te_output, jax_output, precise_comparison=precise_comparison
            )
767
        else:
768
            assert_allclose(te_output, jax_output)
769
770

        if is_dbias:
771
772
773
774
775
776
777
            # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
            precise_comparison = not (
                in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling()
            )
            assert_allclose(
                te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype
            )
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792

    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("is_dbias", [True, False])
    def test_quantize_dact_dbias_no_quantization(
        self,
        in_dtype,
        input_shape,
        activation_type,
        is_dbias,
    ):
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=in_dtype,
793
            scaling_mode=ScalingMode.NO_SCALING,
794
795
            activation_type=activation_type,
            is_dbias=is_dbias,
796
            q_layout=QuantizeLayout.ROWWISE,
797
        )
798

799
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
800
801
802
803
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
    @pytest_parametrize_wrapper("is_dbias", [True, False])
804
    @pytest_parametrize_wrapper(
805
        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
806
    )
807
808
809
810
811
    @pytest_parametrize_wrapper(
        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
    )
    def test_quantize_dact_dbias_tensor_scaling(
        self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout, scaling_mode
812
813
814
815
816
    ):
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=out_dtype,
817
            scaling_mode=scaling_mode,
818
819
            activation_type=activation_type,
            is_dbias=is_dbias,
820
            q_layout=q_layout,
821
        )
822

823
    @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
824
825
826
827
828
829
    @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
    @pytest_parametrize_wrapper(
        "input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)]
    )
    @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
    @pytest_parametrize_wrapper("is_dbias", [True, False])
830
831
832
    @pytest_parametrize_wrapper(
        "q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
    )
833
    def test_quantize_dact_dbias_mxfp8_scaling(
834
        self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout
835
836
837
838
839
840
841
842
    ):
        if reduce(operator.mul, input_shape[:-1]) % 128 != 0 or input_shape[-1] % 128 != 0:
            # TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes.
            # If it doesn't, move this check into the quantize_dact_dbias function and revert to JAX
            # implementation in the unsupported cases
            pytest.skip(
                f"Input shape {input_shape} is not supported by dact MXFP8 kernel in TE currently"
            )
843

844
845
846
847
        self._test_quantize_dact_dbias(
            in_dtype=in_dtype,
            input_shape=input_shape,
            out_dtype=out_dtype,
848
            scaling_mode=ScalingMode.MXFP8_1D_SCALING,
849
850
            activation_type=activation_type,
            is_dbias=is_dbias,
851
            q_layout=q_layout,
852
        )
853
854


Alp Dener's avatar
Alp Dener committed
855
856
857
858
859
860
861
valid_fp8_gemm_operand_types = [
    (jnp.float8_e4m3fn, jnp.float8_e4m3fn),
    (jnp.float8_e5m2, jnp.float8_e4m3fn),
    (jnp.float8_e4m3fn, jnp.float8_e5m2),
]


862
class TestDense:
863
864
    def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
        if data_layout[0] == "T":
865
            a = jnp.swapaxes(a, -1, -2)
866
        if data_layout[1] == "T":
867
868
            b = jnp.swapaxes(b, -1, -2)
        return jnp.dot(a, b)
869

870
    def _generate_gemm_input(self, m, n, k, data_layout):
871
872
873
874
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)
        x = jax.random.uniform(
            subkeys[0],
875
            (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
876
877
878
879
            dtype=jnp.bfloat16,
        ) / jnp.sqrt(k)
        w = jax.random.uniform(
            subkeys[1],
880
            (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
881
882
            dtype=jnp.bfloat16,
        ) / jnp.sqrt(n)
883
884
        lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
        rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
885
886
887
888
        contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)

        return (x, w, contracting_dims)

889
890
891
892
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
    @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
    def test_gemm_bf16(self, m, n, k, data_layout):
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
893

Alp Dener's avatar
Alp Dener committed
894
        primitive_out = tex.gemm(x, w, contracting_dims=contracting_dims)
895
        ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
896

897
        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
898

899
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
900
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
Alp Dener's avatar
Alp Dener committed
901
    @pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types)
902
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
903
    @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
Alp Dener's avatar
Alp Dener committed
904
905
906
907
908
909
910
911
912
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
    def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm):
        if (
            not with_jax_gemm
            and scaling_mode.is_1d_block_scaling()
            and jnp.float8_e5m2 in (x_qtype, w_qtype)
        ):
            pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.")

913
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
914
        quantizer_set = QuantizerFactory.create_set(
Alp Dener's avatar
Alp Dener committed
915
916
917
918
            scaling_mode=scaling_mode,
            fwd_dtype=jnp.float8_e4m3fn,
            bwd_dtype=jnp.float8_e5m2,
            is_2x2x=False,
919
        )
Alp Dener's avatar
Alp Dener committed
920
921
922
923
924
925
926
927
928
929
930
931
        with use_jax_gemm(enabled=with_jax_gemm):
            primitive_out = tex.gemm(
                x,
                w,
                contracting_dims=contracting_dims,
                lhs_quantizer=(
                    quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad
                ),
                rhs_quantizer=(
                    quantizer_set.kernel if w_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad
                ),
            )
932
        ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
933

Alp Dener's avatar
Alp Dener committed
934
        assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
935

936
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
937
    def test_dense_grad_bf16(self, m, n, k):
938
939
        data_layout = "NN"
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
940

941
942
943
        def primitive_func(x, w, contracting_dims):
            primitive_out = dense(x, w, contracting_dims=contracting_dims)
            return jnp.mean(primitive_out)
944

945
946
        def ref_func(x, w, data_layout):
            return jnp.mean(self._ref_gemm_with_jnp_dot(x, w, data_layout))
947

948
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))
949

950
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))
951

952
953
        primitive_out, (primitive_x_grad, primitive_w_grad) = value_n_grad_primitive_func(
            x, w, contracting_dims
954
        )
955
        ref_out, (ref_x_grad, ref_w_grad) = value_n_grad_ref_func(x, w, data_layout)
956
957
958
959

        assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
        assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16)
        assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16)
960

961
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
962
    @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
963
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
Alp Dener's avatar
Alp Dener committed
964
965
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
    def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm):
966
967
        data_layout = "NN"
        x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
968
969
970
971
972
973
974
975
976

        key = jax.random.PRNGKey(1)
        bias = jax.random.uniform(key, n, dtype=jnp.bfloat16)

        def primitive_func(x, w, bias, contracting_dims, quantizer_set):
            primitive_out = dense(
                x, w, bias, contracting_dims=contracting_dims, quantizer_set=quantizer_set
            )
            return jnp.mean(primitive_out)
977

978
        def ref_func(x, w, bias, data_layout):
979
            return jnp.mean(
980
                self._ref_gemm_with_jnp_dot(x, w, data_layout) + jnp.expand_dims(bias, axis=0)
981
            )
982

983
984
        value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
985

986
        quantizer_set = QuantizerFactory.create_set(
Alp Dener's avatar
Alp Dener committed
987
988
989
990
            scaling_mode=scaling_mode,
            fwd_dtype=jnp.float8_e4m3fn,
            bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
            is_2x2x=True,
991
        )
992

993
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
Alp Dener's avatar
Alp Dener committed
994
995
996
997
998
        with use_jax_gemm(enabled=with_jax_gemm):
            for _ in range(n_iterations):
                primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = (
                    value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set)
                )
999

1000
1001
1002
        ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func(
            x, w, bias, data_layout
        )
1003

Alp Dener's avatar
Alp Dener committed
1004
1005
1006
1007
        assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
        assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
        assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.float8_e5m2)
        assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp.float8_e5m2)
1008
1009


1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
    key = jax.random.PRNGKey(0)
    subkeys = jax.random.split(key, 4)
    out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
    return out


def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer):
    if norm_type == "rmsnorm":
        ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
    else:
        ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
    if isinstance(ln_out, ScaledTensor):
        ln_out = ln_out.dequantize()
    return ln_out


class TestFusedDense:
1029
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1030
    @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
1031
1032
    @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
    @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
Alp Dener's avatar
Alp Dener committed
1033
1034
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
    def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_gemm):
1035
        """
1036
        Test layernorm_dense VJP Rule
1037
        """
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
        # zero_centered_gamma is already tested in TestNorm
        zero_centered_gamma = False
        eps = 1e-6

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 4)

        # NN in FWD
        x = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16) / jnp.sqrt(k)
        w = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16) / jnp.sqrt(n)

        gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)

        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode,
Alp Dener's avatar
Alp Dener committed
1053
1054
            fwd_dtype=jnp.float8_e4m3fn,
            bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
1055
1056
1057
1058
1059
            is_2x2x=True,
        )

        if norm_type == "layernorm":
            beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
1060
        else:
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
            beta = None

        def prim_func(x, w, gamma, beta):
            # bias = None as quantize_dbias is already tested in test_dense_grad_fp8
            prim_out = layernorm_dense(
                x,
                w,
                gamma,
                beta,
                None,
                norm_type,
                zero_centered_gamma,
                eps,
                quantizer_set=quantizer_set,
            )
            return jnp.mean(prim_out)

        def ref_func(x, w, gamma, beta):
            x = _ref_jax_norm_impl(
                x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
            )
            return jnp.mean(jnp.dot(x, w))

        value_n_grad_prim_func = value_and_grad(prim_func, (0, 1, 2, 3))
        value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))

        ref_out, (ref_x_grad, ref_w_grad, ref_gamma_grad, ref_beta_grad) = value_n_grad_ref_func(
            x, w, gamma, beta
        )

1091
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
Alp Dener's avatar
Alp Dener committed
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
        with use_jax_gemm(enabled=with_jax_gemm):
            for _ in range(n_iterations):
                prim_out, (
                    prim_x_grad,
                    prim_w_grad,
                    prim_gamma_grad,
                    prim_beta_grad,
                ) = value_n_grad_prim_func(x, w, gamma, beta)

        assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn)
        assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
        assert_allclose(prim_w_grad, ref_w_grad, dtype=jnp.float8_e5m2)
        assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2)
1105
        if beta is not None:
Alp Dener's avatar
Alp Dener committed
1106
            assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp.float8_e5m2)
1107

1108
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1109
    @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
1110
1111
1112
    @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
    @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
Alp Dener's avatar
Alp Dener committed
1113
1114
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
1115
    def test_layernorm_mlp_grad(
Alp Dener's avatar
Alp Dener committed
1116
        self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm
1117
    ):
1118
        """
1119
        Test layernorm_mlp VJP Rule
1120
        """
1121
1122
1123
1124
1125
1126
1127
1128
1129
        # zero_centered_gamma is already tested in TestNorm
        zero_centered_gamma = False
        eps = 1e-6

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 6)

        x = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
        kernel_1 = jax.random.normal(
1130
            subkeys[1], (k, len(activation_type), n), jnp.bfloat16
1131
1132
1133
1134
1135
        ) / jnp.sqrt(k)
        kernel_2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
        gamma = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
        beta = None  # was tested in TestNorm
        if use_bias:
1136
            bias_1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
1137
1138
1139
1140
1141
1142
1143
1144
            bias_2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
        else:
            bias_1 = None
            bias_2 = None

        quantizer_sets = QuantizerFactory.create_set(
            n_quantizer_sets=2,
            scaling_mode=scaling_mode,
Alp Dener's avatar
Alp Dener committed
1145
1146
            fwd_dtype=jnp.float8_e4m3fn,
            bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
            is_2x2x=True,
        )

        if norm_type == "layernorm":
            beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
        else:
            beta = None

        def prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            return jnp.mean(
                layernorm_mlp(
                    x,
                    gamma,
                    beta,
                    [kernel_1, kernel_2],
                    [bias_1, bias_2],
                    norm_type,
                    zero_centered_gamma=zero_centered_gamma,
                    epsilon=eps,
                    activation_type=activation_type,
                    quantizer_sets=quantizer_sets,
1168
1169
                )
            )
1170

1171
1172
1173
        def _ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            ln_out = _ref_jax_norm_impl(
                x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
1174
            )
Alp Dener's avatar
Alp Dener committed
1175
            linear_1_out = jax.lax.dot_general(ln_out, kernel_1, (((1,), (0,)), ((), ())))
1176
1177
1178
1179
1180
            if use_bias:
                bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
                linear_1_out += jnp.reshape(bias_1, bias_1_shape)

            x = _jax_act_lu(linear_1_out, activation_type)
Alp Dener's avatar
Alp Dener committed
1181
            linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ())))
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
            if use_bias:
                bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape
                linear_2_out += jnp.reshape(bias_2, bias_2_shape)

            return linear_2_out

        def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
            return jnp.mean(_ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2))

        value_n_grad_prim_func = value_and_grad(prim_func, range(6))
        value_n_grad_ref_func = value_and_grad(ref_func, range(6))

1194
        n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
Alp Dener's avatar
Alp Dener committed
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
        with use_jax_gemm(enabled=with_jax_gemm):
            for _ in range(n_iterations):
                prim_out, (
                    prim_x_grad,
                    prim_gamma_grad,
                    prim_kernel_1_grad,
                    prim_kernel_2_grad,
                    prim_bias_1_grad,
                    prim_bias_2_grad,
                ) = value_n_grad_prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214

        ref_out, (
            ref_x_grad,
            ref_gamma_grad,
            ref_kernel_1_grad,
            ref_kernel_2_grad,
            ref_bias_1_grad,
            ref_bias_2_grad,
        ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)

Alp Dener's avatar
Alp Dener committed
1215
        assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn)
1216

Alp Dener's avatar
Alp Dener committed
1217
        assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=jnp.float8_e5m2)
1218
        if use_bias:
Alp Dener's avatar
Alp Dener committed
1219
            assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=jnp.float8_e5m2)
1220

Alp Dener's avatar
Alp Dener committed
1221
        assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=jnp.float8_e5m2)
1222
        if use_bias:
Alp Dener's avatar
Alp Dener committed
1223
            assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=jnp.float8_e5m2)
1224

Alp Dener's avatar
Alp Dener committed
1225
1226
        assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2)
        assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
1227
1228
1229
1230
1231
1232
1233
1234
1235


# E5M2 * E5M2 is not supported
fwd_bwd_dtypes = [
    [jnp.float8_e4m3fn, jnp.float8_e4m3fn],
    [jnp.float8_e4m3fn, jnp.float8_e5m2],
    [jnp.float8_e5m2, jnp.float8_e4m3fn],
]

1236
1237
1238
1239
1240
1241
1242
1243
1244
GROUPED_DENSE_INPUT_SHAPES = [
    # (n_groups, m, n, k), the actual m will be multiplied by 32
    (5, 32, 128, 64),  # Test the case where n_groups is not a multiple of 4
    (8, 64, 32, 128),
    (8, 64, 128, 256),
]


@pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES)
1245
class TestGroupedDense:
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
    def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims):
        lhs_contract_dim, _ = contracting_dims
        assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3
        if bias is None:
            bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype)
        else:
            assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2])
        remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop()
        lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis)
        rhs = jnp.split(rhs, rhs.shape[0], axis=0)
        bias = jnp.split(bias, bias.shape[0], axis=0)
        ref_out = []
        dim_num = (contracting_dims, ((), ()))
        for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias):
1260
1261
1262
            out_i = jax.lax.dot_general(
                lhs_i, rhs_i, dim_num, precision=jax.lax.Precision.HIGHEST
            ) + jnp.expand_dims(bias_i, axis=0)
1263
1264
1265
1266
            ref_out.append(jnp.squeeze(out_i))
        return ref_out

    def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False):
1267
        key = jax.random.PRNGKey(0)
1268
1269
1270
1271
1272
1273
        subkeys = jax.random.split(key, 4)
        n_groups, m, n, k = input_shape

        group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
        group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
        group_sizes = jnp.diff(group_sizes)
1274
1275
1276
        # Make one empty input lhs to test empty GEMM handling
        group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1])
        group_sizes = group_sizes.at[1].set(0)
1277
1278
1279
1280
1281
        assert group_sizes.sum() == m

        # *32 to make sure that input shape works for MXFP8
        group_sizes = group_sizes * 32
        m = m * 32
1282

1283
1284
1285
        lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m)
        rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k)
        bias_shape = (n_groups, n)
1286

1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
        lhs = jax.random.uniform(subkeys[1], lhs_shape, dtype=dtype)
        rhs = jax.random.uniform(subkeys[2], rhs_shape, dtype=dtype)
        bias = jax.random.uniform(subkeys[3], bias_shape, dtype=dtype) if with_bias else None

        lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
        rhs_contracting_dim = (1,) if data_layout[1] == "N" else (2,)
        contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)

        return lhs, rhs, group_sizes, contracting_dims, bias

    def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype):
        assert out.dtype == ref_list[0].dtype
        out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0)
        for i in range(len(ref_list)):
            assert_allclose(out_list[i], ref_list[i], dtype=dtype)
1302
1303

    @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
1304
1305
1306
1307
    @pytest_parametrize_wrapper("layout", ["NN"])
    def test_grouped_gemm_fp16(self, dtype, input_shape, layout):
        lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
            dtype, input_shape, layout
1308
        )
1309
        ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
1310
1311

        # jitting grouped_gemm
1312
1313
1314
1315
1316
1317
        prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
            lhs,
            rhs,
            group_sizes,
            contracting_dims,
        )
1318

1319
        self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
1320

1321
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1322
1323
    @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
1324
1325
    @pytest_parametrize_wrapper("layout", ["NN"])
    def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout):
1326
1327
        fwd_dtype, bwd_dtype = fwd_bwd_dtype
        quantizer_set = QuantizerFactory.create_set(
1328
1329
1330
1331
1332
            scaling_mode=scaling_mode,
            fwd_dtype=fwd_dtype,
            bwd_dtype=bwd_dtype,
            is_2x2x=False,
            n_groups=input_shape[0],
1333
        )
1334

1335
1336
1337
1338
1339
1340
        # quantizer_set.{x, kernel} has fwd_dtype, while quantizer_set.grad has bwd_dtype
        # We want to test E4M3 * E5M2, manually set the quantizer_set.kernel.q_dtype to bwd_dtype
        quantizer_set.kernel.q_dtype = bwd_dtype
        for quantizer in quantizer_set.kernel.quantizers:
            quantizer.q_dtype = bwd_dtype

1341
        out_dtype = jnp.bfloat16
1342
1343
1344
1345
        lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
            out_dtype, input_shape, layout
        )
        ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
1346

1347
        prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
1348
            lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
1349
1350
1351
        )

        allclose_dtype = jnp.float8_e4m3fn
1352
        if jnp.float8_e5m2 in fwd_bwd_dtype:
1353
1354
            allclose_dtype = jnp.float8_e5m2

1355
        self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, allclose_dtype)
1356

1357
1358
1359
    def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims):
        out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims)
        # Note: we use jnp.sum instead of jnp.mean to make the gradient larger
1360
1361
        # and prevent them from being clamp to zero in FP8. / sqrt(x.size) is used to
        # normalize the output and prevent the gradient from being too large for FP8.
1362
        out_sum_list = [jnp.sum(out) for out in out_list]
1363
        return jnp.sum(jnp.asarray(out_sum_list)) / jnp.sqrt(x.size)
1364
1365
1366
1367
1368
1369

    def _primitive_sum_grouped_dense(
        self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set
    ):
        out = grouped_dense(
            x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set
1370
        )
1371
        return jnp.sum(jnp.asarray(out)) / jnp.sqrt(x.size)
1372

1373
1374
1375
1376
1377
1378
1379
    @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
    def test_grouped_dense_grad_fp16(self, dtype, input_shape):
        x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
            dtype,
            input_shape,
            with_bias=True,
        )
1380

1381
        value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
1382
        # jitting the grouped_dense
1383
1384
1385
        value_n_grad_prim_func = jit(
            value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,)
        )
1386

1387
1388
        ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
            x, kernel, bias, group_sizes, contracting_dims
1389
        )
1390
1391
        prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func(
            x, kernel, bias, group_sizes, contracting_dims
1392
        )
1393

1394
1395
1396
1397
        assert_allclose(prim_out_sum, ref_out_sum, dtype=dtype)
        assert_allclose(prim_dgrad, ref_dgrad, dtype=dtype)
        assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype)
        assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
1398

1399
    @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1400
1401
1402
1403
    @pytest.mark.parametrize(
        "fwd_bwd_dtype",
        [(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)],
    )
1404
    @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
1405
1406
1407
1408
1409
1410
1411
    def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
        fwd_dtype, bwd_dtype = fwd_bwd_dtype
        dtype = jnp.bfloat16
        x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
            dtype,
            input_shape,
            with_bias=True,
1412
        )
1413

1414
1415
1416
1417
1418
1419
        quantizer_set = QuantizerFactory.create_set(
            scaling_mode=scaling_mode,
            fwd_dtype=fwd_dtype,
            bwd_dtype=bwd_dtype,
            is_2x2x=True,
            n_groups=group_sizes.size,
1420
        )
1421
        value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
1422
1423

        # jitting the grouped_dense
1424
1425
1426
        value_n_grad_prim_func = jit(
            value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,)
        )
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436

        ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
            x,
            kernel,
            bias,
            group_sizes,
            contracting_dims,
        )
        prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func(
            x, kernel, bias, group_sizes, contracting_dims, quantizer_set=quantizer_set
1437
1438
        )

1439
1440
1441
1442
        assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype)
        assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype)
        assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype)
        assert_allclose(prim_dbias, ref_dbias, dtype=dtype)