# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. import jax import jax.numpy as jnp import pytest from jax import jit, value_and_grad from functools import reduce from typing import Union import operator from utils import ( assert_allclose, pytest_parametrize_wrapper, use_jax_gemm, ) from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm_mlp import layernorm_mlp from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu, _jax_quantize_dact_dbias from transformer_engine.jax.cpp_extensions.normalization import ( _jax_layernorm, _jax_rmsnorm, is_norm_zero_centered_gamma_in_weight_dtype, ) from transformer_engine.jax.cpp_extensions.quantization import ( _jax_quantize, _jax_quantize_dbias, ) from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax.quantize import ( NoScaleTensor, ScaledTensor, ScaledTensor1x, ScaledTensor2x, GroupedScaledTensor1x, ScalingMode, QuantizerFactory, QuantizeLayout, noop_quantizer_set, ) from transformer_engine.jax.quantize import helper from transformer_engine.jax.activation import activation from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense from transformer_engine.common import recipe GEMM_CASES = [ (256, 256, 512), (32, 32, 32), (2048, 1024, 2048), (2048, 2048, 1024), (2048, 1024, 1024), ] FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2] LN_CASES = [(256, 128), (128, 256)] DTYPES = [jnp.bfloat16, jnp.float32] # TODO(Phuong): remove unneccessary pytest skips is_fp8_supported, fp8_unsupported_reason = helper.is_scaling_mode_supported( ScalingMode.DELAYED_TENSOR_SCALING ) is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_scaling_mode_supported( ScalingMode.MXFP8_1D_SCALING ) is_fp4_supported, fp4_unsupported_reason = helper.is_scaling_mode_supported( ScalingMode.NVFP4_1D_SCALING ) """ Find supported scaling modes""" supported_scaling_modes = helper.get_supported_scaling_modes() non_fp4_supported_scaling_modes = [s for s in supported_scaling_modes if not s.is_nvfp4_scaling] supported_recipes = helper.get_supported_quantization_recipes() supported_recipes = [pytest.param(r, id=r.__class__.__name__) for r in supported_recipes] def is_shape_supported_by_mxfp8(input_shape): try: if isinstance(input_shape, type(pytest.param(0))): input_shape = input_shape.values[0] ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape) return True except: # get_scale_shapes will raise an exception if the shape is not supported return False def assert_bitwise_scaled_tensors( a: ScaledTensor, b: ScaledTensor, precise_comparison: bool = True ): if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x): if not precise_comparison and not a.scaling_mode.is_nvfp4_scaling: assert_allclose(a.dequantize(), b.dequantize(), dtype=a.data.dtype) return assert a.scaling_mode == b.scaling_mode assert a.scale_inv.dtype == b.scale_inv.dtype assert a.data_layout == b.data_layout if a.scaling_mode.is_tensor_scaling(): # Assert in dq_dtype as some unfused codepaths have an intermediate cast # to an input dtype which reduces precision compared to everything in fp32 assert_allclose(a.scale_inv, b.scale_inv, dtype=a.dq_dtype) elif a.scaling_mode == ScalingMode.MXFP8_1D_SCALING: # Compare MXFP8 scales as uint8 assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8)) elif a.scaling_mode.is_nvfp4_scaling: assert_allclose(a.amax, b.amax) assert_allclose(a.scale_inv, b.scale_inv) if not precise_comparison: mismatch = a.data != b.data mismatch_fraction = jnp.mean(mismatch.astype(jnp.float32)) assert ( mismatch_fraction < 0.05 ), f"Mismatch fraction {mismatch_fraction} is too high" return else: raise ValueError(f"Unsupported scaling mode {a.scaling_mode}") assert_allclose(a.data, b.data) elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x): assert_bitwise_scaled_tensors( a.rowwise_tensor, b.rowwise_tensor, precise_comparison=precise_comparison ) assert_bitwise_scaled_tensors( a.colwise_tensor, b.colwise_tensor, precise_comparison=precise_comparison ) else: pytest.fail("Unsupported input types") def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray): if isinstance(a, ScaledTensor1x): if a.data_layout == "T": flatten_axis = a.data.ndim - a.flatten_axis b_transpose = jnp.transpose(b, (*range(flatten_axis, b.ndim), *range(flatten_axis))) assert_allclose(a.dequantize(), b_transpose, dtype=a.data.dtype) else: assert_allclose(a.dequantize(), b, dtype=a.data.dtype) elif isinstance(a, ScaledTensor2x): assert_dequantized_scaled_tensor(a.rowwise_tensor, b) assert_dequantized_scaled_tensor(a.colwise_tensor, b) else: pytest.fail("a must be a ScaledTensor object") def assert_dequantized_grouped_scaled_tensor( a: Union[GroupedScaledTensor1x, ScaledTensor2x], b: jnp.ndarray ): if isinstance(a, GroupedScaledTensor1x): assert a.group_sizes.sum() == b.shape[0] b = jnp.split(b, jnp.cumulative_sum(a.group_sizes)[:-1], axis=0) dq_a = a.dequantize() for dq_a_i, b_i in zip(dq_a, b): if len(dq_a_i) == 0: continue if a.data_layout == "T": data_ndim = len(a.original_shape) flatten_axis = a.flatten_axis if b_i.shape[0] == 1: b_i = jnp.transpose( b_i, (0, *range(flatten_axis, data_ndim), *range(1, flatten_axis)) ) else: b_i = jnp.transpose( b_i, (*range(flatten_axis, data_ndim), *range(flatten_axis)) ) dq_a_i = dq_a_i.reshape(b_i.shape) assert_allclose(dq_a_i, b_i, dtype=a.data.dtype) elif isinstance(a, ScaledTensor2x): assert isinstance(a.rowwise_tensor, GroupedScaledTensor1x) assert isinstance(a.colwise_tensor, GroupedScaledTensor1x) assert_dequantized_grouped_scaled_tensor(a.rowwise_tensor, b) assert_dequantized_grouped_scaled_tensor(a.colwise_tensor, b) else: pytest.fail("a must be a GroupedScaledTensor object") ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)] ALL_ACTIVATION_TYPES = [ ("gelu",), ("gelu", "linear"), ("silu",), ("silu", "linear"), ("relu",), ("relu", "linear"), ("quick_gelu",), ("quick_gelu", "linear"), ("squared_relu",), ("squared_relu", "linear"), ("clamped_silu", "clamped_linear"), ] ACTIVATION_TYPES = { "L0": [ ("gelu",), ("gelu", "linear"), ], "L2": ALL_ACTIVATION_TYPES, } class TestActivation: def ref_act(self, x, activation_type, act_params): return _jax_act_lu(x, activation_type, act_params=act_params).data def value_n_grad_ref_func(self, x, activation_type, act_params): jitted_reference = jit( value_and_grad( lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,) ) ) return jitted_reference(x) def primitive_func(self, inputs, activation_type, quantizer, act_params): out = activation( inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params ) return jnp.mean(out) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper( "activation_type", ( ALL_ACTIVATION_TYPES # Test all activation types for this test to ensure all are functional, then just test a subset for the other tests to verify other functionality ), ) def test_act_grad(self, shape, activation_type): key = jax.random.PRNGKey(0) x = jax.random.uniform(key, shape, jnp.float32) x = jnp.expand_dims(x, axis=-2) x = jnp.repeat(x, len(activation_type), axis=-2) value_n_grad_primitive_func = jit( value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3) ) act_args = ( {"limit": 0.75, "alpha": 1.702} if activation_type == ("clamped_silu", "clamped_linear") else {} ) act_params = ( tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) if activation_type == ("clamped_silu", "clamped_linear") else None ) prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params) ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params) assert_allclose(prim_out, ref_out, dtype=x.dtype) assert_allclose(prim_grad, ref_grad, dtype=x.dtype) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper( "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING] ) def test_act_grad_with_tensor_scaling_fp8( self, random_inputs, activation_type, output_type, scaling_mode ): x = random_inputs x = jnp.expand_dims(x, axis=-2) x = jnp.repeat(x, len(activation_type), axis=-2) self.activation_type = activation_type value_n_grad_primitive_func = jit( value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3), ) quantizer = QuantizerFactory.create( scaling_mode=scaling_mode, q_dtype=output_type, q_layout=QuantizeLayout.ROWWISE, ) act_args = ( {"limit": 0.75, "alpha": 1.702} if activation_type == ("clamped_silu", "clamped_linear") else {} ) act_params = ( tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) if activation_type == ("clamped_silu", "clamped_linear") else None ) prim_out, (prim_grad,) = value_n_grad_primitive_func( x, activation_type, quantizer, act_params ) ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params) assert_allclose(prim_out, ref_out, dtype=output_type) assert_allclose(prim_grad, ref_grad, dtype=output_type) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] ) @pytest_parametrize_wrapper( "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING] ) def test_act_forward_with_tensor_scaling_fp8( self, random_inputs, activation_type, output_type, q_layout, scaling_mode ): x = random_inputs x = jnp.expand_dims(x, axis=-2) x = jnp.repeat(x, len(activation_type), axis=-2) self.activation_type = activation_type te_quantizer, jax_quantizer = QuantizerFactory.create( n_quantizers=2, scaling_mode=scaling_mode, q_dtype=output_type, q_layout=q_layout, ) act_args = ( {"limit": 0.75, "alpha": 1.702} if activation_type == ("clamped_silu", "clamped_linear") else {} ) act_params = ( tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) if activation_type == ("clamped_silu", "clamped_linear") else None ) te_output = tex.act_lu(x, activation_type, te_quantizer, act_params) jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params) assert_bitwise_scaled_tensors(te_output, jax_output) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)]) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] ) def test_act_forward_with_block_scaling_fp8( self, random_inputs, activation_type, output_type, q_layout ): x = random_inputs x = jnp.repeat(x, len(activation_type), axis=-2) self.activation_type = activation_type quantizer = QuantizerFactory.create( scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout ) act_args = ( {"limit": 0.75, "alpha": 1.702} if activation_type == ("clamped_silu", "clamped_linear") else {} ) act_params = ( tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) if activation_type == ("clamped_silu", "clamped_linear") else None ) output = tex.act_lu(x, activation_type, quantizer, act_params) ref_out = self.ref_act(x, activation_type, act_params) assert_dequantized_scaled_tensor(output, ref_out) NORM_OUTPUT_DTYPES = { "L0": [jnp.float8_e4m3fn], "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2], } @pytest_parametrize_wrapper("n, hidden", LN_CASES) @pytest_parametrize_wrapper("inp_dtype", DTYPES) @pytest_parametrize_wrapper("norm_type", ["layernorm", "rmsnorm"]) @pytest_parametrize_wrapper( "zero_centered_gamma", [ pytest.param(True, id="zero_centered"), pytest.param(False, id="no_zero_centered"), ], ) @pytest_parametrize_wrapper("epsilon", [1e-2, 1e-6]) class TestNorm: """ Test transformer_engine.jax.layernorm APIs """ def _test_norm_grad( self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer ): def compute_loss(x): # Higher precision to compute the loss x_ = x.astype(jnp.float32) return jnp.mean(jnp.square(x_)).astype(x.dtype) def reference_func(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer): if norm_type == "rmsnorm": ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer) else: ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer) # This is a no-op for non-quantized data ln_out = ln_out.dequantize() return ln_out key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 3) x = jax.random.uniform(subkeys[0], (n, hidden), jnp.float32, -1, 1) x = x.astype(inp_dtype) gamma_range = (-1, 1) if zero_centered_gamma else (0, 2) gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range) gamma = jnp.asarray(gamma, inp_dtype) if norm_type == "layernorm": beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1) beta = jnp.asarray(beta, inp_dtype) else: beta = None jitted_reference = jit( value_and_grad( lambda x, gamma, beta: compute_loss( reference_func( x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer=None ) ), (0, 1, 2), ) ) jitted_primitive = jit( value_and_grad( lambda x, gamma, beta: compute_loss( layernorm(x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer) ), (0, 1, 2), ) ) reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference( x, gamma, beta ) primitive_out, (primitive_dx, primitive_dgamma, primitive_dbeta) = jitted_primitive( x, gamma, beta ) out_dtype = inp_dtype if quantizer is None else quantizer.q_dtype assert_allclose(primitive_out, reference_out, dtype=out_dtype) assert_allclose(primitive_dx, reference_dx, dtype=out_dtype) assert_allclose(primitive_dgamma, reference_dgamma, dtype=out_dtype) if beta is not None: assert_allclose(primitive_dbeta, reference_dbeta, dtype=out_dtype) def test_norm_grad(self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype): """ Test transformer_engine.jax.layernorm.layernorm """ if norm_type == "rmsnorm" and zero_centered_gamma is True: pytest.skip("RMSNorm and zero_centered_gamma is not supported!") self._test_norm_grad( n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer=None ) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) # No Norm FWD E5M2 in TE backend @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] ) @pytest_parametrize_wrapper( "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING] ) def test_norm_grad_with_tensor_scaling_fp8( self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout, scaling_mode, ): """ Test transformer_engine.jax.layernorm.layernorm """ if norm_type == "rmsnorm" and zero_centered_gamma is True: pytest.skip("RMSNorm and zero_centered_gamma is not supported!") quantizer = QuantizerFactory.create( scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout ) self._test_norm_grad( n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer ) def _test_norm_forward( self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, scaling_mode, q_layout, ): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 3) x = jax.random.uniform(subkeys[0], (n, hidden), inp_dtype, -1, 1) x = jnp.asarray(x, inp_dtype) gamma_range = (-1, 1) if zero_centered_gamma else (0, 2) gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range) gamma = jnp.asarray(gamma, inp_dtype) quantizer, ref_quantizer = QuantizerFactory.create( n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout ) if norm_type == "layernorm": beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1) beta = jnp.asarray(beta, inp_dtype) output, mu, rsigma = tex.layernorm_fwd( x, gamma, beta, zero_centered_gamma, epsilon, quantizer=quantizer ) ref_out, ref_mu, ref_rsigma = _jax_layernorm( x, gamma, beta, zero_centered_gamma, epsilon, quantizer=ref_quantizer, ) else: output, rsigma = tex.rmsnorm_fwd( x, gamma, zero_centered_gamma, epsilon, quantizer=quantizer ) ref_out, ref_rsigma = _jax_rmsnorm( x, gamma, zero_centered_gamma, epsilon, quantizer=ref_quantizer, ) ref_mu = None precise_comparison = True if get_cudnn_version() < (9, 10, 0) and scaling_mode == ScalingMode.MXFP8_1D_SCALING: # Reduce precision of test as we don't use fused norm below this version CuDNN for MXFP8 and instead # do an unfused norm and quantize with an intermediate cast into in_dtype which can reduce precision precise_comparison = False elif is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode): # Larger tolerances as our JAX implementation _jax_*norm uses the compute dtype float32 # for zero-centered gamma always precise_comparison = False elif scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING and inp_dtype != jnp.float32: # Current implementation of Current Tensor Scaling performs unfused layernorm and quantization # and writes intermediate results into the input dtype, which will slightly reduce precision # if the input dtype is not float32 precise_comparison = False assert_bitwise_scaled_tensors(output, ref_out, precise_comparison=precise_comparison) assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype) if norm_type == "layernorm": assert_allclose(mu, ref_mu, dtype=inp_dtype) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) # No Norm FWD E5M2 in TE backend @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] ) @pytest_parametrize_wrapper( "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING] ) def test_norm_forward_with_tensor_scaling_fp8( self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout, scaling_mode, ): if norm_type == "rmsnorm" and zero_centered_gamma is True: pytest.skip("RMSNorm and zero_centered_gamma is not supported!") self._test_norm_forward( n=n, hidden=hidden, norm_type=norm_type, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, inp_dtype=inp_dtype, out_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout, ) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @pytest.mark.parametrize( "out_dtype", [ jnp.float8_e4m3fn, ], ) def test_norm_forward_with_block_scaling_fp8( self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype ): self._test_norm_forward( n=n, hidden=hidden, norm_type=norm_type, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, inp_dtype=inp_dtype, out_dtype=out_dtype, scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_layout=QuantizeLayout.ROWWISE_COLWISE, ) QUANTIZE_OUTPUT_FP8_DTYPES = { "L0": [jnp.float8_e4m3fn], "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2], } QUANTIZE_OUTPUT_DTYPES = { test_level: QUANTIZE_OUTPUT_FP8_DTYPES[test_level] + [jnp.float4_e2m1fn] for test_level in QUANTIZE_OUTPUT_FP8_DTYPES } QUANTIZE_QDTYPE_AND_SCALING_MODES = { test_level: [ (q_dtype, scaling_mode) for q_dtype, scaling_mode in zip( QUANTIZE_OUTPUT_FP8_DTYPES[test_level], supported_scaling_modes ) if q_dtype in scaling_mode.get_compatible_q_dtypes() ] for test_level in QUANTIZE_OUTPUT_FP8_DTYPES } ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [ ((32, 64), -1), ((2, 64, 32), -1), ((64, 2, 32), -2), ((32, 256, 128), -1), ((32, 256, 128), -2), ((64, 32, 32, 256), -1), ((8192, 2, 4096), -2), ] QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = { "L0": [ ((32, 64), -1), ((2, 64, 32), -1), ((64, 2, 32), -2), ], "L2": ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES, } QUANTIZATION_INPUT_DTYPE = { "L0": [jnp.bfloat16], "L2": [jnp.float32, jnp.float16, jnp.bfloat16], } @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2, jnp.float4_e2m1fn]) @pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper( "q_layout", [ QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE, ], ) class TestQuantize: """ Purely quantization related tests that will always test on a wider set of types and shapes """ def _skip_unsupported_dtypes(self, q_dtype, scaling_mode): """Skip unsupported dtypes for given scaling mode. For example, NVFP4 only supports the float4_e2m1 dtype not float8 dtypes.""" if q_dtype not in scaling_mode.get_compatible_q_dtypes(): pytest.skip(f"Quantize dtype {q_dtype} is not supported by {scaling_mode}") return def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): self._skip_unsupported_dtypes(q_dtype, scaling_mode) key = jax.random.PRNGKey(0) # Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling) quantizer = QuantizerFactory.create( scaling_mode=scaling_mode, q_dtype=q_dtype, q_layout=q_layout, ) if scaling_mode.is_nvfp4_scaling: if in_dtype != jnp.bfloat16: pytest.skip("NVFP4 scaling only supported with bfloat16 input dtype currently") return q_func = _jax_quantize # For NVFP4 scaling, the maximum possible error for a single value can be high between the dequantized and original tensors. To ensure quantization and dequantization is operating correctly without requiring a very high tolerance for all values, we instead test that quantizing the dequantized tensor is bitwise identical to the original quantized tensor. x = jax.random.uniform(key, input_shape, in_dtype) * 10 q1 = q_func(x, quantizer=quantizer, flatten_axis=flatten_axis) dq_rowwise = None dq_colwise = None if isinstance(q1, ScaledTensor1x): dq = q1.dequantize() if q1.is_colwise: dq_colwise = dq else: dq_rowwise = dq elif isinstance(q1, ScaledTensor2x): dq_rowwise = q1.rowwise_tensor.dequantize() dq_colwise = q1.colwise_tensor.dequantize() else: raise ValueError(f"Unsupported output type {type(q1)}") # We only compare Q-DQ for the same quantization layout. If we for example QDQ rowwise, then re-quantize colwise, the error will be larger and may not be bitwise identical to the original colwise quantization. if dq_rowwise is not None: assert ( dq_rowwise.shape == x.shape ), f"dq_rowwise shape {dq_rowwise.shape} != x shape {x.shape}" q2_rowwise = q_func(dq_rowwise, quantizer=quantizer, flatten_axis=flatten_axis) q2_rowwise = ( q2_rowwise if isinstance(q2_rowwise, ScaledTensor1x) else q2_rowwise.rowwise_tensor ) q1_rowwise = q1 if isinstance(q1, ScaledTensor1x) else q1.rowwise_tensor assert_bitwise_scaled_tensors(q1_rowwise, q2_rowwise) if dq_colwise is not None: # Since this is for NVFP4, we are assuming colwise has T layout and we do a transpose here to get back to original shape flatten_axis = flatten_axis + len(input_shape) if flatten_axis < 0 else flatten_axis colwise_flatten_axis = len(input_shape) - flatten_axis dq_colwise = jnp.transpose( dq_colwise, (*range(colwise_flatten_axis, dq_colwise.ndim), *range(colwise_flatten_axis)), ) assert ( dq_colwise.shape == x.shape ), f"dq_colwise shape {dq_colwise.shape} != x shape {x.shape}" q2_colwise = q_func(dq_colwise, quantizer=quantizer, flatten_axis=flatten_axis) q2_colwise = ( q2_colwise if isinstance(q2_colwise, ScaledTensor1x) else q2_colwise.colwise_tensor ) q1_colwise = q1 if isinstance(q1, ScaledTensor1x) else q1.colwise_tensor assert_bitwise_scaled_tensors(q1_colwise, q2_colwise) assert ( dq_rowwise is not None or dq_colwise is not None ), "At least one of rowwise or colwise dq must be not None" return n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): x = jax.random.uniform(key, input_shape, in_dtype) scaled_tensor = quantizer.quantize(x, flatten_axis=flatten_axis) assert_dequantized_scaled_tensor(scaled_tensor, x) def _should_use_precise_comparison( self, in_dtype, scaling_mode, quantizer, input_shape, flatten_axis ): if scaling_mode.is_nvfp4_scaling and in_dtype != jnp.bfloat16: # With NVFP4 scaling, TE kernels internally use bfloat16 so using a different input dtype can lead to small numerical differences compared to the JAX implementation return False return True def test_quantize_bitwise( self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis ): self._skip_unsupported_dtypes(q_dtype, scaling_mode) key = jax.random.PRNGKey(0) input = jax.random.uniform(key, input_shape, in_dtype) te_quantizer, jax_quantizer = QuantizerFactory.create( n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout ) jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis) assert_bitwise_scaled_tensors( te_output, jax_output, precise_comparison=self._should_use_precise_comparison( in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis ), ) def test_quantize_bitwise_jitted( self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis ): self._skip_unsupported_dtypes(q_dtype, scaling_mode) key = jax.random.PRNGKey(0) input = jax.random.uniform(key, input_shape, in_dtype) te_quantizer, jax_quantizer = QuantizerFactory.create( n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout ) jax_impl_func_jit = jax.jit(_jax_quantize, static_argnums=(2, 3)) te_impl_func_jit = jax.jit(tex.quantize, static_argnums=(2,)) jax_output = jax_impl_func_jit(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) te_output = te_impl_func_jit(input, quantizer=te_quantizer, flatten_axis=flatten_axis) assert_bitwise_scaled_tensors( te_output, jax_output, precise_comparison=self._should_use_precise_comparison( in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis ), ) @pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16]) @pytest_parametrize_wrapper("q_dtype", [jnp.float4_e2m1fn]) @pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) @pytest_parametrize_wrapper( "scaling_mode", [s for s in supported_scaling_modes if s.is_nvfp4_scaling] ) @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE] ) class TestStochasticRounding: def _dequantize(self, scaled_tensor) -> list[jnp.ndarray]: """Dequantizes a ScaledTensor back to it's original jnp.ndarray form. This always returns an array of jnp.ndarrays, for ScaledTensor2x there will be two tensors, for ScaledTensor1x there will be one tensor.""" if isinstance(scaled_tensor, ScaledTensor1x): dq = scaled_tensor.dequantize() if scaled_tensor.data_layout == "T": dq = jnp.transpose( dq, ( *range(scaled_tensor.flatten_axis, dq.ndim), *range(scaled_tensor.flatten_axis), ), ) return [dq] elif isinstance(scaled_tensor, ScaledTensor2x): [rowwise_dq] = self._dequantize(scaled_tensor.rowwise_tensor) [colwise_dq] = self._dequantize(scaled_tensor.colwise_tensor) return [rowwise_dq, colwise_dq] raise ValueError( "Unsupported ScaledTensor type, expected ScaledTensor but received" f" {type(scaled_tensor)}" ) def _sample_sr_qdq( self, num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis ) -> list[jnp.ndarray]: """Samples num_samples quantize-dequantize operations with stochastic rounding enabled and returns the dequantized tensors.""" dq_tensors = [] key = jax.random.PRNGKey(0) for i in range(num_samples): iter_key = jax.random.fold_in(key, i) sr_rng_state = jax.random.randint( iter_key, (1, 4), minval=0, maxval=2**30 - 1, dtype=jnp.uint32 ) quantizer = QuantizerFactory.create( q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout, stochastic_rounding_rng_state=sr_rng_state, ) q_output = q_func(inputs, quantizer=quantizer, flatten_axis=flatten_axis) iter_dq = self._dequantize(q_output) dq_tensors.extend(iter_dq) avg_sr_tensor = jnp.mean(jnp.stack(dq_tensors), axis=0) assert avg_sr_tensor.shape == inputs.shape, ( f"Dequantized tensor shape {avg_sr_tensor.shape} does not match input shape" f" {inputs.shape}" ) sr_mae = jnp.mean(jnp.abs(avg_sr_tensor - inputs)) dq_var = jnp.var(jnp.stack(dq_tensors)) assert ( dq_var > 0 ), "Variance of dequantized tensors is zero, stochastic rounding may not be working" return dq_tensors def _round_nearest( self, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis ) -> jnp.ndarray: """Quantizes and dequantizes the input tensor with round nearest quantization.""" quantizer = QuantizerFactory.create( q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout, stochastic_rounding_rng_state=None, ) q_output = q_func(inputs, quantizer=quantizer, flatten_axis=flatten_axis) rn_dq = self._dequantize(q_output)[0] return rn_dq def _test_sr( self, num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis ) -> float: """Tests that the mean absolute error (MAE) of stochastic rounding is smaller than round nearest quantization over multiple samples.""" dq_tensors = self._sample_sr_qdq( num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis ) avg_sr_tensor = jnp.mean(jnp.stack(dq_tensors).astype(jnp.float32), axis=0) assert avg_sr_tensor.shape == inputs.shape, ( f"Dequantized tensor shape {avg_sr_tensor.shape} does not match input shape" f" {inputs.shape}" ) round_nearest_tensor = self._round_nearest( q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis ) sr_mae = jnp.mean(jnp.abs(avg_sr_tensor - inputs)) rn_mae = jnp.mean(jnp.abs(round_nearest_tensor - inputs)) assert sr_mae < rn_mae, ( f"Mean absolute error of stochastic rounding ({sr_mae}) is not smaller than" f" round nearest ({rn_mae})" ) return sr_mae def test_sr_nvfp4(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): """Tests that the mean absolute error of stochastic rounding is smaller than round nearest quantization over multiple samples for both TE and JAX implementations. Asserts that the MAE of both implementations is close to each other.""" key = jax.random.PRNGKey(0) inputs = jax.random.uniform(key, input_shape, in_dtype) NUM_SAMPLES = 10 te_mean_error = self._test_sr( NUM_SAMPLES, tex.quantize, inputs, q_dtype, scaling_mode, q_layout, flatten_axis ) jax_mean_error = self._test_sr( NUM_SAMPLES, _jax_quantize, inputs, q_dtype, scaling_mode, q_layout, flatten_axis ) assert_allclose(te_mean_error, jax_mean_error, rtol=0.2, atol=1e-4) @pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16]) @pytest_parametrize_wrapper("q_dtype", [jnp.float4_e2m1fn]) @pytest_parametrize_wrapper( "scaling_mode", [s for s in supported_scaling_modes if s == ScalingMode.NVFP4_1D_SCALING] ) class TestRandomizedHadamardTransform: @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE] ) @pytest_parametrize_wrapper("input_shape,flatten_axis", [((64, 128), -1)]) def test_rht_quantize_bitwise_jitted( self, in_dtype, q_dtype, scaling_mode, q_layout, input_shape, flatten_axis ): key = jax.random.PRNGKey(0) inputs = jax.random.uniform(key, input_shape, in_dtype) te_quantizer, jax_quantizer = QuantizerFactory.create( n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout, use_rht=True, ) jax_impl_func_jit = jax.jit(_jax_quantize, static_argnums=(2, 3)) te_impl_func_jit = jax.jit(tex.quantize, static_argnums=(2,)) jax_output = jax_impl_func_jit(inputs, quantizer=jax_quantizer, flatten_axis=flatten_axis) te_output = te_impl_func_jit(inputs, quantizer=te_quantizer, flatten_axis=flatten_axis) assert_bitwise_scaled_tensors(te_output, jax_output) def _ref_gemm_with_jnp_dot(self, a, b, data_layout): if data_layout[0] == "T": a = jnp.swapaxes(a, -1, -2) if data_layout[1] == "T": b = jnp.swapaxes(b, -1, -2) return jnp.dot(a, b) def _generate_gemm_input(self, m, n, k, data_layout): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) x = jax.random.uniform( subkeys[0], (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m), dtype=jnp.bfloat16, ) / jnp.sqrt(k) w = jax.random.uniform( subkeys[1], (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k), dtype=jnp.bfloat16, ) / jnp.sqrt(n) lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,) contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) return (x, w, contracting_dims) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) # We do not test NN and TT layouts here as they do not have both inputs using RHT due to RHT only supporting the colwise layout currently @pytest_parametrize_wrapper("data_layout", ["TN", "NT"]) @pytest_parametrize_wrapper("with_jax_gemm", [True, False]) def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, with_jax_gemm): key = jax.random.PRNGKey(0) lhs_scaling_mode, rhs_scaling_mode = scaling_mode, scaling_mode x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) lhs_quantizer = QuantizerFactory.create( scaling_mode=lhs_scaling_mode, q_dtype=jnp.float4_e2m1fn, use_rht=True, ) rhs_quantizer = QuantizerFactory.create( scaling_mode=rhs_scaling_mode, q_dtype=jnp.float4_e2m1fn, use_rht=True, ) with use_jax_gemm(enabled=with_jax_gemm): primitive_out = tex.gemm( x, w, contracting_dims=contracting_dims, lhs_quantizer=lhs_quantizer, rhs_quantizer=rhs_quantizer, ) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) assert_allclose(primitive_out, ref_out, dtype=jnp.float4_e2m1fn) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("input_shape", [(8, 16, 32)]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) @pytest_parametrize_wrapper("flatten_axis", [-1]) @pytest_parametrize_wrapper("with_group_sizes", [True, False]) @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE] ) class TestGroupedQuantize: def test_grouped_qdq( self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis, with_group_sizes ): n_groups, m, n = input_shape key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) # *32 so that the input shapes works for MXFP8 input_shape = (m * 32, n) if with_group_sizes: group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) group_sizes = jnp.diff(group_sizes) assert group_sizes.sum() == m assert jnp.any(group_sizes == 0) # make sure that at least one group has 0 row group_sizes = group_sizes * 32 else: group_sizes = None input_shape = (n_groups, input_shape[0] // n_groups, input_shape[1]) if flatten_axis == -2: input_shape = input_shape[:-1] + (2,) + input_shape[-1:] x = jax.random.uniform(subkeys[1], input_shape, in_dtype) grouped_quantizer = QuantizerFactory.create( scaling_mode=scaling_mode, q_dtype=q_dtype, q_layout=q_layout, n_groups=n_groups, ) scaled_tensor = tex.grouped_quantize( x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer ) assert_dequantized_grouped_scaled_tensor(scaled_tensor, x) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) class TestFusedQuantize: @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) @pytest_parametrize_wrapper("out_dtype,scaling_mode", QUANTIZE_QDTYPE_AND_SCALING_MODES) @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] ) def test_quantize_dbias( self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis ): if scaling_mode == ScalingMode.MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8( input_shape ): pytest.skip(f"Input shape {input_shape} is not supported by MXFP8") key = jax.random.PRNGKey(0) input = jax.random.uniform(key, input_shape, in_dtype) jax_quantizer, te_quantizer = QuantizerFactory.create( n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout ) te_output, te_dbias = jit( lambda input: tex.quantize_dbias( input, quantizer=te_quantizer, flatten_axis=flatten_axis ) )(input) jax_output, jax_dbias = jit( lambda input: _jax_quantize_dbias( input, quantizer=jax_quantizer, flatten_axis=flatten_axis ) )(input) assert_bitwise_scaled_tensors(te_output, jax_output) assert_allclose(te_dbias, jax_dbias) def _test_quantize_dact_dbias( self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout ): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1) x = jnp.expand_dims(x, axis=-2) x = jnp.repeat(x, len(activation_type), axis=-2) dz = jax.random.uniform(subkeys[1], input_shape, in_dtype, -1, 1) jax_quantizer, te_quantizer = QuantizerFactory.create( n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout ) is_casted_output = te_quantizer is not None te_output, te_dbias = jit( lambda dz, x: tex.quantize_dact_dbias( dz, x, activation_type=activation_type, is_dbias=is_dbias, quantizer=te_quantizer, ) )(dz, x) jax_output, jax_dbias = jit( lambda dz, x: _jax_quantize_dact_dbias( dz, x, activation_type=activation_type, is_dbias=is_dbias, quantizer=jax_quantizer, ) )(dz, x) if is_casted_output: # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation precise_comparison = not ( in_dtype != jnp.float32 and scaling_mode.is_1d_block_scaling() ) assert_bitwise_scaled_tensors( te_output, jax_output, precise_comparison=precise_comparison ) else: assert isinstance(te_output, NoScaleTensor) assert isinstance(jax_output, NoScaleTensor) assert_allclose(te_output.data, jax_output.data) if is_dbias: precise_comparison = not ( # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16. (in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling()) # Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently. or ( activation_type in {("squared_relu",), ("clamped_silu", "clamped_linear")} and in_dtype == jnp.bfloat16 and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING ) ) assert_allclose( te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype ) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("is_dbias", [True, False]) def test_quantize_dact_dbias_no_quantization( self, in_dtype, input_shape, activation_type, is_dbias, ): self._test_quantize_dact_dbias( in_dtype=in_dtype, input_shape=input_shape, out_dtype=in_dtype, scaling_mode=ScalingMode.NO_SCALING, activation_type=activation_type, is_dbias=is_dbias, q_layout=QuantizeLayout.ROWWISE, ) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_FP8_DTYPES) @pytest_parametrize_wrapper("is_dbias", [True, False]) @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] ) @pytest_parametrize_wrapper( "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING] ) def test_quantize_dact_dbias_tensor_scaling( self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout, scaling_mode ): self._test_quantize_dact_dbias( in_dtype=in_dtype, input_shape=input_shape, out_dtype=out_dtype, scaling_mode=scaling_mode, activation_type=activation_type, is_dbias=is_dbias, q_layout=q_layout, ) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper( "input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)] ) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_FP8_DTYPES) @pytest_parametrize_wrapper("is_dbias", [True, False]) @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE] ) def test_quantize_dact_dbias_mxfp8_scaling( self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout ): if reduce(operator.mul, input_shape[:-1]) % 128 != 0 or input_shape[-1] % 128 != 0: # TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes. # If it doesn't, move this check into the quantize_dact_dbias function and revert to JAX # implementation in the unsupported cases pytest.skip( f"Input shape {input_shape} is not supported by dact MXFP8 kernel in TE currently" ) self._test_quantize_dact_dbias( in_dtype=in_dtype, input_shape=input_shape, out_dtype=out_dtype, scaling_mode=ScalingMode.MXFP8_1D_SCALING, activation_type=activation_type, is_dbias=is_dbias, q_layout=q_layout, ) valid_fp8_gemm_operand_types = [ (jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e5m2, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2), ] supported_nvfp4_scaling_mode_pairs = [ (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_1D_SCALING), (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING), ] class TestDense: def _ref_gemm_with_jnp_dot(self, a, b, data_layout): if data_layout[0] == "T": a = jnp.swapaxes(a, -1, -2) if data_layout[1] == "T": b = jnp.swapaxes(b, -1, -2) return jnp.dot(a, b) def _generate_gemm_input(self, m, n, k, data_layout): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) x = jax.random.uniform( subkeys[0], (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m), dtype=jnp.bfloat16, ) / jnp.sqrt(k) w = jax.random.uniform( subkeys[1], (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k), dtype=jnp.bfloat16, ) / jnp.sqrt(n) lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,) contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) return (x, w, contracting_dims) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) def test_gemm_bf16(self, m, n, k, data_layout): x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) primitive_out = tex.gemm(x, w, contracting_dims=contracting_dims) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types) @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm): if ( not with_jax_gemm and scaling_mode.is_1d_block_scaling() and jnp.float8_e5m2 in (x_qtype, w_qtype) ): pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.") x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) quantizer_set = QuantizerFactory.create_set( scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, bwd_dtype=jnp.float8_e5m2, is_2x2x=False, ) with use_jax_gemm(enabled=with_jax_gemm): primitive_out = tex.gemm( x, w, contracting_dims=contracting_dims, lhs_quantizer=( quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad ), rhs_quantizer=( quantizer_set.kernel if w_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad ), ) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn) # TODO(Phuong): add bitwise test @pytest.mark.skipif(not is_fp4_supported, reason=fp4_unsupported_reason) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("scaling_mode_pair", supported_nvfp4_scaling_mode_pairs) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) @pytest_parametrize_wrapper("with_jax_gemm", [True, False]) def test_gemm_nvfp4(self, m, n, k, scaling_mode_pair, data_layout, with_jax_gemm): x_uses_rht = scaling_mode_pair[0] == ScalingMode.NVFP4_1D_SCALING and data_layout[0] == "T" w_uses_rht = scaling_mode_pair[1] == ScalingMode.NVFP4_1D_SCALING and data_layout[1] == "N" if x_uses_rht != w_uses_rht: # TODO(jberchtold): Ideally avoid a skip here and rewrite test setup to ensure both or neither use RHT pytest.skip("RHT must be used for both or neither operand, skipping") lhs_scaling_mode, rhs_scaling_mode = scaling_mode_pair x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) lhs_quantizer = QuantizerFactory.create( scaling_mode=lhs_scaling_mode, q_dtype=jnp.float4_e2m1fn, ) rhs_quantizer = QuantizerFactory.create( scaling_mode=rhs_scaling_mode, q_dtype=jnp.float4_e2m1fn, ) with use_jax_gemm(enabled=with_jax_gemm): primitive_out = tex.gemm( x, w, contracting_dims=contracting_dims, lhs_quantizer=lhs_quantizer, rhs_quantizer=rhs_quantizer, ) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) assert_allclose(primitive_out, ref_out, dtype=jnp.float4_e2m1fn) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) def test_dense_grad_bf16(self, m, n, k): data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) def primitive_func(x, w, contracting_dims): primitive_out = dense(x, w, contracting_dims=contracting_dims) return jnp.mean(primitive_out) def ref_func(x, w, data_layout): return jnp.mean(self._ref_gemm_with_jnp_dot(x, w, data_layout)) value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1)) value_n_grad_ref_func = value_and_grad(ref_func, (0, 1)) primitive_out, (primitive_x_grad, primitive_w_grad) = value_n_grad_primitive_func( x, w, contracting_dims ) ref_out, (ref_x_grad, ref_w_grad) = value_n_grad_ref_func(x, w, data_layout) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16) assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16) @pytest_parametrize_wrapper("m,n,k", [(64, 128, 128)]) @pytest_parametrize_wrapper("recipe", supported_recipes) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_dense_grad_fp8_and_fp4(self, m, n, k, recipe, with_jax_gemm): data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) key = jax.random.PRNGKey(1) bias = jax.random.uniform(key, n, dtype=jnp.bfloat16) def primitive_func(x, w, bias, contracting_dims, quantizer_set): primitive_out = dense( x, w, bias, contracting_dims=contracting_dims, quantizer_set=quantizer_set ) return jnp.mean(primitive_out) def ref_func(x, w, bias, data_layout): return jnp.mean( self._ref_gemm_with_jnp_dot(x, w, data_layout) + jnp.expand_dims(bias, axis=0) ) value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe) n_iterations = 3 if recipe.delayed() else 1 with use_jax_gemm(enabled=with_jax_gemm): for _ in range(n_iterations): primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = ( value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set) ) ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func( x, w, bias, data_layout ) assert_allclose(primitive_out, ref_out, dtype=quantizer_set.x.q_dtype) assert_allclose(primitive_x_grad, ref_x_grad, dtype=quantizer_set.dgrad.q_dtype) assert_allclose(primitive_w_grad, ref_w_grad, dtype=quantizer_set.dgrad.q_dtype) assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=quantizer_set.dgrad.q_dtype) @pytest.fixture(name="random_inputs") def random_inputs_fixture(shape): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 4) out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8) return out def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer): if norm_type == "rmsnorm": ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer) else: ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer) ln_out = ln_out.dequantize() return ln_out class TestFusedDense: @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("m,n,k", [(64, 128, 128)]) @pytest_parametrize_wrapper("recipe", supported_recipes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_dense_grad(self, m, n, k, recipe, norm_type, with_jax_gemm): """ Test layernorm_dense VJP Rule """ # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False eps = 1e-6 key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 4) # NN in FWD x = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16) / jnp.sqrt(k) w = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16) / jnp.sqrt(n) gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16) quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe) if norm_type == "layernorm": beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16) else: beta = None def prim_func(x, w, gamma, beta): # bias = None as quantize_dbias is already tested in test_dense_grad_fp8 prim_out = layernorm_dense( x, w, gamma, beta, None, norm_type, zero_centered_gamma, eps, quantizer_set=quantizer_set, ) return jnp.mean(prim_out) def ref_func(x, w, gamma, beta): x = _ref_jax_norm_impl( x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None ) return jnp.mean(jnp.dot(x, w)) value_n_grad_prim_func = value_and_grad(prim_func, (0, 1, 2, 3)) value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3)) ref_out, (ref_x_grad, ref_w_grad, ref_gamma_grad, ref_beta_grad) = value_n_grad_ref_func( x, w, gamma, beta ) n_iterations = 3 if recipe.delayed() else 1 with use_jax_gemm(enabled=with_jax_gemm): for _ in range(n_iterations): prim_out, ( prim_x_grad, prim_w_grad, prim_gamma_grad, prim_beta_grad, ) = value_n_grad_prim_func(x, w, gamma, beta) assert_allclose(prim_out, ref_out, dtype=quantizer_set.x.q_dtype) assert_allclose(prim_x_grad, ref_x_grad, dtype=quantizer_set.dgrad.q_dtype) assert_allclose(prim_w_grad, ref_w_grad, dtype=quantizer_set.dgrad.q_dtype) assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=quantizer_set.dgrad.q_dtype) if beta is not None: assert_allclose(prim_beta_grad, ref_beta_grad, dtype=quantizer_set.dgrad.q_dtype) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("m,n,k", [(64, 128, 128)]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("recipe", supported_recipes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( self, m, n, k, activation_type, recipe, norm_type, use_bias, with_jax_gemm ): """ Test layernorm_mlp VJP Rule """ # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False eps = 1e-6 key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 6) x = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16) kernel_1 = jax.random.normal( subkeys[1], (k, len(activation_type), n), jnp.bfloat16 ) / jnp.sqrt(k) kernel_2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n) gamma = jax.random.normal(subkeys[5], (k,), jnp.bfloat16) beta = None # was tested in TestNorm if use_bias: bias_1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16) bias_2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16) else: bias_1 = None bias_2 = None quantizer_sets = QuantizerFactory.create_set( n_quantizer_sets=2, fp8_recipe=recipe, ) if norm_type == "layernorm": beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16) else: beta = None def prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): return jnp.mean( layernorm_mlp( x, gamma, beta, [kernel_1, kernel_2], [bias_1, bias_2], norm_type, zero_centered_gamma=zero_centered_gamma, epsilon=eps, activation_type=activation_type, quantizer_sets=quantizer_sets, ) ) def _ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ln_out = _ref_jax_norm_impl( x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None ) linear_1_out = jax.lax.dot_general(ln_out, kernel_1, (((1,), (0,)), ((), ()))) if use_bias: bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape linear_1_out += jnp.reshape(bias_1, bias_1_shape) x = _jax_act_lu(linear_1_out, activation_type).data linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ()))) if use_bias: bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape linear_2_out += jnp.reshape(bias_2, bias_2_shape) return linear_2_out def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): return jnp.mean(_ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2)) value_n_grad_prim_func = value_and_grad(prim_func, range(6)) value_n_grad_ref_func = value_and_grad(ref_func, range(6)) n_iterations = 3 if recipe.delayed() else 1 with use_jax_gemm(enabled=with_jax_gemm): for _ in range(n_iterations): prim_out, ( prim_x_grad, prim_gamma_grad, prim_kernel_1_grad, prim_kernel_2_grad, prim_bias_1_grad, prim_bias_2_grad, ) = value_n_grad_prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) ref_out, ( ref_x_grad, ref_gamma_grad, ref_kernel_1_grad, ref_kernel_2_grad, ref_bias_1_grad, ref_bias_2_grad, ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) fwd_dtype = quantizer_sets[0].x.q_dtype bwd_dtype = quantizer_sets[0].dgrad.q_dtype assert_allclose(prim_out, ref_out, dtype=fwd_dtype) assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=bwd_dtype) assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=bwd_dtype) assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=bwd_dtype) assert_allclose(prim_x_grad, ref_x_grad, dtype=bwd_dtype) if use_bias: assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=bwd_dtype) assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=bwd_dtype) # E5M2 * E5M2 is not supported fwd_bwd_dtypes = [ [jnp.float8_e4m3fn, jnp.float8_e4m3fn], [jnp.float8_e4m3fn, jnp.float8_e5m2], [jnp.float8_e5m2, jnp.float8_e4m3fn], ] GROUPED_DENSE_INPUT_SHAPES = [ # (n_groups, m, n, k), the actual m will be multiplied by 32 (5, 32, 128, 64), # Test the case where n_groups is not a multiple of 4 (8, 64, 32, 128), (8, 64, 128, 256), ] @pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES) class TestGroupedDense: def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims): lhs_contract_dim, _ = contracting_dims assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3 if bias is None: bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype) else: assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2]) remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop() lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis) rhs = jnp.split(rhs, rhs.shape[0], axis=0) bias = jnp.split(bias, bias.shape[0], axis=0) ref_out = [] dim_num = (contracting_dims, ((), ())) for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias): out_i = jax.lax.dot_general( lhs_i, rhs_i, dim_num, precision=jax.lax.Precision.HIGHEST ) + jnp.expand_dims(bias_i, axis=0) ref_out.append(jnp.squeeze(out_i)) return ref_out def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 4) n_groups, m, n, k = input_shape group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) group_sizes = jnp.diff(group_sizes) # Make one empty input lhs to test empty GEMM handling group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1]) group_sizes = group_sizes.at[1].set(0) assert group_sizes.sum() == m # *32 to make sure that input shape works for MXFP8 group_sizes = group_sizes * 32 m = m * 32 lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m) rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k) bias_shape = (n_groups, n) lhs = jax.random.uniform(subkeys[1], lhs_shape, dtype=dtype) rhs = jax.random.uniform(subkeys[2], rhs_shape, dtype=dtype) bias = jax.random.uniform(subkeys[3], bias_shape, dtype=dtype) if with_bias else None lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) rhs_contracting_dim = (1,) if data_layout[1] == "N" else (2,) contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) return lhs, rhs, group_sizes, contracting_dims, bias def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype): assert out.dtype == ref_list[0].dtype out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0) for i in range(len(ref_list)): assert_allclose(out_list[i], ref_list[i], dtype=dtype) @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) @pytest_parametrize_wrapper("layout", ["NN"]) def test_grouped_gemm_fp16(self, dtype, input_shape, layout): lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( dtype, input_shape, layout ) num_gemms = input_shape[0] _ = jax.jit(tex.grouped_gemm_copy_group_sizes, static_argnames=("num_gemms",))( group_sizes, num_gemms=num_gemms, ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) # jitting grouped_gemm prim_out = jax.jit( tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes") )( lhs, rhs, group_sizes, contracting_dims, use_async_d2h_group_sizes=True, ) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) @pytest_parametrize_wrapper("layout", ["NN"]) def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout): fwd_dtype, bwd_dtype = fwd_bwd_dtype quantizer_set = QuantizerFactory.create_set( scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=False, n_groups=input_shape[0], ) # quantizer_set.{x, kernel} has fwd_dtype, while quantizer_set.grad has bwd_dtype # We want to test E4M3 * E5M2, manually set the quantizer_set.kernel.q_dtype to bwd_dtype quantizer_set.kernel.q_dtype = bwd_dtype for quantizer in quantizer_set.kernel.quantizers: quantizer.q_dtype = bwd_dtype out_dtype = jnp.bfloat16 lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( out_dtype, input_shape, layout ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set ) allclose_dtype = jnp.float8_e4m3fn if jnp.float8_e5m2 in fwd_bwd_dtype: allclose_dtype = jnp.float8_e5m2 self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, allclose_dtype) def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims): out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims) # Note: we use jnp.sum instead of jnp.mean to make the gradient larger # and prevent them from being clamp to zero in FP8. / sqrt(x.size) is used to # normalize the output and prevent the gradient from being too large for FP8. out_sum_list = [jnp.sum(out) for out in out_list] return jnp.sum(jnp.asarray(out_sum_list)) / jnp.sqrt(x.size) def _primitive_sum_grouped_dense( self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set ): out = grouped_dense( x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set ) return jnp.sum(jnp.asarray(out)) / jnp.sqrt(x.size) @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) def test_grouped_dense_grad_fp16(self, dtype, input_shape): x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( dtype, input_shape, with_bias=True, ) value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) # jitting the grouped_dense value_n_grad_prim_func = jit( value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,) ) ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( x, kernel, bias, group_sizes, contracting_dims ) prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func( x, kernel, bias, group_sizes, contracting_dims ) assert_allclose(prim_out_sum, ref_out_sum, dtype=dtype) assert_allclose(prim_dgrad, ref_dgrad, dtype=dtype) assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype) assert_allclose(prim_dbias, ref_dbias, dtype=dtype) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize( "fwd_bwd_dtype", [(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)], ) @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): fwd_dtype, bwd_dtype = fwd_bwd_dtype dtype = jnp.bfloat16 x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( dtype, input_shape, with_bias=True, ) quantizer_set = QuantizerFactory.create_set( scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True, n_groups=group_sizes.size, ) value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) # jitting the grouped_dense value_n_grad_prim_func = jit( value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,) ) ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( x, kernel, bias, group_sizes, contracting_dims, ) prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func( x, kernel, bias, group_sizes, contracting_dims, quantizer_set=quantizer_set ) assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype) assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype) assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) assert_allclose(prim_dbias, ref_dbias, dtype=dtype)