Unverified Commit 818b30cc authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] NVFP4 recipe with option to enable/disable SR, RHT, and 2D quantization (#2270)



* [JAX] Support recipe flags for disabling SR, RHT, and 2D quantization
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix issue with SR state being erased due to pytree handling of NVFP4Quantizer
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Add test for SR state preservation across VJP boundaries
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix sharding of SR rng state
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* update tolerances slightly now that SR is enabled
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Use hashlib for deterministic hashes across runs for SR
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* rename uses_rht on scaled tensors to has_applied_rht
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* add assert
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Move decision of whether to use RHT into helper.py and add dedicated RHT tests
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix use_rht attr usage
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix pure-jax rht usage criteria
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Adjust tolerances after rebase
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent ce2e8bd1
...@@ -672,7 +672,7 @@ class TestEncoder(unittest.TestCase): ...@@ -672,7 +672,7 @@ class TestEncoder(unittest.TestCase):
def test_te_nvfp4(self): def test_te_nvfp4(self):
"""Test Transformer Engine with NVFP4""" """Test Transformer Engine with NVFP4"""
result = self.exec(True, "NVFP4BlockScaling") result = self.exec(True, "NVFP4BlockScaling")
assert result[0] < 0.451 and result[1] > 0.79 assert result[0] < 0.451 and result[1] > 0.788
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self): def test_te_bf16_shardy(self):
...@@ -710,7 +710,7 @@ class TestEncoder(unittest.TestCase): ...@@ -710,7 +710,7 @@ class TestEncoder(unittest.TestCase):
def test_te_nvfp4_shardy(self): def test_te_nvfp4_shardy(self):
"""Test Transformer Engine with NVFP4""" """Test Transformer Engine with NVFP4"""
result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True) result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True)
assert result[0] < 0.451 and result[1] > 0.79 assert result[0] < 0.451 and result[1] > 0.788
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -390,7 +390,7 @@ class TestEncoder(unittest.TestCase): ...@@ -390,7 +390,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling" self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.476 and actual[1] > 0.775 assert actual[0] < 0.477 and actual[1] > 0.769
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -40,7 +40,6 @@ from transformer_engine.jax.quantize import ( ...@@ -40,7 +40,6 @@ from transformer_engine.jax.quantize import (
QuantizerFactory, QuantizerFactory,
QuantizeLayout, QuantizeLayout,
noop_quantizer_set, noop_quantizer_set,
should_use_rht,
) )
from transformer_engine.jax.quantize import helper from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation from transformer_engine.jax.activation import activation
...@@ -685,21 +684,14 @@ class TestQuantize: ...@@ -685,21 +684,14 @@ class TestQuantize:
Purely quantization related tests that will always test on a wider set of types and shapes Purely quantization related tests that will always test on a wider set of types and shapes
""" """
def _skip_for_fp4(self, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): def _skip_unsupported_dtypes(self, q_dtype, scaling_mode):
"""Temporary hack to skip unsupported FP4 cases until we implement them""" """Skip unsupported dtypes for given scaling mode. For example, NVFP4 only supports the float4_e2m1 dtype not float8 dtypes."""
if q_dtype not in scaling_mode.get_compatible_q_dtypes(): if q_dtype not in scaling_mode.get_compatible_q_dtypes():
pytest.skip(f"Quantize dtype {q_dtype} is not supported by {scaling_mode}") pytest.skip(f"Quantize dtype {q_dtype} is not supported by {scaling_mode}")
return return
# HACK: FIXME TODO(jberchtold)
row = reduce(operator.mul, input_shape[flatten_axis:], 1)
col = reduce(operator.mul, input_shape[:flatten_axis], 1)
will_use_rht = should_use_rht(scaling_mode, q_layout=q_layout)
if will_use_rht and (row % 64 != 0 or col % 128 != 0):
pytest.skip("Unfused RHT is not supported currently, skipping")
def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis) self._skip_unsupported_dtypes(q_dtype, scaling_mode)
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
...@@ -780,22 +772,8 @@ class TestQuantize: ...@@ -780,22 +772,8 @@ class TestQuantize:
assert_dequantized_scaled_tensor(scaled_tensor, x) assert_dequantized_scaled_tensor(scaled_tensor, x)
def _should_use_precise_comparison( def _should_use_precise_comparison(
self, in_dtype, scaling_mode, q_layout, input_shape, flatten_axis self, in_dtype, scaling_mode, quantizer, input_shape, flatten_axis
): ):
# TODO(jberchtold): Remove this hack once we have a better solution to ensure bitwise identical results between TE and JAX RHT+quant implementations. Currently for certain shapes the quantized fp4 data differs by a small amount on <0.5% of the values.
RHT_SLIGHT_MISMATCH_SHAPES = [
((32, 256, 128), -1),
((64, 32, 32, 256), -1),
((8192, 2, 4096), -2),
]
if (
should_use_rht(scaling_mode, q_layout=q_layout)
and (input_shape, flatten_axis) in RHT_SLIGHT_MISMATCH_SHAPES
):
# TE fused RHT+quant and JAX RHT+quant have slight implementation differences which can lead to small numerical differences on certain shapes
return False
if scaling_mode.is_nvfp4_scaling and in_dtype != jnp.bfloat16: if scaling_mode.is_nvfp4_scaling and in_dtype != jnp.bfloat16:
# With NVFP4 scaling, TE kernels internally use bfloat16 so using a different input dtype can lead to small numerical differences compared to the JAX implementation # With NVFP4 scaling, TE kernels internally use bfloat16 so using a different input dtype can lead to small numerical differences compared to the JAX implementation
return False return False
...@@ -805,7 +783,7 @@ class TestQuantize: ...@@ -805,7 +783,7 @@ class TestQuantize:
def test_quantize_bitwise( def test_quantize_bitwise(
self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
): ):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis) self._skip_unsupported_dtypes(q_dtype, scaling_mode)
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype) input = jax.random.uniform(key, input_shape, in_dtype)
...@@ -816,28 +794,20 @@ class TestQuantize: ...@@ -816,28 +794,20 @@ class TestQuantize:
jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
try: te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
except AssertionError as e:
if should_use_rht(scaling_mode, q_layout=q_layout) and in_dtype != jnp.bfloat16:
error_message = e.args[0]
if "RHT requires input to be bfloat16" in error_message:
# Successfully caught the expected error, early return from the test
return
raise e
assert_bitwise_scaled_tensors( assert_bitwise_scaled_tensors(
te_output, te_output,
jax_output, jax_output,
precise_comparison=self._should_use_precise_comparison( precise_comparison=self._should_use_precise_comparison(
in_dtype, scaling_mode, q_layout, input_shape, flatten_axis in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis
), ),
) )
def test_quantize_bitwise_jitted( def test_quantize_bitwise_jitted(
self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
): ):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis) self._skip_unsupported_dtypes(q_dtype, scaling_mode)
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype) input = jax.random.uniform(key, input_shape, in_dtype)
...@@ -851,21 +821,13 @@ class TestQuantize: ...@@ -851,21 +821,13 @@ class TestQuantize:
jax_output = jax_impl_func_jit(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) jax_output = jax_impl_func_jit(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
try: te_output = te_impl_func_jit(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
te_output = te_impl_func_jit(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
except AssertionError as e:
if should_use_rht(scaling_mode, q_layout=q_layout) and in_dtype != jnp.bfloat16:
error_message = e.args[0]
if "RHT requires input to be bfloat16" in error_message:
# Successfully caught the expected error, early return from the test
return
raise e
assert_bitwise_scaled_tensors( assert_bitwise_scaled_tensors(
te_output, te_output,
jax_output, jax_output,
precise_comparison=self._should_use_precise_comparison( precise_comparison=self._should_use_precise_comparison(
in_dtype, scaling_mode, q_layout, input_shape, flatten_axis in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis
), ),
) )
...@@ -985,12 +947,6 @@ class TestStochasticRounding: ...@@ -985,12 +947,6 @@ class TestStochasticRounding:
def test_sr_nvfp4(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): def test_sr_nvfp4(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
"""Tests that the mean absolute error of stochastic rounding is smaller than round nearest quantization over multiple samples for both TE and JAX implementations. Asserts that the MAE of both implementations is close to each other.""" """Tests that the mean absolute error of stochastic rounding is smaller than round nearest quantization over multiple samples for both TE and JAX implementations. Asserts that the MAE of both implementations is close to each other."""
# HACK: FIXME TODO(jberchtold)
row = reduce(operator.mul, input_shape[flatten_axis:], 1)
col = reduce(operator.mul, input_shape[:flatten_axis], 1)
will_use_rht = should_use_rht(scaling_mode, q_layout=q_layout)
if will_use_rht and (row % 64 != 0 or col % 128 != 0):
pytest.skip("Unfused RHT is not supported currently, skipping")
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
inputs = jax.random.uniform(key, input_shape, in_dtype) inputs = jax.random.uniform(key, input_shape, in_dtype)
...@@ -1007,6 +963,97 @@ class TestStochasticRounding: ...@@ -1007,6 +963,97 @@ class TestStochasticRounding:
assert_allclose(te_mean_error, jax_mean_error, rtol=0.2, atol=1e-4) assert_allclose(te_mean_error, jax_mean_error, rtol=0.2, atol=1e-4)
@pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16])
@pytest_parametrize_wrapper("q_dtype", [jnp.float4_e2m1fn])
@pytest_parametrize_wrapper(
"scaling_mode", [s for s in supported_scaling_modes if s == ScalingMode.NVFP4_1D_SCALING]
)
class TestRandomizedHadamardTransform:
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE]
)
@pytest_parametrize_wrapper("input_shape,flatten_axis", [((64, 128), -1)])
def test_rht_quantize_bitwise_jitted(
self, in_dtype, q_dtype, scaling_mode, q_layout, input_shape, flatten_axis
):
key = jax.random.PRNGKey(0)
inputs = jax.random.uniform(key, input_shape, in_dtype)
te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2,
q_dtype=q_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
use_rht=True,
)
jax_impl_func_jit = jax.jit(_jax_quantize, static_argnums=(2, 3))
te_impl_func_jit = jax.jit(tex.quantize, static_argnums=(2,))
jax_output = jax_impl_func_jit(inputs, quantizer=jax_quantizer, flatten_axis=flatten_axis)
te_output = te_impl_func_jit(inputs, quantizer=te_quantizer, flatten_axis=flatten_axis)
assert_bitwise_scaled_tensors(te_output, jax_output)
def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
if data_layout[0] == "T":
a = jnp.swapaxes(a, -1, -2)
if data_layout[1] == "T":
b = jnp.swapaxes(b, -1, -2)
return jnp.dot(a, b)
def _generate_gemm_input(self, m, n, k, data_layout):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
x = jax.random.uniform(
subkeys[0],
(m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
dtype=jnp.bfloat16,
) / jnp.sqrt(k)
w = jax.random.uniform(
subkeys[1],
(k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
dtype=jnp.bfloat16,
) / jnp.sqrt(n)
lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
return (x, w, contracting_dims)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
# We do not test NN and TT layouts here as they do not have both inputs using RHT due to RHT only supporting the colwise layout currently
@pytest_parametrize_wrapper("data_layout", ["TN", "NT"])
@pytest_parametrize_wrapper("with_jax_gemm", [True, False])
def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, with_jax_gemm):
key = jax.random.PRNGKey(0)
lhs_scaling_mode, rhs_scaling_mode = scaling_mode, scaling_mode
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
lhs_quantizer = QuantizerFactory.create(
scaling_mode=lhs_scaling_mode,
q_dtype=jnp.float4_e2m1fn,
use_rht=True,
)
rhs_quantizer = QuantizerFactory.create(
scaling_mode=rhs_scaling_mode,
q_dtype=jnp.float4_e2m1fn,
use_rht=True,
)
with use_jax_gemm(enabled=with_jax_gemm):
primitive_out = tex.gemm(
x,
w,
contracting_dims=contracting_dims,
lhs_quantizer=lhs_quantizer,
rhs_quantizer=rhs_quantizer,
)
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=jnp.float4_e2m1fn)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)]) @pytest_parametrize_wrapper("input_shape", [(8, 16, 32)])
......
...@@ -3,11 +3,13 @@ ...@@ -3,11 +3,13 @@
# See LICENSE for license information. # See LICENSE for license information.
import unittest import unittest
from functools import partial
import flax import flax
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from flax import linen as nn
from utils import assert_allclose from utils import assert_allclose
from transformer_engine.common.recipe import ( from transformer_engine.common.recipe import (
...@@ -24,15 +26,51 @@ from transformer_engine.jax.quantize import ( ...@@ -24,15 +26,51 @@ from transformer_engine.jax.quantize import (
ScalingMode, ScalingMode,
update_collections, update_collections,
TensorSource, TensorSource,
QuantizerFactory,
QuantizeLayout,
) )
from transformer_engine.jax.quantize.helper import _format2dtypes from transformer_engine.jax.quantize.helper import _format2dtypes
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
from transformer_engine.jax.flax.module import TransformerEngineBase
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
def quantizer_check_vjp(outer_quantizer_set, assertion_func, x):
"""Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries."""
# Define a function with a custom VJP (vector-Jacobian product)
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def quantizer_check(inner_quantizer_set, assertion_func, x):
return quantizer_check_fwd(inner_quantizer_set, assertion_func, x)
def quantizer_check_fwd(inner_quantizer_set, assertion_func, x):
assertion_func(inner_quantizer_set.x, TensorSource.X)
assertion_func(inner_quantizer_set.kernel, TensorSource.KERNEL)
assertion_func(inner_quantizer_set.dgrad, TensorSource.DGRAD)
return x
def quantizer_check_bwd(ctx, g):
return (g,)
quantizer_check.defvjp(quantizer_check_fwd, quantizer_check_bwd)
return quantizer_check(outer_quantizer_set, assertion_func, x)
class TestModule(TransformerEngineBase):
"""A simple module to test quantizer creation and reconstruction across VJP boundaries."""
# Signature: (quantizer: Quantizer, tensor_source: TensorSource) -> None
assertion_func: callable
@nn.compact
def __call__(self, x):
quantizer_set = self.generate_quantizer_set()
return quantizer_check_vjp(quantizer_set, self.assertion_func, x)
class TestHelper(unittest.TestCase): class TestHelper(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason) @unittest.skipIf(not is_fp8_supported, reason=reason)
...@@ -89,12 +127,43 @@ class TestFP8Functions(unittest.TestCase): ...@@ -89,12 +127,43 @@ class TestFP8Functions(unittest.TestCase):
for tensor_source in TensorSource: for tensor_source in TensorSource:
target_scaling_mode = ( target_scaling_mode = (
ScalingMode.NVFP4_2D_SCALING ScalingMode.NVFP4_2D_SCALING
if tensor_source == TensorSource.KERNEL if (not test.disable_2d_quantization) and tensor_source == TensorSource.KERNEL
else ScalingMode.NVFP4_1D_SCALING else ScalingMode.NVFP4_1D_SCALING
) )
self.assertEqual( self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode
) )
self.assertEqual(
get_quantize_config().DISABLE_STOCHASTIC_ROUNDING, test.disable_stochastic_rounding
)
self.assertEqual(get_quantize_config().DISABLE_RHT, test.disable_rht)
self.assertEqual(
get_quantize_config().DISABLE_2D_QUANTIZATION, test.disable_2d_quantization
)
def _compare_nvfp4_scaling_quantizers(self, test):
"""Check that the quantizers created have the expected stochastic rounding state and the state is preserved across VJP boundaries."""
def assertion_func(quantizer, tensor_source):
if test.disable_stochastic_rounding or tensor_source != TensorSource.DGRAD:
self.assertIsNone(quantizer.stochastic_rounding_rng_state)
else:
self.assertIsNotNone(quantizer.stochastic_rounding_rng_state)
expected_rht = (
quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
and quantizer.q_layout in {QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE}
and not test.disable_rht
)
self.assertEqual(quantizer.use_rht, expected_rht)
x = jnp.ones((), dtype=jnp.float32)
test_module = TestModule(assertion_func=assertion_func)
param_key, sr_key = jax.random.split(jax.random.PRNGKey(0))
rngs = {"params": param_key, "sr_rng": sr_key}
variables = test_module.init(rngs, x)
jax.jit(jax.value_and_grad(test_module.apply), static_argnums=(2,))(variables, x, rngs=rngs)
@unittest.skipIf(not is_fp8_supported, reason=reason) @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_autocast_delayed_scaling(self): def test_autocast_delayed_scaling(self):
...@@ -171,5 +240,16 @@ class TestFP8Functions(unittest.TestCase): ...@@ -171,5 +240,16 @@ class TestFP8Functions(unittest.TestCase):
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()): with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled()) self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_nvfp4_scaling(bs) self._compare_nvfp4_scaling(bs)
self._compare_nvfp4_scaling_quantizers(bs)
bs = NVFP4BlockScaling(
disable_stochastic_rounding=True,
disable_rht=True,
disable_2d_quantization=True,
)
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_nvfp4_scaling(bs)
self._compare_nvfp4_scaling_quantizers(bs)
self._check_default_state() self._check_default_state()
...@@ -44,7 +44,6 @@ from ..quantize import ( ...@@ -44,7 +44,6 @@ from ..quantize import (
noop_quantizer_set, noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv, apply_padding_to_scale_inv,
should_use_rht,
) )
from .misc import get_padded_spec, is_all_reduce_in_float32 from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import ( from ..sharding import (
...@@ -169,16 +168,13 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ ...@@ -169,16 +168,13 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_
assert not isinstance(lhs_q, ScaledTensor2x) assert not isinstance(lhs_q, ScaledTensor2x)
assert not isinstance(rhs_q, ScaledTensor2x) assert not isinstance(rhs_q, ScaledTensor2x)
def uses_rht(q: AbstractBaseTensor) -> bool: def has_rht_applied(q: AbstractBaseTensor) -> bool:
return isinstance(q, ScaledTensor1x) and should_use_rht( return isinstance(q, ScaledTensor1x) and q.has_rht_applied
q.scaling_mode, is_colwise=q.is_colwise
)
# TODO(jberchtold): Move RHT usage check to a bool flag on the ScaledTensor class assert has_rht_applied(lhs_q) == has_rht_applied(rhs_q), (
assert uses_rht(lhs_q) == uses_rht(rhs_q), ( "With NVFP4_1D_SCALING, if one operand is quantized with RHT, the other must be quantized"
"With NVFP4_1D_SCALING, if one operand is colwise quantized, the other must be colwise" " with RHT as well. This is to ensure the RHT is applied to both and will cancel out in the"
" quantized as well. This is to ensure the RHT is applied to both and will cancel out in" " GEMM."
" the GEMM."
) )
return lhs_q, rhs_q return lhs_q, rhs_q
......
...@@ -31,7 +31,7 @@ from .misc import ( ...@@ -31,7 +31,7 @@ from .misc import (
from ..sharding import ( from ..sharding import (
all_reduce_max_along_all_axes_except_PP, all_reduce_max_along_all_axes_except_PP,
all_reduce_sum_along_dp_fsdp, all_reduce_sum_along_dp_fsdp,
num_of_devices, get_num_devices_in_mesh,
) )
from ..quantize import ( from ..quantize import (
ScaledTensor2x, ScaledTensor2x,
...@@ -45,7 +45,6 @@ from ..quantize import ( ...@@ -45,7 +45,6 @@ from ..quantize import (
compute_scale_from_amax, compute_scale_from_amax,
NoScaleTensor, NoScaleTensor,
get_rht_matrix, get_rht_matrix,
should_use_rht,
) )
...@@ -108,17 +107,18 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -108,17 +107,18 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
"sr_rng_state must be a uint32 array when stochastic_rounding is True but" "sr_rng_state must be a uint32 array when stochastic_rounding is True but"
f" received {sr_rng_state_aval}" f" received {sr_rng_state_aval}"
) )
if is_outer: if is_outer and get_num_devices_in_mesh() > 1:
assert ( assert (
sr_rng_state_aval.shape[0] == num_of_devices() sr_rng_state_aval.shape[0] == get_num_devices_in_mesh()
and sr_rng_state_aval.shape[1] == 4 and sr_rng_state_aval.shape[1] == 4
), ( ), (
"sr_rng_state must be of shape (num_devices, 4) when stochastic_rounding is" "sr_rng_state must be of shape (num_devices, 4) when stochastic_rounding is"
f" True and is_outer is True but received {sr_rng_state_aval.shape}" f" True and is_outer is True but received {sr_rng_state_aval.shape}"
) )
else: else:
assert sr_rng_state_aval.shape == (4,), ( # We cannot assert the shape is exactly (4,) here because if the quantized data is not perfectly sharded across all devices then we will have extra rng state here. For example, this could occur when the weights are not sharded when using data parallelism. However, this is okay because the extra rng state will simply not be used and each device still has a unique rng state.
"Sharded sr_rng_state must be of shape (4,) per device when" assert sr_rng_state_aval.size >= 4, (
"Sharded sr_rng_state must have at least 4 elements per device when"
f" stochastic_rounding is True but received {sr_rng_state_aval.shape}" f" stochastic_rounding is True but received {sr_rng_state_aval.shape}"
) )
...@@ -552,8 +552,13 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -552,8 +552,13 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
desc="BaseDBiasQuantizePrimitive.colwise_scale_inv", desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
) )
# TODO(jberchtold): Assert the sr_rng state is sharded along all mesh axes arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings[3] = NamedSharding(
mesh,
PartitionSpec(tuple(x for x in x_spec if x is not None), None),
desc="BaseDBiasQuantizePrimitive.sr_rng_state",
)
arg_shardings = tuple(arg_shardings)
out_shardings = ( out_shardings = (
out_sharding, out_sharding,
colwise_out_sharding, colwise_out_sharding,
...@@ -564,6 +569,9 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -564,6 +569,9 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
) )
def sharded_impl(x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix): def sharded_impl(x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix):
if sr_rng_state.size > 4:
# See comment in abstract method for explanation of why we cannot assert exact shape
sr_rng_state = sr_rng_state.flatten()[:4]
( (
local_x, local_x,
local_colwise_x, local_colwise_x,
...@@ -754,9 +762,10 @@ def _quantize_dbias_impl( ...@@ -754,9 +762,10 @@ def _quantize_dbias_impl(
# If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
# fall back on the native-JAX quantize implementation # fall back on the native-JAX quantize implementation
PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
is_unsupported = ( is_unsupported = quantizer.q_layout == QuantizeLayout.COLWISE and not (
quantizer.q_layout == QuantizeLayout.COLWISE quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
and quantizer.scaling_mode != ScalingMode.NVFP4_1D_SCALING and hasattr(quantizer, "use_rht")
and quantizer.use_rht
) )
if is_unsupported or not PrimitiveClass.enabled(): if is_unsupported or not PrimitiveClass.enabled():
if is_dbias: if is_dbias:
...@@ -792,7 +801,7 @@ def _quantize_dbias_impl( ...@@ -792,7 +801,7 @@ def _quantize_dbias_impl(
rht_matrix = jnp.empty((1, 1), jnp.bfloat16) rht_matrix = jnp.empty((1, 1), jnp.bfloat16)
amax = x.amax amax = x.amax
if should_use_rht(quantizer.scaling_mode, q_layout=quantizer.q_layout): if hasattr(quantizer, "use_rht") and quantizer.use_rht:
use_rht = True use_rht = True
rht_matrix = get_rht_matrix() rht_matrix = get_rht_matrix()
...@@ -861,7 +870,11 @@ def _quantize_dbias_impl( ...@@ -861,7 +870,11 @@ def _quantize_dbias_impl(
x.data, x.data,
scale, scale,
amax, amax,
sr_rng_state if sr_rng_state is not None else jnp.empty((num_of_devices(), 1), jnp.uint32), (
sr_rng_state
if sr_rng_state is not None
else jnp.empty((get_num_devices_in_mesh(), 1), jnp.uint32)
),
post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32), post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32),
rht_matrix, rht_matrix,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
...@@ -902,6 +915,7 @@ def _quantize_dbias_impl( ...@@ -902,6 +915,7 @@ def _quantize_dbias_impl(
q_layout=quantizer.q_layout, q_layout=quantizer.q_layout,
data_layout=quantizer.get_data_layout(), data_layout=quantizer.get_data_layout(),
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
colwise_has_rht_applied=use_rht,
) )
return out, dbias.astype(dq_dtype) return out, dbias.astype(dq_dtype)
......
...@@ -15,7 +15,7 @@ import jax ...@@ -15,7 +15,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .hadamard import apply_rht, should_use_rht from .hadamard import apply_rht
__all__ = ["ScalingModeToDequantizerMap"] __all__ = ["ScalingModeToDequantizerMap"]
...@@ -171,7 +171,9 @@ class NVFP4Dequantizer(Dequantizer): ...@@ -171,7 +171,9 @@ class NVFP4Dequantizer(Dequantizer):
""" """
@staticmethod @staticmethod
def _dequantize_func(data, scale_inv, amax, dq_dtype, scaling_mode, is_colwise, flatten_axis): def _dequantize_func(
data, scale_inv, amax, dq_dtype, scaling_mode, is_colwise, flatten_axis, has_rht_applied
):
"""Dequantize a tensor using block scaling. """Dequantize a tensor using block scaling.
Args: Args:
...@@ -182,6 +184,7 @@ class NVFP4Dequantizer(Dequantizer): ...@@ -182,6 +184,7 @@ class NVFP4Dequantizer(Dequantizer):
scaling_mode: The scaling mode used for quantization scaling_mode: The scaling mode used for quantization
is_colwise: Whether the scaling is column-wise is_colwise: Whether the scaling is column-wise
flatten_axis: The axis along which the tensor could be flattened to 2D flatten_axis: The axis along which the tensor could be flattened to 2D
has_rht_applied: Whether the quantization has RHT applied and we need to apply the inverse RHT to dequantize
Returns: Returns:
The dequantized tensor The dequantized tensor
...@@ -223,8 +226,7 @@ class NVFP4Dequantizer(Dequantizer): ...@@ -223,8 +226,7 @@ class NVFP4Dequantizer(Dequantizer):
out = jnp.asarray(data * scale_inv, dq_dtype).reshape(data_shape) out = jnp.asarray(data * scale_inv, dq_dtype).reshape(data_shape)
# Apply inverse of RHT if needed # Apply inverse of RHT if needed
use_rht = should_use_rht(scaling_mode, is_colwise=is_colwise) if has_rht_applied:
if use_rht:
out = apply_rht(out, inverse=True) out = apply_rht(out, inverse=True)
return out return out
...@@ -247,6 +249,7 @@ class NVFP4Dequantizer(Dequantizer): ...@@ -247,6 +249,7 @@ class NVFP4Dequantizer(Dequantizer):
scaled_tensor.scaling_mode, scaled_tensor.scaling_mode,
scaled_tensor.is_colwise, scaled_tensor.is_colwise,
scaled_tensor.flatten_axis, scaled_tensor.flatten_axis,
scaled_tensor.has_rht_applied,
) )
......
...@@ -4,32 +4,6 @@ ...@@ -4,32 +4,6 @@
"""Randomized Hadamard Transform (RHT) utilities for JAX.""" """Randomized Hadamard Transform (RHT) utilities for JAX."""
import jax.numpy as jnp import jax.numpy as jnp
from .scaling_modes import ScalingMode
def should_use_rht(scaling_mode, is_colwise=None, q_layout=None) -> bool:
"""Determine if RHT (Randomized Hadamard Transform) should be used.
Args:
scaling_mode: The scaling mode of the tensor.
is_colwise: Whether the tensor is column-wise. Only one of is_colwise or q_layout should be provided.
q_layout: The quantization layout of the tensor. Only one of is_colwise or q_layout should be provided.
Returns:
bool: True if RHT should be used, False otherwise.
"""
# Delayed import to avoid circular dependencies
from .quantizer import QuantizeLayout
assert (is_colwise is None) != (
q_layout is None
), "Exactly one of is_colwise or q_layout must be provided."
if q_layout is not None:
is_colwise = q_layout in {QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE}
return scaling_mode == ScalingMode.NVFP4_1D_SCALING and is_colwise
def get_wgrad_sign_vector() -> list[int]: def get_wgrad_sign_vector() -> list[int]:
"""Get a fixed sign vector for the RHT used in NVFP4 weight gradient quantization.""" """Get a fixed sign vector for the RHT used in NVFP4 weight gradient quantization."""
......
...@@ -12,6 +12,7 @@ from abc import ABC, abstractmethod ...@@ -12,6 +12,7 @@ from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
import hashlib
from typing import Optional, Tuple, Dict, Union, Sequence, Type, List from typing import Optional, Tuple, Dict, Union, Sequence, Type, List
from functools import reduce, lru_cache from functools import reduce, lru_cache
import operator import operator
...@@ -35,7 +36,7 @@ from transformer_engine.common.recipe import ( ...@@ -35,7 +36,7 @@ from transformer_engine.common.recipe import (
from transformer_engine.jax.sharding import ( from transformer_engine.jax.sharding import (
global_shard_guard, global_shard_guard,
MeshResource, MeshResource,
num_of_devices, get_num_devices_in_mesh,
get_all_mesh_axes, get_all_mesh_axes,
with_sharding_constraint, with_sharding_constraint,
) )
...@@ -561,29 +562,87 @@ class BlockScalingQuantizeConfig(BaseQuantizeConfig): ...@@ -561,29 +562,87 @@ class BlockScalingQuantizeConfig(BaseQuantizeConfig):
return QuantizeMeta() return QuantizeMeta()
@dataclass
class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig): class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for NVFP4 scaling recipe. """Configuration class for NVFP4 scaling recipe.
This class provides specific initialization and finalization for NVFP4 scaling quantization mode. This class provides specific initialization and finalization for NVFP4 scaling quantization mode.
""" """
DISABLE_STOCHASTIC_ROUNDING: bool = False
DISABLE_RHT: bool = False
DISABLE_2D_QUANTIZATION: bool = False
def initialize_from_recipe(self, fp8_recipe: Recipe) -> None: def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
"""Initialize block scaling FP8 configuration. """Initialize block scaling NVFP4 configuration.
Args: Args:
fp8_recipe: The FP8 recipe to use for initialization fp8_recipe: The quantization recipe to use for initialization
""" """
assert isinstance(fp8_recipe, NVFP4BlockScaling)
self.INITIALIZED = True self.INITIALIZED = True
self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(fp8_recipe.fp4_format) self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(fp8_recipe.fp4_format)
self.AMAX_HISTORY_LEN = 0 self.AMAX_HISTORY_LEN = 0
self.DISABLE_STOCHASTIC_ROUNDING = fp8_recipe.disable_stochastic_rounding
self.DISABLE_RHT = fp8_recipe.disable_rht
self.DISABLE_2D_QUANTIZATION = fp8_recipe.disable_2d_quantization
def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
"""Gets the scaling mode for a specific tensor's usage type.""" """Gets the scaling mode for a specific tensor's usage type."""
if tensor_source == TensorSource.KERNEL: if (not self.DISABLE_2D_QUANTIZATION) and tensor_source == TensorSource.KERNEL:
return ScalingMode.NVFP4_2D_SCALING return ScalingMode.NVFP4_2D_SCALING
# for x and grad # for x and grad
return ScalingMode.NVFP4_1D_SCALING return ScalingMode.NVFP4_1D_SCALING
def _make_rht_quantize_meta(self, q_layout, tensor_source: TensorSource) -> QuantizeMeta:
"""Create the quantization metadata for RHT if applicable."""
# Imported here to prevent circular import
from transformer_engine.jax.quantize import QuantizeLayout
use_rht = self.get_scaling_mode(
tensor_source
) == ScalingMode.NVFP4_1D_SCALING and q_layout in {
QuantizeLayout.ROWWISE_COLWISE,
QuantizeLayout.COLWISE,
}
if self.DISABLE_RHT:
use_rht = False
return QuantizeMeta(use_rht=use_rht)
def _make_stochastic_rounding_rng_state(
self, module, tensor_source: TensorSource, quantizer_name: str
) -> jnp.ndarray:
"""Create the stochastic rounding rng state if applicable."""
if self.DISABLE_STOCHASTIC_ROUNDING:
return QuantizeMeta()
if tensor_source != TensorSource.DGRAD:
# Only DGRAD uses stochastic rounding
return QuantizeMeta()
sr_jax_rng = module.make_rng("sr_rng")
# Get a unique key for this quantizer
# Use hashlib to get a deterministic hash value for quantizer_name
quantizer_hash = (
int(hashlib.sha256(quantizer_name.encode("utf-8")).hexdigest(), 16)
% jnp.iinfo(jnp.int32).max
)
sr_jax_rng = jax.jit(jax.random.fold_in)(sr_jax_rng, quantizer_hash)
# Generate 4 random uint32 values from the JAX PRNG key
shape = (4,)
if get_num_devices_in_mesh() > 1:
shape = (get_num_devices_in_mesh(), 4)
sr_jax_rng_state = jax.random.randint(
sr_jax_rng, shape, 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32
).view(jnp.uint32)
sr_jax_rng_state = with_sharding_constraint(
sr_jax_rng_state, jax.sharding.PartitionSpec(get_all_mesh_axes(), None)
)
return QuantizeMeta(stochastic_rounding_rng_state=sr_jax_rng_state)
def get_quantize_flax_meta( def get_quantize_flax_meta(
self, self,
module, module,
...@@ -603,27 +662,14 @@ class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig): ...@@ -603,27 +662,14 @@ class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig):
Returns: Returns:
The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
""" """
if tensor_source != TensorSource.DGRAD: # Imported here to prevent circular import
# Only DGRAD uses stochastic rounding from transformer_engine.jax.quantize import QuantizeLayout
return QuantizeMeta()
# TODO(jberchtold): This assumes SR is always enabled for NVFP4. Use flag from recipe to toggle it.
sr_jax_rng = module.make_rng("sr_rng")
# Get a unique key for this quantizer
sr_jax_rng = jax.jit(jax.random.fold_in)(
sr_jax_rng, hash(quantizer_name) % jnp.iinfo(jnp.int32).max
)
# Generate 4 random uint32 values from the JAX PRNG key return QuantizeMeta.merge(
sr_jax_rng_state = jax.random.randint( self._make_rht_quantize_meta(QuantizeLayout.ROWWISE_COLWISE, tensor_source),
sr_jax_rng, (num_of_devices(), 4), 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32 self._make_stochastic_rounding_rng_state(module, tensor_source, quantizer_name),
).view(jnp.uint32)
sr_jax_rng_state = with_sharding_constraint(
sr_jax_rng_state, jax.sharding.PartitionSpec(get_all_mesh_axes(), None)
) )
return QuantizeMeta(stochastic_rounding_rng_state=sr_jax_rng_state)
_QUANTIZE_CONFIG = NoOpQuantizeConfig() _QUANTIZE_CONFIG = NoOpQuantizeConfig()
......
...@@ -26,6 +26,26 @@ class QuantizeMeta: ...@@ -26,6 +26,26 @@ class QuantizeMeta:
""" """
@staticmethod
def merge(a: "QuantizeMeta", b: "QuantizeMeta") -> "QuantizeMeta":
"""Merge two QuantizeMeta instances.
Args:
a (QuantizeMeta): The first QuantizeMeta instance.
b (QuantizeMeta): The second QuantizeMeta instance.
Returns:
QuantizeMeta: A new QuantizeMeta instance with merged metadata.
"""
assert isinstance(a, QuantizeMeta)
assert isinstance(b, QuantizeMeta)
for key in b.get_kwargs_dictionary().keys():
if key in a.get_kwargs_dictionary():
assert (
a.get_kwargs_dictionary()[key] == b.get_kwargs_dictionary()[key]
), f"Conflict in merging QuantizeMeta: {key} has different values."
return QuantizeMeta(**{**a.get_kwargs_dictionary(), **b.get_kwargs_dictionary()})
def __init__(self, **kwargs): def __init__(self, **kwargs):
self._kwargs = kwargs self._kwargs = kwargs
......
...@@ -19,7 +19,7 @@ from transformer_engine_jax import QuantizeLayout ...@@ -19,7 +19,7 @@ from transformer_engine_jax import QuantizeLayout
from transformer_engine.common import recipe from transformer_engine.common import recipe
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .hadamard import apply_rht, should_use_rht from .hadamard import apply_rht
from .tensor import ( from .tensor import (
ScaledTensor, ScaledTensor,
ScaledTensor1x, ScaledTensor1x,
...@@ -590,11 +590,13 @@ class NVFP4Quantizer(Quantizer): ...@@ -590,11 +590,13 @@ class NVFP4Quantizer(Quantizer):
q_layout: Quantization axis q_layout: Quantization axis
data_layout: Data layout string (default: "NT") data_layout: Data layout string (default: "NT")
stochastic_rounding_rng_state: RNG state for stochastic rounding, must be of shape (4,) and dtype uint32. If None, stochastic rounding is disabled. stochastic_rounding_rng_state: RNG state for stochastic rounding, must be of shape (4,) and dtype uint32. If None, stochastic rounding is disabled.
use_rht: Whether to apply Randomized Hadamard Transform (RHT) before quantization.
""" """
scaling_mode: ScalingMode = ScalingMode.NVFP4_1D_SCALING scaling_mode: ScalingMode = ScalingMode.NVFP4_1D_SCALING
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
data_layout: str = "NT" data_layout: str = "NT"
use_rht: bool = False
stochastic_rounding_rng_state: Optional[jnp.ndarray] = None stochastic_rounding_rng_state: Optional[jnp.ndarray] = None
def __post_init__(self): def __post_init__(self):
...@@ -603,6 +605,30 @@ class NVFP4Quantizer(Quantizer): ...@@ -603,6 +605,30 @@ class NVFP4Quantizer(Quantizer):
), "NVFP4 quantization must use a q_dtype of float4_e2m1fn" ), "NVFP4 quantization must use a q_dtype of float4_e2m1fn"
assert self.scaling_mode.is_nvfp4_scaling, "NVFP4Quantizer must use NVFP4 scaling modes" assert self.scaling_mode.is_nvfp4_scaling, "NVFP4Quantizer must use NVFP4 scaling modes"
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children = (self.stochastic_rounding_rng_state,)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout, self.use_rht)
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Reconstruct a quantizer from its flattened representation.
Args:
aux_data: Auxiliary data containing quantizer parameters
children: Unused children data
Returns:
A reconstructed Quantizer instance
"""
stochastic_rounding_rng_state = children[0]
return cls(*aux_data, stochastic_rounding_rng_state=stochastic_rounding_rng_state)
def _apply_stochastic_rounding(self, x): def _apply_stochastic_rounding(self, x):
assert ( assert (
self.stochastic_rounding_rng_state is not None self.stochastic_rounding_rng_state is not None
...@@ -688,8 +714,9 @@ class NVFP4Quantizer(Quantizer): ...@@ -688,8 +714,9 @@ class NVFP4Quantizer(Quantizer):
flatten_axis = x.ndim - flatten_axis flatten_axis = x.ndim - flatten_axis
x_shape = x.shape x_shape = x.shape
if should_use_rht(self.scaling_mode, is_colwise=is_colwise): # We currently only have a single flag 'use_rht' on the quantizer. To avoid an unused rowwise flag, we assume RHT is only used for colwise quantization for now.
# We only apply RHT for 1D colwise nvfp4 use_rht = self.use_rht and is_colwise and self.scaling_mode == ScalingMode.NVFP4_1D_SCALING
if use_rht:
x = apply_rht(x) x = apply_rht(x)
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
...@@ -790,6 +817,7 @@ class NVFP4Quantizer(Quantizer): ...@@ -790,6 +817,7 @@ class NVFP4Quantizer(Quantizer):
scaling_mode=self.scaling_mode, scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
flatten_axis=rowwise_flatten_axis, flatten_axis=rowwise_flatten_axis,
has_rht_applied=use_rht,
) )
......
...@@ -175,6 +175,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): ...@@ -175,6 +175,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
is_colwise: Whether the tensor uses column-wise quantization is_colwise: Whether the tensor uses column-wise quantization
data_layout: The data_layout specification for the tensor data_layout: The data_layout specification for the tensor
flatten_axis: The quantization axis for the tensor flatten_axis: The quantization axis for the tensor
has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization
""" """
scale_inv: jnp.ndarray scale_inv: jnp.ndarray
...@@ -184,6 +185,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): ...@@ -184,6 +185,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
is_colwise: bool is_colwise: bool
data_layout: str data_layout: str
flatten_axis: int flatten_axis: int
has_rht_applied: bool
def __post_init__(self): def __post_init__(self):
"""Validates and adjusts the scale_inv shape after initialization. """Validates and adjusts the scale_inv shape after initialization.
...@@ -243,6 +245,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): ...@@ -243,6 +245,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
self.is_colwise, self.is_colwise,
self.data_layout, self.data_layout,
self.flatten_axis, self.flatten_axis,
self.has_rht_applied,
) )
return (children, aux_data) return (children, aux_data)
...@@ -314,6 +317,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): ...@@ -314,6 +317,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
is_colwise=self.is_colwise, is_colwise=self.is_colwise,
data_layout=self.data_layout, data_layout=self.data_layout,
flatten_axis=self.flatten_axis, flatten_axis=self.flatten_axis,
has_rht_applied=self.has_rht_applied,
) )
...@@ -354,6 +358,7 @@ class GroupedScaledTensor1x(ScaledTensor1x): ...@@ -354,6 +358,7 @@ class GroupedScaledTensor1x(ScaledTensor1x):
self.group_sizes = group_sizes self.group_sizes = group_sizes
self.original_shape = original_shape self.original_shape = original_shape
self.group_axis = group_axis self.group_axis = group_axis
# TODO(Phuong):Handle RHT for grouped quantization once grouped quantization supports NVFP4
super().__init__( super().__init__(
data=data, data=data,
scale_inv=scale_inv, scale_inv=scale_inv,
...@@ -364,6 +369,7 @@ class GroupedScaledTensor1x(ScaledTensor1x): ...@@ -364,6 +369,7 @@ class GroupedScaledTensor1x(ScaledTensor1x):
is_colwise=is_colwise, is_colwise=is_colwise,
data_layout=data_layout, data_layout=data_layout,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
has_rht_applied=False,
) )
def __post_init__(self): def __post_init__(self):
...@@ -515,6 +521,7 @@ class ScaledTensorFactory: ...@@ -515,6 +521,7 @@ class ScaledTensorFactory:
group_sizes=None, group_sizes=None,
original_shape=None, original_shape=None,
group_axis=0, group_axis=0,
has_rht_applied=False,
): ):
"""Creates a single-scale quantized tensor. """Creates a single-scale quantized tensor.
...@@ -530,6 +537,7 @@ class ScaledTensorFactory: ...@@ -530,6 +537,7 @@ class ScaledTensorFactory:
group_sizes: Array of ints containing the size of each group (default: None) group_sizes: Array of ints containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None) original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0) group_axis: The axis along which grouping is performed (default: 0)
has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization (default: False)
Returns: Returns:
A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided
...@@ -593,6 +601,7 @@ class ScaledTensorFactory: ...@@ -593,6 +601,7 @@ class ScaledTensorFactory:
is_colwise=is_colwise, is_colwise=is_colwise,
data_layout=data_layout, data_layout=data_layout,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
has_rht_applied=has_rht_applied,
) )
@staticmethod @staticmethod
...@@ -610,6 +619,8 @@ class ScaledTensorFactory: ...@@ -610,6 +619,8 @@ class ScaledTensorFactory:
group_sizes=None, group_sizes=None,
original_shape=None, original_shape=None,
group_axis=0, group_axis=0,
rowwise_has_rht_applied=False,
colwise_has_rht_applied=False,
): ):
"""Creates a double-scale quantized tensor. """Creates a double-scale quantized tensor.
...@@ -626,6 +637,8 @@ class ScaledTensorFactory: ...@@ -626,6 +637,8 @@ class ScaledTensorFactory:
group_sizes: Array containing the size of each group (default: None) group_sizes: Array containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None) original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0) group_axis: The axis along which grouping is performed (default: 0)
rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
colwise_has_rht_applied: Whether the column-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
Returns: Returns:
A ScaledTensor2x instance A ScaledTensor2x instance
...@@ -648,6 +661,7 @@ class ScaledTensorFactory: ...@@ -648,6 +661,7 @@ class ScaledTensorFactory:
group_sizes=group_sizes, group_sizes=group_sizes,
original_shape=original_shape, original_shape=original_shape,
group_axis=group_axis, group_axis=group_axis,
has_rht_applied=rowwise_has_rht_applied,
) )
colwise_tensor = ScaledTensorFactory.create_1x( colwise_tensor = ScaledTensorFactory.create_1x(
colwise_data, colwise_data,
...@@ -661,6 +675,7 @@ class ScaledTensorFactory: ...@@ -661,6 +675,7 @@ class ScaledTensorFactory:
group_sizes=group_sizes, group_sizes=group_sizes,
original_shape=original_shape, original_shape=original_shape,
group_axis=group_axis, group_axis=group_axis,
has_rht_applied=colwise_has_rht_applied,
) )
return ScaledTensor2x(rowwise_tensor, colwise_tensor) return ScaledTensor2x(rowwise_tensor, colwise_tensor)
...@@ -680,6 +695,8 @@ class ScaledTensorFactory: ...@@ -680,6 +695,8 @@ class ScaledTensorFactory:
group_sizes: jnp.ndarray = None, group_sizes: jnp.ndarray = None,
original_shape: Tuple[int] = None, original_shape: Tuple[int] = None,
group_axis: int = 0, group_axis: int = 0,
rowwise_has_rht_applied: bool = False,
colwise_has_rht_applied: bool = False,
): ):
"""Creates a scaled tensor based on the quantization axis. """Creates a scaled tensor based on the quantization axis.
...@@ -696,10 +713,14 @@ class ScaledTensorFactory: ...@@ -696,10 +713,14 @@ class ScaledTensorFactory:
group_sizes: Array containing the size of each group (default: None) group_sizes: Array containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None) original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0) group_axis: The axis along which grouping is performed (default: 0)
rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
colwise_has_rht_applied: Whether the col-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
Returns: Returns:
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout
""" """
assert not rowwise_has_rht_applied, "RHT is not supported for rowwise quantization yet"
if q_layout == QuantizeLayout.ROWWISE_COLWISE: if q_layout == QuantizeLayout.ROWWISE_COLWISE:
return ScaledTensorFactory.create_2x( return ScaledTensorFactory.create_2x(
data, data,
...@@ -715,6 +736,8 @@ class ScaledTensorFactory: ...@@ -715,6 +736,8 @@ class ScaledTensorFactory:
group_sizes=group_sizes, group_sizes=group_sizes,
original_shape=original_shape, original_shape=original_shape,
group_axis=group_axis, group_axis=group_axis,
rowwise_has_rht_applied=rowwise_has_rht_applied,
colwise_has_rht_applied=colwise_has_rht_applied,
) )
is_colwise = q_layout == QuantizeLayout.COLWISE is_colwise = q_layout == QuantizeLayout.COLWISE
...@@ -731,6 +754,7 @@ class ScaledTensorFactory: ...@@ -731,6 +754,7 @@ class ScaledTensorFactory:
group_sizes=group_sizes, group_sizes=group_sizes,
original_shape=original_shape, original_shape=original_shape,
group_axis=group_axis, group_axis=group_axis,
has_rht_applied=colwise_has_rht_applied,
) )
return ScaledTensorFactory.create_1x( return ScaledTensorFactory.create_1x(
...@@ -745,6 +769,7 @@ class ScaledTensorFactory: ...@@ -745,6 +769,7 @@ class ScaledTensorFactory:
group_sizes=group_sizes, group_sizes=group_sizes,
original_shape=original_shape, original_shape=original_shape,
group_axis=group_axis, group_axis=group_axis,
has_rht_applied=rowwise_has_rht_applied,
) )
......
...@@ -238,6 +238,19 @@ def num_of_devices(): ...@@ -238,6 +238,19 @@ def num_of_devices():
return len(jax.devices()) return len(jax.devices())
def get_num_devices_in_mesh(mesh=None):
"""
Get the number of devices in the given mesh.
If the mesh is None, it would be replaced
by the global mesh.
"""
if mesh is None:
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh.empty:
return 1
return np.prod(list(mesh.shape.values()))
def get_mesh_axis_size(axis, mesh=None): def get_mesh_axis_size(axis, mesh=None):
""" """
Get the axis size of the given mesh. Get the axis size of the given mesh.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment