Unverified Commit ca7407e3 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Update tolerance of distributed layernorm MLP for FP8 (#1971)



Update tolerance of distributed layernorm MLP for FP8
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 86c50977
......@@ -389,6 +389,24 @@ class TestDistributedLayernormMLP:
atol = 0.04
rtol = 11
# JAX's FP8 GEMM, jax.lax.dot_general, now uses the
# Triton backend by default. The error of
# the Triton FP8 gemm has been verified to be less than or equal
# to the error of the cuDNN FP8 gemm w.r.t a float32 ground truth.
# However, Triton can auto-tune a different kernel for the single GPU
# and multi-GPU run in this test, meaning the diff between single GPU
# and multi-GPU can be larger in some cases, even though both are
# within tolerance to the float32 ground truth.
jax_triton_gemm_precision_tolerance_update = (
with_jax_gemm
and isinstance(fp8_recipe, recipe.Float8CurrentScaling)
and dtype == jnp.bfloat16
and activation_type == ("gelu", "linear")
)
if jax_triton_gemm_precision_tolerance_update:
atol = 0.08
rtol = 15
assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype, atol=atol, rtol=rtol)
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment