Unverified Commit e2f2a0b4 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Make SR rng state always 2D (num_devices, 4) to fix partitioning issue (#2294)



* Make SR rng state always 2D (num_devices, 4)
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix pure-jax impl
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix test shape
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent eb34783c
...@@ -876,7 +876,7 @@ class TestStochasticRounding: ...@@ -876,7 +876,7 @@ class TestStochasticRounding:
for i in range(num_samples): for i in range(num_samples):
iter_key = jax.random.fold_in(key, i) iter_key = jax.random.fold_in(key, i)
sr_rng_state = jax.random.randint( sr_rng_state = jax.random.randint(
iter_key, (4,), minval=0, maxval=2**30 - 1, dtype=jnp.uint32 iter_key, (1, 4), minval=0, maxval=2**30 - 1, dtype=jnp.uint32
) )
quantizer = QuantizerFactory.create( quantizer = QuantizerFactory.create(
q_dtype=q_dtype, q_dtype=q_dtype,
......
...@@ -631,10 +631,8 @@ class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig): ...@@ -631,10 +631,8 @@ class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig):
) )
sr_jax_rng = jax.jit(jax.random.fold_in)(sr_jax_rng, quantizer_hash) sr_jax_rng = jax.jit(jax.random.fold_in)(sr_jax_rng, quantizer_hash)
# Generate 4 random uint32 values from the JAX PRNG key # Generate 4 random uint32 values per device from the JAX PRNG key
shape = (4,) shape = (get_num_devices_in_mesh(), 4)
if get_num_devices_in_mesh() > 1:
shape = (get_num_devices_in_mesh(), 4)
sr_jax_rng_state = jax.random.randint( sr_jax_rng_state = jax.random.randint(
sr_jax_rng, shape, 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32 sr_jax_rng, shape, 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32
).view(jnp.uint32) ).view(jnp.uint32)
......
...@@ -34,6 +34,7 @@ from .helper import ( ...@@ -34,6 +34,7 @@ from .helper import (
TensorSource, TensorSource,
) )
from .device_utils import is_fp8_gemm_with_all_layouts_supported from .device_utils import is_fp8_gemm_with_all_layouts_supported
from ..sharding import get_num_devices_in_mesh
__all__ = [ __all__ = [
"QuantizeLayout", "QuantizeLayout",
...@@ -633,9 +634,11 @@ class NVFP4Quantizer(Quantizer): ...@@ -633,9 +634,11 @@ class NVFP4Quantizer(Quantizer):
assert ( assert (
self.stochastic_rounding_rng_state is not None self.stochastic_rounding_rng_state is not None
), "Stochastic rounding RNG state is not initialized" ), "Stochastic rounding RNG state is not initialized"
assert self.stochastic_rounding_rng_state.shape == ( expected_sr_rng_state_shape = (get_num_devices_in_mesh(), 4)
4, assert self.stochastic_rounding_rng_state.shape == expected_sr_rng_state_shape, (
), "Stochastic rounding RNG state must be of shape (4,)" "Stochastic rounding RNG state must be of shape (num_devices_in_mesh, 4). Expected"
f" {expected_sr_rng_state_shape}, but got {self.stochastic_rounding_rng_state.shape}"
)
assert ( assert (
self.stochastic_rounding_rng_state.dtype == jnp.uint32 self.stochastic_rounding_rng_state.dtype == jnp.uint32
), "Stochastic rounding RNG state must be of dtype uint32" ), "Stochastic rounding RNG state must be of dtype uint32"
...@@ -643,14 +646,15 @@ class NVFP4Quantizer(Quantizer): ...@@ -643,14 +646,15 @@ class NVFP4Quantizer(Quantizer):
# Default RNG state in JAX expects 2x 32-bit integers, use first 2 uint32s for initial state and fold in the other 2 uint32s # Default RNG state in JAX expects 2x 32-bit integers, use first 2 uint32s for initial state and fold in the other 2 uint32s
key_bits = jnp.array( key_bits = jnp.array(
[ [
self.stochastic_rounding_rng_state[0], # only take the first device's RNG state as the pure-JAX stochastic rounding impl only uses a single-device
self.stochastic_rounding_rng_state[1], self.stochastic_rounding_rng_state[0][0],
self.stochastic_rounding_rng_state[0][1],
], ],
dtype=jnp.uint32, dtype=jnp.uint32,
) )
key = jax.random.wrap_key_data(key_bits) key = jax.random.wrap_key_data(key_bits)
key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[2]) key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[0][2])
key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[3]) key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[0][3])
abs_x = jnp.abs(x) abs_x = jnp.abs(x)
sign_x = jnp.sign(x) sign_x = jnp.sign(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment