Unverified Commit 818b30cc authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] NVFP4 recipe with option to enable/disable SR, RHT, and 2D quantization (#2270)



* [JAX] Support recipe flags for disabling SR, RHT, and 2D quantization
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix issue with SR state being erased due to pytree handling of NVFP4Quantizer
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Add test for SR state preservation across VJP boundaries
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix sharding of SR rng state
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* update tolerances slightly now that SR is enabled
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Use hashlib for deterministic hashes across runs for SR
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* rename uses_rht on scaled tensors to has_applied_rht
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* add assert
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Move decision of whether to use RHT into helper.py and add dedicated RHT tests
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix use_rht attr usage
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix pure-jax rht usage criteria
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Adjust tolerances after rebase
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent ce2e8bd1
......@@ -672,7 +672,7 @@ class TestEncoder(unittest.TestCase):
def test_te_nvfp4(self):
"""Test Transformer Engine with NVFP4"""
result = self.exec(True, "NVFP4BlockScaling")
assert result[0] < 0.451 and result[1] > 0.79
assert result[0] < 0.451 and result[1] > 0.788
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
......@@ -710,7 +710,7 @@ class TestEncoder(unittest.TestCase):
def test_te_nvfp4_shardy(self):
"""Test Transformer Engine with NVFP4"""
result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True)
assert result[0] < 0.451 and result[1] > 0.79
assert result[0] < 0.451 and result[1] > 0.788
if __name__ == "__main__":
......
......@@ -390,7 +390,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.476 and actual[1] > 0.775
assert actual[0] < 0.477 and actual[1] > 0.769
if __name__ == "__main__":
......
......@@ -40,7 +40,6 @@ from transformer_engine.jax.quantize import (
QuantizerFactory,
QuantizeLayout,
noop_quantizer_set,
should_use_rht,
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
......@@ -685,21 +684,14 @@ class TestQuantize:
Purely quantization related tests that will always test on a wider set of types and shapes
"""
def _skip_for_fp4(self, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
"""Temporary hack to skip unsupported FP4 cases until we implement them"""
def _skip_unsupported_dtypes(self, q_dtype, scaling_mode):
"""Skip unsupported dtypes for given scaling mode. For example, NVFP4 only supports the float4_e2m1 dtype not float8 dtypes."""
if q_dtype not in scaling_mode.get_compatible_q_dtypes():
pytest.skip(f"Quantize dtype {q_dtype} is not supported by {scaling_mode}")
return
# HACK: FIXME TODO(jberchtold)
row = reduce(operator.mul, input_shape[flatten_axis:], 1)
col = reduce(operator.mul, input_shape[:flatten_axis], 1)
will_use_rht = should_use_rht(scaling_mode, q_layout=q_layout)
if will_use_rht and (row % 64 != 0 or col % 128 != 0):
pytest.skip("Unfused RHT is not supported currently, skipping")
def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis)
self._skip_unsupported_dtypes(q_dtype, scaling_mode)
key = jax.random.PRNGKey(0)
......@@ -780,22 +772,8 @@ class TestQuantize:
assert_dequantized_scaled_tensor(scaled_tensor, x)
def _should_use_precise_comparison(
self, in_dtype, scaling_mode, q_layout, input_shape, flatten_axis
):
# TODO(jberchtold): Remove this hack once we have a better solution to ensure bitwise identical results between TE and JAX RHT+quant implementations. Currently for certain shapes the quantized fp4 data differs by a small amount on <0.5% of the values.
RHT_SLIGHT_MISMATCH_SHAPES = [
((32, 256, 128), -1),
((64, 32, 32, 256), -1),
((8192, 2, 4096), -2),
]
if (
should_use_rht(scaling_mode, q_layout=q_layout)
and (input_shape, flatten_axis) in RHT_SLIGHT_MISMATCH_SHAPES
self, in_dtype, scaling_mode, quantizer, input_shape, flatten_axis
):
# TE fused RHT+quant and JAX RHT+quant have slight implementation differences which can lead to small numerical differences on certain shapes
return False
if scaling_mode.is_nvfp4_scaling and in_dtype != jnp.bfloat16:
# With NVFP4 scaling, TE kernels internally use bfloat16 so using a different input dtype can lead to small numerical differences compared to the JAX implementation
return False
......@@ -805,7 +783,7 @@ class TestQuantize:
def test_quantize_bitwise(
self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis)
self._skip_unsupported_dtypes(q_dtype, scaling_mode)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
......@@ -816,28 +794,20 @@ class TestQuantize:
jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
try:
te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
except AssertionError as e:
if should_use_rht(scaling_mode, q_layout=q_layout) and in_dtype != jnp.bfloat16:
error_message = e.args[0]
if "RHT requires input to be bfloat16" in error_message:
# Successfully caught the expected error, early return from the test
return
raise e
assert_bitwise_scaled_tensors(
te_output,
jax_output,
precise_comparison=self._should_use_precise_comparison(
in_dtype, scaling_mode, q_layout, input_shape, flatten_axis
in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis
),
)
def test_quantize_bitwise_jitted(
self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis)
self._skip_unsupported_dtypes(q_dtype, scaling_mode)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
......@@ -851,21 +821,13 @@ class TestQuantize:
jax_output = jax_impl_func_jit(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
try:
te_output = te_impl_func_jit(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
except AssertionError as e:
if should_use_rht(scaling_mode, q_layout=q_layout) and in_dtype != jnp.bfloat16:
error_message = e.args[0]
if "RHT requires input to be bfloat16" in error_message:
# Successfully caught the expected error, early return from the test
return
raise e
assert_bitwise_scaled_tensors(
te_output,
jax_output,
precise_comparison=self._should_use_precise_comparison(
in_dtype, scaling_mode, q_layout, input_shape, flatten_axis
in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis
),
)
......@@ -985,12 +947,6 @@ class TestStochasticRounding:
def test_sr_nvfp4(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
"""Tests that the mean absolute error of stochastic rounding is smaller than round nearest quantization over multiple samples for both TE and JAX implementations. Asserts that the MAE of both implementations is close to each other."""
# HACK: FIXME TODO(jberchtold)
row = reduce(operator.mul, input_shape[flatten_axis:], 1)
col = reduce(operator.mul, input_shape[:flatten_axis], 1)
will_use_rht = should_use_rht(scaling_mode, q_layout=q_layout)
if will_use_rht and (row % 64 != 0 or col % 128 != 0):
pytest.skip("Unfused RHT is not supported currently, skipping")
key = jax.random.PRNGKey(0)
inputs = jax.random.uniform(key, input_shape, in_dtype)
......@@ -1007,6 +963,97 @@ class TestStochasticRounding:
assert_allclose(te_mean_error, jax_mean_error, rtol=0.2, atol=1e-4)
@pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16])
@pytest_parametrize_wrapper("q_dtype", [jnp.float4_e2m1fn])
@pytest_parametrize_wrapper(
"scaling_mode", [s for s in supported_scaling_modes if s == ScalingMode.NVFP4_1D_SCALING]
)
class TestRandomizedHadamardTransform:
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE]
)
@pytest_parametrize_wrapper("input_shape,flatten_axis", [((64, 128), -1)])
def test_rht_quantize_bitwise_jitted(
self, in_dtype, q_dtype, scaling_mode, q_layout, input_shape, flatten_axis
):
key = jax.random.PRNGKey(0)
inputs = jax.random.uniform(key, input_shape, in_dtype)
te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2,
q_dtype=q_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
use_rht=True,
)
jax_impl_func_jit = jax.jit(_jax_quantize, static_argnums=(2, 3))
te_impl_func_jit = jax.jit(tex.quantize, static_argnums=(2,))
jax_output = jax_impl_func_jit(inputs, quantizer=jax_quantizer, flatten_axis=flatten_axis)
te_output = te_impl_func_jit(inputs, quantizer=te_quantizer, flatten_axis=flatten_axis)
assert_bitwise_scaled_tensors(te_output, jax_output)
def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
if data_layout[0] == "T":
a = jnp.swapaxes(a, -1, -2)
if data_layout[1] == "T":
b = jnp.swapaxes(b, -1, -2)
return jnp.dot(a, b)
def _generate_gemm_input(self, m, n, k, data_layout):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
x = jax.random.uniform(
subkeys[0],
(m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
dtype=jnp.bfloat16,
) / jnp.sqrt(k)
w = jax.random.uniform(
subkeys[1],
(k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
dtype=jnp.bfloat16,
) / jnp.sqrt(n)
lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
return (x, w, contracting_dims)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
# We do not test NN and TT layouts here as they do not have both inputs using RHT due to RHT only supporting the colwise layout currently
@pytest_parametrize_wrapper("data_layout", ["TN", "NT"])
@pytest_parametrize_wrapper("with_jax_gemm", [True, False])
def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, with_jax_gemm):
key = jax.random.PRNGKey(0)
lhs_scaling_mode, rhs_scaling_mode = scaling_mode, scaling_mode
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
lhs_quantizer = QuantizerFactory.create(
scaling_mode=lhs_scaling_mode,
q_dtype=jnp.float4_e2m1fn,
use_rht=True,
)
rhs_quantizer = QuantizerFactory.create(
scaling_mode=rhs_scaling_mode,
q_dtype=jnp.float4_e2m1fn,
use_rht=True,
)
with use_jax_gemm(enabled=with_jax_gemm):
primitive_out = tex.gemm(
x,
w,
contracting_dims=contracting_dims,
lhs_quantizer=lhs_quantizer,
rhs_quantizer=rhs_quantizer,
)
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=jnp.float4_e2m1fn)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)])
......
......@@ -3,11 +3,13 @@
# See LICENSE for license information.
import unittest
from functools import partial
import flax
import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from utils import assert_allclose
from transformer_engine.common.recipe import (
......@@ -24,15 +26,51 @@ from transformer_engine.jax.quantize import (
ScalingMode,
update_collections,
TensorSource,
QuantizerFactory,
QuantizeLayout,
)
from transformer_engine.jax.quantize.helper import _format2dtypes
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
from transformer_engine.jax.flax.module import TransformerEngineBase
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
def quantizer_check_vjp(outer_quantizer_set, assertion_func, x):
"""Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries."""
# Define a function with a custom VJP (vector-Jacobian product)
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def quantizer_check(inner_quantizer_set, assertion_func, x):
return quantizer_check_fwd(inner_quantizer_set, assertion_func, x)
def quantizer_check_fwd(inner_quantizer_set, assertion_func, x):
assertion_func(inner_quantizer_set.x, TensorSource.X)
assertion_func(inner_quantizer_set.kernel, TensorSource.KERNEL)
assertion_func(inner_quantizer_set.dgrad, TensorSource.DGRAD)
return x
def quantizer_check_bwd(ctx, g):
return (g,)
quantizer_check.defvjp(quantizer_check_fwd, quantizer_check_bwd)
return quantizer_check(outer_quantizer_set, assertion_func, x)
class TestModule(TransformerEngineBase):
"""A simple module to test quantizer creation and reconstruction across VJP boundaries."""
# Signature: (quantizer: Quantizer, tensor_source: TensorSource) -> None
assertion_func: callable
@nn.compact
def __call__(self, x):
quantizer_set = self.generate_quantizer_set()
return quantizer_check_vjp(quantizer_set, self.assertion_func, x)
class TestHelper(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason)
......@@ -89,12 +127,43 @@ class TestFP8Functions(unittest.TestCase):
for tensor_source in TensorSource:
target_scaling_mode = (
ScalingMode.NVFP4_2D_SCALING
if tensor_source == TensorSource.KERNEL
if (not test.disable_2d_quantization) and tensor_source == TensorSource.KERNEL
else ScalingMode.NVFP4_1D_SCALING
)
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode
)
self.assertEqual(
get_quantize_config().DISABLE_STOCHASTIC_ROUNDING, test.disable_stochastic_rounding
)
self.assertEqual(get_quantize_config().DISABLE_RHT, test.disable_rht)
self.assertEqual(
get_quantize_config().DISABLE_2D_QUANTIZATION, test.disable_2d_quantization
)
def _compare_nvfp4_scaling_quantizers(self, test):
"""Check that the quantizers created have the expected stochastic rounding state and the state is preserved across VJP boundaries."""
def assertion_func(quantizer, tensor_source):
if test.disable_stochastic_rounding or tensor_source != TensorSource.DGRAD:
self.assertIsNone(quantizer.stochastic_rounding_rng_state)
else:
self.assertIsNotNone(quantizer.stochastic_rounding_rng_state)
expected_rht = (
quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
and quantizer.q_layout in {QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE}
and not test.disable_rht
)
self.assertEqual(quantizer.use_rht, expected_rht)
x = jnp.ones((), dtype=jnp.float32)
test_module = TestModule(assertion_func=assertion_func)
param_key, sr_key = jax.random.split(jax.random.PRNGKey(0))
rngs = {"params": param_key, "sr_rng": sr_key}
variables = test_module.init(rngs, x)
jax.jit(jax.value_and_grad(test_module.apply), static_argnums=(2,))(variables, x, rngs=rngs)
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_autocast_delayed_scaling(self):
......@@ -171,5 +240,16 @@ class TestFP8Functions(unittest.TestCase):
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_nvfp4_scaling(bs)
self._compare_nvfp4_scaling_quantizers(bs)
bs = NVFP4BlockScaling(
disable_stochastic_rounding=True,
disable_rht=True,
disable_2d_quantization=True,
)
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_nvfp4_scaling(bs)
self._compare_nvfp4_scaling_quantizers(bs)
self._check_default_state()
......@@ -44,7 +44,6 @@ from ..quantize import (
noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv,
should_use_rht,
)
from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import (
......@@ -169,16 +168,13 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_
assert not isinstance(lhs_q, ScaledTensor2x)
assert not isinstance(rhs_q, ScaledTensor2x)
def uses_rht(q: AbstractBaseTensor) -> bool:
return isinstance(q, ScaledTensor1x) and should_use_rht(
q.scaling_mode, is_colwise=q.is_colwise
)
def has_rht_applied(q: AbstractBaseTensor) -> bool:
return isinstance(q, ScaledTensor1x) and q.has_rht_applied
# TODO(jberchtold): Move RHT usage check to a bool flag on the ScaledTensor class
assert uses_rht(lhs_q) == uses_rht(rhs_q), (
"With NVFP4_1D_SCALING, if one operand is colwise quantized, the other must be colwise"
" quantized as well. This is to ensure the RHT is applied to both and will cancel out in"
" the GEMM."
assert has_rht_applied(lhs_q) == has_rht_applied(rhs_q), (
"With NVFP4_1D_SCALING, if one operand is quantized with RHT, the other must be quantized"
" with RHT as well. This is to ensure the RHT is applied to both and will cancel out in the"
" GEMM."
)
return lhs_q, rhs_q
......
......@@ -31,7 +31,7 @@ from .misc import (
from ..sharding import (
all_reduce_max_along_all_axes_except_PP,
all_reduce_sum_along_dp_fsdp,
num_of_devices,
get_num_devices_in_mesh,
)
from ..quantize import (
ScaledTensor2x,
......@@ -45,7 +45,6 @@ from ..quantize import (
compute_scale_from_amax,
NoScaleTensor,
get_rht_matrix,
should_use_rht,
)
......@@ -108,17 +107,18 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
"sr_rng_state must be a uint32 array when stochastic_rounding is True but"
f" received {sr_rng_state_aval}"
)
if is_outer:
if is_outer and get_num_devices_in_mesh() > 1:
assert (
sr_rng_state_aval.shape[0] == num_of_devices()
sr_rng_state_aval.shape[0] == get_num_devices_in_mesh()
and sr_rng_state_aval.shape[1] == 4
), (
"sr_rng_state must be of shape (num_devices, 4) when stochastic_rounding is"
f" True and is_outer is True but received {sr_rng_state_aval.shape}"
)
else:
assert sr_rng_state_aval.shape == (4,), (
"Sharded sr_rng_state must be of shape (4,) per device when"
# We cannot assert the shape is exactly (4,) here because if the quantized data is not perfectly sharded across all devices then we will have extra rng state here. For example, this could occur when the weights are not sharded when using data parallelism. However, this is okay because the extra rng state will simply not be used and each device still has a unique rng state.
assert sr_rng_state_aval.size >= 4, (
"Sharded sr_rng_state must have at least 4 elements per device when"
f" stochastic_rounding is True but received {sr_rng_state_aval.shape}"
)
......@@ -552,8 +552,13 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
)
# TODO(jberchtold): Assert the sr_rng state is sharded along all mesh axes
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
arg_shardings[3] = NamedSharding(
mesh,
PartitionSpec(tuple(x for x in x_spec if x is not None), None),
desc="BaseDBiasQuantizePrimitive.sr_rng_state",
)
arg_shardings = tuple(arg_shardings)
out_shardings = (
out_sharding,
colwise_out_sharding,
......@@ -564,6 +569,9 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
)
def sharded_impl(x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix):
if sr_rng_state.size > 4:
# See comment in abstract method for explanation of why we cannot assert exact shape
sr_rng_state = sr_rng_state.flatten()[:4]
(
local_x,
local_colwise_x,
......@@ -754,9 +762,10 @@ def _quantize_dbias_impl(
# If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
# fall back on the native-JAX quantize implementation
PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
is_unsupported = (
quantizer.q_layout == QuantizeLayout.COLWISE
and quantizer.scaling_mode != ScalingMode.NVFP4_1D_SCALING
is_unsupported = quantizer.q_layout == QuantizeLayout.COLWISE and not (
quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
and hasattr(quantizer, "use_rht")
and quantizer.use_rht
)
if is_unsupported or not PrimitiveClass.enabled():
if is_dbias:
......@@ -792,7 +801,7 @@ def _quantize_dbias_impl(
rht_matrix = jnp.empty((1, 1), jnp.bfloat16)
amax = x.amax
if should_use_rht(quantizer.scaling_mode, q_layout=quantizer.q_layout):
if hasattr(quantizer, "use_rht") and quantizer.use_rht:
use_rht = True
rht_matrix = get_rht_matrix()
......@@ -861,7 +870,11 @@ def _quantize_dbias_impl(
x.data,
scale,
amax,
sr_rng_state if sr_rng_state is not None else jnp.empty((num_of_devices(), 1), jnp.uint32),
(
sr_rng_state
if sr_rng_state is not None
else jnp.empty((get_num_devices_in_mesh(), 1), jnp.uint32)
),
post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32),
rht_matrix,
out_dtype=quantizer.q_dtype,
......@@ -902,6 +915,7 @@ def _quantize_dbias_impl(
q_layout=quantizer.q_layout,
data_layout=quantizer.get_data_layout(),
flatten_axis=flatten_axis,
colwise_has_rht_applied=use_rht,
)
return out, dbias.astype(dq_dtype)
......
......@@ -15,7 +15,7 @@ import jax
import jax.numpy as jnp
from .scaling_modes import ScalingMode
from .hadamard import apply_rht, should_use_rht
from .hadamard import apply_rht
__all__ = ["ScalingModeToDequantizerMap"]
......@@ -171,7 +171,9 @@ class NVFP4Dequantizer(Dequantizer):
"""
@staticmethod
def _dequantize_func(data, scale_inv, amax, dq_dtype, scaling_mode, is_colwise, flatten_axis):
def _dequantize_func(
data, scale_inv, amax, dq_dtype, scaling_mode, is_colwise, flatten_axis, has_rht_applied
):
"""Dequantize a tensor using block scaling.
Args:
......@@ -182,6 +184,7 @@ class NVFP4Dequantizer(Dequantizer):
scaling_mode: The scaling mode used for quantization
is_colwise: Whether the scaling is column-wise
flatten_axis: The axis along which the tensor could be flattened to 2D
has_rht_applied: Whether the quantization has RHT applied and we need to apply the inverse RHT to dequantize
Returns:
The dequantized tensor
......@@ -223,8 +226,7 @@ class NVFP4Dequantizer(Dequantizer):
out = jnp.asarray(data * scale_inv, dq_dtype).reshape(data_shape)
# Apply inverse of RHT if needed
use_rht = should_use_rht(scaling_mode, is_colwise=is_colwise)
if use_rht:
if has_rht_applied:
out = apply_rht(out, inverse=True)
return out
......@@ -247,6 +249,7 @@ class NVFP4Dequantizer(Dequantizer):
scaled_tensor.scaling_mode,
scaled_tensor.is_colwise,
scaled_tensor.flatten_axis,
scaled_tensor.has_rht_applied,
)
......
......@@ -4,32 +4,6 @@
"""Randomized Hadamard Transform (RHT) utilities for JAX."""
import jax.numpy as jnp
from .scaling_modes import ScalingMode
def should_use_rht(scaling_mode, is_colwise=None, q_layout=None) -> bool:
"""Determine if RHT (Randomized Hadamard Transform) should be used.
Args:
scaling_mode: The scaling mode of the tensor.
is_colwise: Whether the tensor is column-wise. Only one of is_colwise or q_layout should be provided.
q_layout: The quantization layout of the tensor. Only one of is_colwise or q_layout should be provided.
Returns:
bool: True if RHT should be used, False otherwise.
"""
# Delayed import to avoid circular dependencies
from .quantizer import QuantizeLayout
assert (is_colwise is None) != (
q_layout is None
), "Exactly one of is_colwise or q_layout must be provided."
if q_layout is not None:
is_colwise = q_layout in {QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE}
return scaling_mode == ScalingMode.NVFP4_1D_SCALING and is_colwise
def get_wgrad_sign_vector() -> list[int]:
"""Get a fixed sign vector for the RHT used in NVFP4 weight gradient quantization."""
......
......@@ -12,6 +12,7 @@ from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
import hashlib
from typing import Optional, Tuple, Dict, Union, Sequence, Type, List
from functools import reduce, lru_cache
import operator
......@@ -35,7 +36,7 @@ from transformer_engine.common.recipe import (
from transformer_engine.jax.sharding import (
global_shard_guard,
MeshResource,
num_of_devices,
get_num_devices_in_mesh,
get_all_mesh_axes,
with_sharding_constraint,
)
......@@ -561,29 +562,87 @@ class BlockScalingQuantizeConfig(BaseQuantizeConfig):
return QuantizeMeta()
@dataclass
class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for NVFP4 scaling recipe.
This class provides specific initialization and finalization for NVFP4 scaling quantization mode.
"""
DISABLE_STOCHASTIC_ROUNDING: bool = False
DISABLE_RHT: bool = False
DISABLE_2D_QUANTIZATION: bool = False
def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
"""Initialize block scaling FP8 configuration.
"""Initialize block scaling NVFP4 configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
fp8_recipe: The quantization recipe to use for initialization
"""
assert isinstance(fp8_recipe, NVFP4BlockScaling)
self.INITIALIZED = True
self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(fp8_recipe.fp4_format)
self.AMAX_HISTORY_LEN = 0
self.DISABLE_STOCHASTIC_ROUNDING = fp8_recipe.disable_stochastic_rounding
self.DISABLE_RHT = fp8_recipe.disable_rht
self.DISABLE_2D_QUANTIZATION = fp8_recipe.disable_2d_quantization
def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
"""Gets the scaling mode for a specific tensor's usage type."""
if tensor_source == TensorSource.KERNEL:
if (not self.DISABLE_2D_QUANTIZATION) and tensor_source == TensorSource.KERNEL:
return ScalingMode.NVFP4_2D_SCALING
# for x and grad
return ScalingMode.NVFP4_1D_SCALING
def _make_rht_quantize_meta(self, q_layout, tensor_source: TensorSource) -> QuantizeMeta:
"""Create the quantization metadata for RHT if applicable."""
# Imported here to prevent circular import
from transformer_engine.jax.quantize import QuantizeLayout
use_rht = self.get_scaling_mode(
tensor_source
) == ScalingMode.NVFP4_1D_SCALING and q_layout in {
QuantizeLayout.ROWWISE_COLWISE,
QuantizeLayout.COLWISE,
}
if self.DISABLE_RHT:
use_rht = False
return QuantizeMeta(use_rht=use_rht)
def _make_stochastic_rounding_rng_state(
self, module, tensor_source: TensorSource, quantizer_name: str
) -> jnp.ndarray:
"""Create the stochastic rounding rng state if applicable."""
if self.DISABLE_STOCHASTIC_ROUNDING:
return QuantizeMeta()
if tensor_source != TensorSource.DGRAD:
# Only DGRAD uses stochastic rounding
return QuantizeMeta()
sr_jax_rng = module.make_rng("sr_rng")
# Get a unique key for this quantizer
# Use hashlib to get a deterministic hash value for quantizer_name
quantizer_hash = (
int(hashlib.sha256(quantizer_name.encode("utf-8")).hexdigest(), 16)
% jnp.iinfo(jnp.int32).max
)
sr_jax_rng = jax.jit(jax.random.fold_in)(sr_jax_rng, quantizer_hash)
# Generate 4 random uint32 values from the JAX PRNG key
shape = (4,)
if get_num_devices_in_mesh() > 1:
shape = (get_num_devices_in_mesh(), 4)
sr_jax_rng_state = jax.random.randint(
sr_jax_rng, shape, 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32
).view(jnp.uint32)
sr_jax_rng_state = with_sharding_constraint(
sr_jax_rng_state, jax.sharding.PartitionSpec(get_all_mesh_axes(), None)
)
return QuantizeMeta(stochastic_rounding_rng_state=sr_jax_rng_state)
def get_quantize_flax_meta(
self,
module,
......@@ -603,27 +662,14 @@ class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig):
Returns:
The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
"""
if tensor_source != TensorSource.DGRAD:
# Only DGRAD uses stochastic rounding
return QuantizeMeta()
# TODO(jberchtold): This assumes SR is always enabled for NVFP4. Use flag from recipe to toggle it.
sr_jax_rng = module.make_rng("sr_rng")
# Get a unique key for this quantizer
sr_jax_rng = jax.jit(jax.random.fold_in)(
sr_jax_rng, hash(quantizer_name) % jnp.iinfo(jnp.int32).max
)
# Imported here to prevent circular import
from transformer_engine.jax.quantize import QuantizeLayout
# Generate 4 random uint32 values from the JAX PRNG key
sr_jax_rng_state = jax.random.randint(
sr_jax_rng, (num_of_devices(), 4), 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32
).view(jnp.uint32)
sr_jax_rng_state = with_sharding_constraint(
sr_jax_rng_state, jax.sharding.PartitionSpec(get_all_mesh_axes(), None)
return QuantizeMeta.merge(
self._make_rht_quantize_meta(QuantizeLayout.ROWWISE_COLWISE, tensor_source),
self._make_stochastic_rounding_rng_state(module, tensor_source, quantizer_name),
)
return QuantizeMeta(stochastic_rounding_rng_state=sr_jax_rng_state)
_QUANTIZE_CONFIG = NoOpQuantizeConfig()
......
......@@ -26,6 +26,26 @@ class QuantizeMeta:
"""
@staticmethod
def merge(a: "QuantizeMeta", b: "QuantizeMeta") -> "QuantizeMeta":
"""Merge two QuantizeMeta instances.
Args:
a (QuantizeMeta): The first QuantizeMeta instance.
b (QuantizeMeta): The second QuantizeMeta instance.
Returns:
QuantizeMeta: A new QuantizeMeta instance with merged metadata.
"""
assert isinstance(a, QuantizeMeta)
assert isinstance(b, QuantizeMeta)
for key in b.get_kwargs_dictionary().keys():
if key in a.get_kwargs_dictionary():
assert (
a.get_kwargs_dictionary()[key] == b.get_kwargs_dictionary()[key]
), f"Conflict in merging QuantizeMeta: {key} has different values."
return QuantizeMeta(**{**a.get_kwargs_dictionary(), **b.get_kwargs_dictionary()})
def __init__(self, **kwargs):
self._kwargs = kwargs
......
......@@ -19,7 +19,7 @@ from transformer_engine_jax import QuantizeLayout
from transformer_engine.common import recipe
from .scaling_modes import ScalingMode
from .hadamard import apply_rht, should_use_rht
from .hadamard import apply_rht
from .tensor import (
ScaledTensor,
ScaledTensor1x,
......@@ -590,11 +590,13 @@ class NVFP4Quantizer(Quantizer):
q_layout: Quantization axis
data_layout: Data layout string (default: "NT")
stochastic_rounding_rng_state: RNG state for stochastic rounding, must be of shape (4,) and dtype uint32. If None, stochastic rounding is disabled.
use_rht: Whether to apply Randomized Hadamard Transform (RHT) before quantization.
"""
scaling_mode: ScalingMode = ScalingMode.NVFP4_1D_SCALING
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
data_layout: str = "NT"
use_rht: bool = False
stochastic_rounding_rng_state: Optional[jnp.ndarray] = None
def __post_init__(self):
......@@ -603,6 +605,30 @@ class NVFP4Quantizer(Quantizer):
), "NVFP4 quantization must use a q_dtype of float4_e2m1fn"
assert self.scaling_mode.is_nvfp4_scaling, "NVFP4Quantizer must use NVFP4 scaling modes"
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children = (self.stochastic_rounding_rng_state,)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout, self.use_rht)
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Reconstruct a quantizer from its flattened representation.
Args:
aux_data: Auxiliary data containing quantizer parameters
children: Unused children data
Returns:
A reconstructed Quantizer instance
"""
stochastic_rounding_rng_state = children[0]
return cls(*aux_data, stochastic_rounding_rng_state=stochastic_rounding_rng_state)
def _apply_stochastic_rounding(self, x):
assert (
self.stochastic_rounding_rng_state is not None
......@@ -688,8 +714,9 @@ class NVFP4Quantizer(Quantizer):
flatten_axis = x.ndim - flatten_axis
x_shape = x.shape
if should_use_rht(self.scaling_mode, is_colwise=is_colwise):
# We only apply RHT for 1D colwise nvfp4
# We currently only have a single flag 'use_rht' on the quantizer. To avoid an unused rowwise flag, we assume RHT is only used for colwise quantization for now.
use_rht = self.use_rht and is_colwise and self.scaling_mode == ScalingMode.NVFP4_1D_SCALING
if use_rht:
x = apply_rht(x)
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
......@@ -790,6 +817,7 @@ class NVFP4Quantizer(Quantizer):
scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype,
flatten_axis=rowwise_flatten_axis,
has_rht_applied=use_rht,
)
......
......@@ -175,6 +175,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
is_colwise: Whether the tensor uses column-wise quantization
data_layout: The data_layout specification for the tensor
flatten_axis: The quantization axis for the tensor
has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization
"""
scale_inv: jnp.ndarray
......@@ -184,6 +185,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
is_colwise: bool
data_layout: str
flatten_axis: int
has_rht_applied: bool
def __post_init__(self):
"""Validates and adjusts the scale_inv shape after initialization.
......@@ -243,6 +245,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
self.is_colwise,
self.data_layout,
self.flatten_axis,
self.has_rht_applied,
)
return (children, aux_data)
......@@ -314,6 +317,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
is_colwise=self.is_colwise,
data_layout=self.data_layout,
flatten_axis=self.flatten_axis,
has_rht_applied=self.has_rht_applied,
)
......@@ -354,6 +358,7 @@ class GroupedScaledTensor1x(ScaledTensor1x):
self.group_sizes = group_sizes
self.original_shape = original_shape
self.group_axis = group_axis
# TODO(Phuong):Handle RHT for grouped quantization once grouped quantization supports NVFP4
super().__init__(
data=data,
scale_inv=scale_inv,
......@@ -364,6 +369,7 @@ class GroupedScaledTensor1x(ScaledTensor1x):
is_colwise=is_colwise,
data_layout=data_layout,
flatten_axis=flatten_axis,
has_rht_applied=False,
)
def __post_init__(self):
......@@ -515,6 +521,7 @@ class ScaledTensorFactory:
group_sizes=None,
original_shape=None,
group_axis=0,
has_rht_applied=False,
):
"""Creates a single-scale quantized tensor.
......@@ -530,6 +537,7 @@ class ScaledTensorFactory:
group_sizes: Array of ints containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0)
has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization (default: False)
Returns:
A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided
......@@ -593,6 +601,7 @@ class ScaledTensorFactory:
is_colwise=is_colwise,
data_layout=data_layout,
flatten_axis=flatten_axis,
has_rht_applied=has_rht_applied,
)
@staticmethod
......@@ -610,6 +619,8 @@ class ScaledTensorFactory:
group_sizes=None,
original_shape=None,
group_axis=0,
rowwise_has_rht_applied=False,
colwise_has_rht_applied=False,
):
"""Creates a double-scale quantized tensor.
......@@ -626,6 +637,8 @@ class ScaledTensorFactory:
group_sizes: Array containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0)
rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
colwise_has_rht_applied: Whether the column-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
Returns:
A ScaledTensor2x instance
......@@ -648,6 +661,7 @@ class ScaledTensorFactory:
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
has_rht_applied=rowwise_has_rht_applied,
)
colwise_tensor = ScaledTensorFactory.create_1x(
colwise_data,
......@@ -661,6 +675,7 @@ class ScaledTensorFactory:
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
has_rht_applied=colwise_has_rht_applied,
)
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
......@@ -680,6 +695,8 @@ class ScaledTensorFactory:
group_sizes: jnp.ndarray = None,
original_shape: Tuple[int] = None,
group_axis: int = 0,
rowwise_has_rht_applied: bool = False,
colwise_has_rht_applied: bool = False,
):
"""Creates a scaled tensor based on the quantization axis.
......@@ -696,10 +713,14 @@ class ScaledTensorFactory:
group_sizes: Array containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0)
rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
colwise_has_rht_applied: Whether the col-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
Returns:
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout
"""
assert not rowwise_has_rht_applied, "RHT is not supported for rowwise quantization yet"
if q_layout == QuantizeLayout.ROWWISE_COLWISE:
return ScaledTensorFactory.create_2x(
data,
......@@ -715,6 +736,8 @@ class ScaledTensorFactory:
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
rowwise_has_rht_applied=rowwise_has_rht_applied,
colwise_has_rht_applied=colwise_has_rht_applied,
)
is_colwise = q_layout == QuantizeLayout.COLWISE
......@@ -731,6 +754,7 @@ class ScaledTensorFactory:
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
has_rht_applied=colwise_has_rht_applied,
)
return ScaledTensorFactory.create_1x(
......@@ -745,6 +769,7 @@ class ScaledTensorFactory:
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
has_rht_applied=rowwise_has_rht_applied,
)
......
......@@ -238,6 +238,19 @@ def num_of_devices():
return len(jax.devices())
def get_num_devices_in_mesh(mesh=None):
"""
Get the number of devices in the given mesh.
If the mesh is None, it would be replaced
by the global mesh.
"""
if mesh is None:
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh.empty:
return 1
return np.prod(list(mesh.shape.values()))
def get_mesh_axis_size(axis, mesh=None):
"""
Get the axis size of the given mesh.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment