Unverified Commit a8e4346e authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Use TE quantization when TE fused norm is disable (#2303)



* jax norm + te quant
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>


---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 4cf2f12b
...@@ -27,7 +27,7 @@ from .misc import ( ...@@ -27,7 +27,7 @@ from .misc import (
NamedSharding, NamedSharding,
get_cudnn_version, get_cudnn_version,
) )
from .quantization import _quantize_dbias_impl, AmaxScope from .quantization import quantize, AmaxScope
from ..sharding import ( from ..sharding import (
all_reduce_max_along_all_axes_except_PP, all_reduce_max_along_all_axes_except_PP,
all_reduce_sum_along_dp_fsdp_tpsp, all_reduce_sum_along_dp_fsdp_tpsp,
...@@ -945,7 +945,7 @@ def layernorm_fwd( ...@@ -945,7 +945,7 @@ def layernorm_fwd(
beta: jnp.ndarray, beta: jnp.ndarray,
zero_centered_gamma: bool, zero_centered_gamma: bool,
epsilon: float, epsilon: float,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer] = None,
amax_scope: AmaxScope = AmaxScope.LOCAL, amax_scope: AmaxScope = AmaxScope.LOCAL,
transpose_batch_sequence: bool = False, transpose_batch_sequence: bool = False,
output_amax_when_no_scaling: bool = False, output_amax_when_no_scaling: bool = False,
...@@ -975,7 +975,16 @@ def layernorm_fwd( ...@@ -975,7 +975,16 @@ def layernorm_fwd(
- Reciprocal of the standard deviation of the input tensor. Shape: (..., 1) - Reciprocal of the standard deviation of the input tensor. Shape: (..., 1)
""" """
if not NormFwdPrimitive.enabled(): if not NormFwdPrimitive.enabled():
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) output, mu, rsigma = _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
if quantizer is not None:
output = quantize(
output,
quantizer,
flatten_axis=-1,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
return (output, mu, rsigma)
# TE/common does not support normalization with colwise only quantization yet # TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
...@@ -1029,7 +1038,7 @@ def layernorm_fwd( ...@@ -1029,7 +1038,7 @@ def layernorm_fwd(
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=False, output_amax_when_no_scaling=False,
) )
out, _ = _quantize_dbias_impl( out, _ = quantize(
out, quantizer, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence out, quantizer, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence
) )
return out, mu, rsigma return out, mu, rsigma
...@@ -1050,11 +1059,9 @@ def layernorm_fwd( ...@@ -1050,11 +1059,9 @@ def layernorm_fwd(
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=True, output_amax_when_no_scaling=True,
) )
out, _ = _quantize_dbias_impl( out = quantize(
out, out,
is_dbias=False,
quantizer=quantizer, quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
) )
...@@ -1219,7 +1226,16 @@ def rmsnorm_fwd( ...@@ -1219,7 +1226,16 @@ def rmsnorm_fwd(
Shape: (..., 1) Shape: (..., 1)
""" """
if not NormFwdPrimitive.enabled(): if not NormFwdPrimitive.enabled():
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer) output, rsigma = _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon)
if quantizer is not None:
output = quantize(
output,
quantizer,
flatten_axis=-1,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
return (output, rsigma)
# TE/common does not support normalization with colwise only quantization yet # TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
...@@ -1274,7 +1290,7 @@ def rmsnorm_fwd( ...@@ -1274,7 +1290,7 @@ def rmsnorm_fwd(
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=False, output_amax_when_no_scaling=False,
) )
out, _ = _quantize_dbias_impl( out = quantize(
out.data, out.data,
quantizer, quantizer,
amax_scope=amax_scope, amax_scope=amax_scope,
...@@ -1297,11 +1313,9 @@ def rmsnorm_fwd( ...@@ -1297,11 +1313,9 @@ def rmsnorm_fwd(
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=True, output_amax_when_no_scaling=True,
) )
out, _ = _quantize_dbias_impl( out = quantize(
out, out,
is_dbias=False,
quantizer=quantizer, quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment