Unverified Commit cd2034f3 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

Lower precision gated-act to accelerate FP8 current-scaling. (#2153)



* Applying the original precision as Norm outputs' and activation compuations.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding knob to control norm output precision.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Removing the knob and applying lower-precision norm with current-scaling only.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix the error when quantizer==None
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent 405d474b
......@@ -465,14 +465,23 @@ class TestNorm:
x, gamma, beta, zero_centered_gamma, epsilon, quantizer=quantizer
)
ref_out, ref_mu, ref_rsigma = _jax_layernorm(
x, gamma, beta, zero_centered_gamma, epsilon, quantizer=ref_quantizer
x,
gamma,
beta,
zero_centered_gamma,
epsilon,
quantizer=ref_quantizer,
)
else:
output, rsigma = tex.rmsnorm_fwd(
x, gamma, zero_centered_gamma, epsilon, quantizer=quantizer
)
ref_out, ref_rsigma = _jax_rmsnorm(
x, gamma, zero_centered_gamma, epsilon, quantizer=ref_quantizer
x,
gamma,
zero_centered_gamma,
epsilon,
quantizer=ref_quantizer,
)
ref_mu = None
......
......@@ -1045,7 +1045,7 @@ def act_lu(
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
out = act_lu(
x=x.astype(jnp.float32),
x=x,
activation_type=activation_type,
quantizer=None,
)
......@@ -1178,8 +1178,8 @@ def quantize_dact_dbias(
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
out = dact_lu(
dz=dz.astype(jnp.float32),
x=x.astype(jnp.float32),
dz=dz,
x=x,
activation_type=activation_type,
quantizer=None,
)
......
......@@ -842,6 +842,8 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None)
output = normed_input * gamma + beta
if quantizer:
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
output = output.astype(x.dtype)
ln_out = quantizer.quantize(output, dq_dtype=x.dtype)
else:
ln_out = jnp.asarray(output).astype(x.dtype)
......@@ -867,6 +869,8 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None):
output = normed_input * gamma
if quantizer:
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
output = output.astype(x.dtype)
ln_out = quantizer.quantize(output, dq_dtype=x.dtype)
else:
ln_out = jnp.asarray(output).astype(x.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment