Unverified Commit a1c18bc8 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] WAR for CuDNN MXFP8 norm incorrect result (#1700)



Check CuDNN version and apply unfused norm if
below a version with the fix
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent c7702309
...@@ -23,6 +23,7 @@ from transformer_engine.jax.cpp_extensions.quantization import ( ...@@ -23,6 +23,7 @@ from transformer_engine.jax.cpp_extensions.quantization import (
_jax_quantize, _jax_quantize,
_jax_quantize_dbias, _jax_quantize_dbias,
) )
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax import cpp_extensions as tex
from transformer_engine.jax.quantize import ( from transformer_engine.jax.quantize import (
DelayedScaleQuantizer, DelayedScaleQuantizer,
...@@ -395,7 +396,12 @@ class TestNorm: ...@@ -395,7 +396,12 @@ class TestNorm:
) )
ref_mu = None ref_mu = None
assert_bitwise_scaled_tensors(output, ref_out) if get_cudnn_version() < (9, 10, 0):
# Reduce precision of test as we don't use fused norm below this version CuDNN for MXFP8 and instead
# do an unfused norm and quantize with an intermediate cast into in_dtype which can reduce precision
assert_allclose(output.dequantize(), ref_out.dequantize(), dtype=out_dtype)
else:
assert_bitwise_scaled_tensors(output, ref_out)
assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype) assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype)
if norm_type == "layernorm": if norm_type == "layernorm":
assert_allclose(mu, ref_mu, dtype=inp_dtype) assert_allclose(mu, ref_mu, dtype=inp_dtype)
......
...@@ -26,6 +26,7 @@ from .misc import ( ...@@ -26,6 +26,7 @@ from .misc import (
jax_dtype_to_te_dtype, jax_dtype_to_te_dtype,
te_dtype_to_jax_dtype, te_dtype_to_jax_dtype,
NamedSharding, NamedSharding,
get_cudnn_version,
) )
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory from ..quantize import ScaledTensor, ScaledTensorFactory
...@@ -35,6 +36,7 @@ from ..quantize import ( ...@@ -35,6 +36,7 @@ from ..quantize import (
DelayedScaleQuantizer, DelayedScaleQuantizer,
ScalingMode, ScalingMode,
) )
from .quantization import _quantize_dbias_impl
if version.parse(jax.__version__) >= version.parse("0.5.0"): if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports from jax import ffi # pylint: disable=ungrouped-imports
...@@ -85,6 +87,10 @@ def is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode: ScalingMode) -> bo ...@@ -85,6 +87,10 @@ def is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode: ScalingMode) -> bo
return int(os.getenv("NVTE_ZERO_CENTERED_GAMMA_IN_WTYPE", "0")) == 1 return int(os.getenv("NVTE_ZERO_CENTERED_GAMMA_IN_WTYPE", "0")) == 1
# CuDNN version must be at least this to use MXFP8 fused normalization otherwise unfused norm and quantize will be used
FUSED_MXFP8_NORM_CUDNN_MIN_VERSION = (9, 10, 0)
class NormFwdPrimitive(BasePrimitive): class NormFwdPrimitive(BasePrimitive):
""" """
Layer Normalization Forward FP8 Primitive Layer Normalization Forward FP8 Primitive
...@@ -122,6 +128,14 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -122,6 +128,14 @@ class NormFwdPrimitive(BasePrimitive):
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32 assert scale_aval is None or scale_aval.dtype == jnp.float32
assert (
scaling_mode != ScalingMode.MXFP8_1D_SCALING.value
or get_cudnn_version() >= FUSED_MXFP8_NORM_CUDNN_MIN_VERSION
), (
"MXFP8 Fused Normalization is only supported in CuDNN version"
f" {FUSED_MXFP8_NORM_CUDNN_MIN_VERSION} or higher"
)
mu_rsigama_dtype = jnp.float32 mu_rsigama_dtype = jnp.float32
if norm_type == NVTE_Norm_Type.LayerNorm: if norm_type == NVTE_Norm_Type.LayerNorm:
...@@ -913,6 +927,16 @@ def layernorm_fwd( ...@@ -913,6 +927,16 @@ def layernorm_fwd(
) )
return output, mu, rsigma return output, mu, rsigma
if (
quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING
and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION
):
out, mu, rsigma = layernorm_fwd(
x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None
)
out, _ = _quantize_dbias_impl(out, quantizer)
return out, mu, rsigma
is_2x2x = quantizer.is_2x2x() is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling # TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
...@@ -1095,6 +1119,14 @@ def rmsnorm_fwd( ...@@ -1095,6 +1119,14 @@ def rmsnorm_fwd(
) )
return output, rsigma return output, rsigma
if (
quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING
and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION
):
out, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer=None)
out, _ = _quantize_dbias_impl(out, quantizer)
return out, rsigma
is_2x2x = quantizer.is_2x2x() is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling # TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment