Unverified Commit e2f2a0b4 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Make SR rng state always 2D (num_devices, 4) to fix partitioning issue (#2294)



* Make SR rng state always 2D (num_devices, 4)
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix pure-jax impl
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix test shape
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent eb34783c
......@@ -876,7 +876,7 @@ class TestStochasticRounding:
for i in range(num_samples):
iter_key = jax.random.fold_in(key, i)
sr_rng_state = jax.random.randint(
iter_key, (4,), minval=0, maxval=2**30 - 1, dtype=jnp.uint32
iter_key, (1, 4), minval=0, maxval=2**30 - 1, dtype=jnp.uint32
)
quantizer = QuantizerFactory.create(
q_dtype=q_dtype,
......
......@@ -631,10 +631,8 @@ class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig):
)
sr_jax_rng = jax.jit(jax.random.fold_in)(sr_jax_rng, quantizer_hash)
# Generate 4 random uint32 values from the JAX PRNG key
shape = (4,)
if get_num_devices_in_mesh() > 1:
shape = (get_num_devices_in_mesh(), 4)
# Generate 4 random uint32 values per device from the JAX PRNG key
shape = (get_num_devices_in_mesh(), 4)
sr_jax_rng_state = jax.random.randint(
sr_jax_rng, shape, 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32
).view(jnp.uint32)
......
......@@ -34,6 +34,7 @@ from .helper import (
TensorSource,
)
from .device_utils import is_fp8_gemm_with_all_layouts_supported
from ..sharding import get_num_devices_in_mesh
__all__ = [
"QuantizeLayout",
......@@ -633,9 +634,11 @@ class NVFP4Quantizer(Quantizer):
assert (
self.stochastic_rounding_rng_state is not None
), "Stochastic rounding RNG state is not initialized"
assert self.stochastic_rounding_rng_state.shape == (
4,
), "Stochastic rounding RNG state must be of shape (4,)"
expected_sr_rng_state_shape = (get_num_devices_in_mesh(), 4)
assert self.stochastic_rounding_rng_state.shape == expected_sr_rng_state_shape, (
"Stochastic rounding RNG state must be of shape (num_devices_in_mesh, 4). Expected"
f" {expected_sr_rng_state_shape}, but got {self.stochastic_rounding_rng_state.shape}"
)
assert (
self.stochastic_rounding_rng_state.dtype == jnp.uint32
), "Stochastic rounding RNG state must be of dtype uint32"
......@@ -643,14 +646,15 @@ class NVFP4Quantizer(Quantizer):
# Default RNG state in JAX expects 2x 32-bit integers, use first 2 uint32s for initial state and fold in the other 2 uint32s
key_bits = jnp.array(
[
self.stochastic_rounding_rng_state[0],
self.stochastic_rounding_rng_state[1],
# only take the first device's RNG state as the pure-JAX stochastic rounding impl only uses a single-device
self.stochastic_rounding_rng_state[0][0],
self.stochastic_rounding_rng_state[0][1],
],
dtype=jnp.uint32,
)
key = jax.random.wrap_key_data(key_bits)
key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[2])
key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[3])
key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[0][2])
key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[0][3])
abs_x = jnp.abs(x)
sign_x = jnp.sign(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment