Unverified Commit a1c18bc8 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] WAR for CuDNN MXFP8 norm incorrect result (#1700)



Check CuDNN version and apply unfused norm if
below a version with the fix
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent c7702309
......@@ -23,6 +23,7 @@ from transformer_engine.jax.cpp_extensions.quantization import (
_jax_quantize,
_jax_quantize_dbias,
)
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
from transformer_engine.jax import cpp_extensions as tex
from transformer_engine.jax.quantize import (
DelayedScaleQuantizer,
......@@ -395,6 +396,11 @@ class TestNorm:
)
ref_mu = None
if get_cudnn_version() < (9, 10, 0):
# Reduce precision of test as we don't use fused norm below this version CuDNN for MXFP8 and instead
# do an unfused norm and quantize with an intermediate cast into in_dtype which can reduce precision
assert_allclose(output.dequantize(), ref_out.dequantize(), dtype=out_dtype)
else:
assert_bitwise_scaled_tensors(output, ref_out)
assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype)
if norm_type == "layernorm":
......
......@@ -26,6 +26,7 @@ from .misc import (
jax_dtype_to_te_dtype,
te_dtype_to_jax_dtype,
NamedSharding,
get_cudnn_version,
)
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory
......@@ -35,6 +36,7 @@ from ..quantize import (
DelayedScaleQuantizer,
ScalingMode,
)
from .quantization import _quantize_dbias_impl
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
......@@ -85,6 +87,10 @@ def is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode: ScalingMode) -> bo
return int(os.getenv("NVTE_ZERO_CENTERED_GAMMA_IN_WTYPE", "0")) == 1
# CuDNN version must be at least this to use MXFP8 fused normalization otherwise unfused norm and quantize will be used
FUSED_MXFP8_NORM_CUDNN_MIN_VERSION = (9, 10, 0)
class NormFwdPrimitive(BasePrimitive):
"""
Layer Normalization Forward FP8 Primitive
......@@ -122,6 +128,14 @@ class NormFwdPrimitive(BasePrimitive):
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32
assert (
scaling_mode != ScalingMode.MXFP8_1D_SCALING.value
or get_cudnn_version() >= FUSED_MXFP8_NORM_CUDNN_MIN_VERSION
), (
"MXFP8 Fused Normalization is only supported in CuDNN version"
f" {FUSED_MXFP8_NORM_CUDNN_MIN_VERSION} or higher"
)
mu_rsigama_dtype = jnp.float32
if norm_type == NVTE_Norm_Type.LayerNorm:
......@@ -913,6 +927,16 @@ def layernorm_fwd(
)
return output, mu, rsigma
if (
quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING
and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION
):
out, mu, rsigma = layernorm_fwd(
x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None
)
out, _ = _quantize_dbias_impl(out, quantizer)
return out, mu, rsigma
is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
......@@ -1095,6 +1119,14 @@ def rmsnorm_fwd(
)
return output, rsigma
if (
quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING
and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION
):
out, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer=None)
out, _ = _quantize_dbias_impl(out, quantizer)
return out, rsigma
is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment