Unverified Commit cd2034f3 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

Lower precision gated-act to accelerate FP8 current-scaling. (#2153)



* Applying the original precision as Norm outputs' and activation compuations.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding knob to control norm output precision.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Removing the knob and applying lower-precision norm with current-scaling only.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix the error when quantizer==None
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent 405d474b
...@@ -465,14 +465,23 @@ class TestNorm: ...@@ -465,14 +465,23 @@ class TestNorm:
x, gamma, beta, zero_centered_gamma, epsilon, quantizer=quantizer x, gamma, beta, zero_centered_gamma, epsilon, quantizer=quantizer
) )
ref_out, ref_mu, ref_rsigma = _jax_layernorm( ref_out, ref_mu, ref_rsigma = _jax_layernorm(
x, gamma, beta, zero_centered_gamma, epsilon, quantizer=ref_quantizer x,
gamma,
beta,
zero_centered_gamma,
epsilon,
quantizer=ref_quantizer,
) )
else: else:
output, rsigma = tex.rmsnorm_fwd( output, rsigma = tex.rmsnorm_fwd(
x, gamma, zero_centered_gamma, epsilon, quantizer=quantizer x, gamma, zero_centered_gamma, epsilon, quantizer=quantizer
) )
ref_out, ref_rsigma = _jax_rmsnorm( ref_out, ref_rsigma = _jax_rmsnorm(
x, gamma, zero_centered_gamma, epsilon, quantizer=ref_quantizer x,
gamma,
zero_centered_gamma,
epsilon,
quantizer=ref_quantizer,
) )
ref_mu = None ref_mu = None
......
...@@ -1045,7 +1045,7 @@ def act_lu( ...@@ -1045,7 +1045,7 @@ def act_lu(
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Current scaling does not support fused operations. Perform dact in higher precision then quantize after. # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
out = act_lu( out = act_lu(
x=x.astype(jnp.float32), x=x,
activation_type=activation_type, activation_type=activation_type,
quantizer=None, quantizer=None,
) )
...@@ -1178,8 +1178,8 @@ def quantize_dact_dbias( ...@@ -1178,8 +1178,8 @@ def quantize_dact_dbias(
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Current scaling does not support fused operations. Perform dact in higher precision then quantize after. # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
out = dact_lu( out = dact_lu(
dz=dz.astype(jnp.float32), dz=dz,
x=x.astype(jnp.float32), x=x,
activation_type=activation_type, activation_type=activation_type,
quantizer=None, quantizer=None,
) )
......
...@@ -842,6 +842,8 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None) ...@@ -842,6 +842,8 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None)
output = normed_input * gamma + beta output = normed_input * gamma + beta
if quantizer: if quantizer:
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
output = output.astype(x.dtype)
ln_out = quantizer.quantize(output, dq_dtype=x.dtype) ln_out = quantizer.quantize(output, dq_dtype=x.dtype)
else: else:
ln_out = jnp.asarray(output).astype(x.dtype) ln_out = jnp.asarray(output).astype(x.dtype)
...@@ -867,6 +869,8 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None): ...@@ -867,6 +869,8 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None):
output = normed_input * gamma output = normed_input * gamma
if quantizer: if quantizer:
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
output = output.astype(x.dtype)
ln_out = quantizer.quantize(output, dq_dtype=x.dtype) ln_out = quantizer.quantize(output, dq_dtype=x.dtype)
else: else:
ln_out = jnp.asarray(output).astype(x.dtype) ln_out = jnp.asarray(output).astype(x.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment