Unverified Commit dac098d8 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Fix distributed Layernorm test failure (#1734)



Fix distributed layernorm test failure
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 2f61c401
......@@ -78,7 +78,7 @@ class TestDistributedLayernorm:
if fp8_recipe == recipe.MXFP8BlockScaling() and "dp" in mesh_axes:
other_bytes = 384 # required for small scale shapes that require padding
if fp8_recipe == recipe.Float8CurrentScaling():
allreduce_total_bytes += 4 # 1 * FP32 for the amax reduction
allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction
return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes
)
......
......@@ -614,7 +614,7 @@ def _quantize_dbias_impl(
# Globally reduce amax across all devices for current scaling so we have a single global scale.
# This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
# until the tensor is dequantized (e.g. in the GEMM).
amax = jnp.amax(jnp.abs(x), keepdims=True)
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32)
scale = compute_scale_from_amax(amax, quantizer.q_dtype)
if isinstance(quantizer, DelayedScaleQuantizer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment