"vscode:/vscode.git/clone" did not exist on "2a95efd39128955081c60b67d49351d89f003324"
Unverified Commit 389a6ba4 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Use TE quant if TE fused act is disabled (#2374)



* Use TE quant if TE fused act is disabled
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Keep existing precision
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent c5257605
...@@ -27,7 +27,7 @@ from .misc import ( ...@@ -27,7 +27,7 @@ from .misc import (
should_apply_1x_fused_dbias_war_for_arch_l_100, should_apply_1x_fused_dbias_war_for_arch_l_100,
NamedSharding, NamedSharding,
) )
from .quantization import _jax_dbias, _quantize_dbias_impl, AmaxScope from .quantization import _jax_dbias, quantize, quantize_dbias, _quantize_dbias_impl, AmaxScope
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import ( from ..quantize import (
...@@ -1268,7 +1268,19 @@ def act_lu( ...@@ -1268,7 +1268,19 @@ def act_lu(
) )
act_params = act_params if act_params is not None else ActivationParams() act_params = act_params if act_params is not None else ActivationParams()
if not ActLuPrimitive.enabled(): if not ActLuPrimitive.enabled():
return _jax_act_lu(x, activation_type, quantizer, act_params) act_out = _jax_act_lu(x, activation_type, act_params=act_params)
assert (
act_out.data.dtype == x.dtype
), f"JAX activation output dtype {act_out.data.dtype} must match input dtype {x.dtype}"
if quantizer is None:
return act_out
return quantize(
act_out,
quantizer=quantizer,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
# TE/common does not support colwise-only quantization yet # TE/common does not support colwise-only quantization yet
if quantizer is not None and quantizer.q_layout.is_colwise_only: if quantizer is not None and quantizer.q_layout.is_colwise_only:
...@@ -1330,11 +1342,12 @@ def act_lu( ...@@ -1330,11 +1342,12 @@ def act_lu(
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=True, output_amax_when_no_scaling=True,
) )
out, _ = _quantize_dbias_impl( assert (
out.data.dtype == x.dtype
), f"Activation output dtype {out.data.dtype} must match input dtype {x.dtype}"
out = quantize(
out, out,
is_dbias=False,
quantizer=quantizer, quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
) )
...@@ -1419,7 +1432,23 @@ def quantize_dact_dbias( ...@@ -1419,7 +1432,23 @@ def quantize_dact_dbias(
if not PrimitiveClass.enabled() or ( if not PrimitiveClass.enabled() or (
quantizer is not None and quantizer.q_layout.is_colwise_only quantizer is not None and quantizer.q_layout.is_colwise_only
): ):
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params) if quantizer is None:
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, act_params=act_params)
dact_out, _ = _jax_quantize_dact_dbias(
dz, x, activation_type, is_dbias=False, act_params=act_params
)
assert (
dact_out.data.dtype == x.dtype
), f"JAX dact output dtype {dact_out.data.dtype} must match input dtype {x.dtype}"
return quantize_dbias(
dact_out,
quantizer,
is_dbias=is_dbias,
flatten_axis=-2,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
if quantizer is None: if quantizer is None:
output, _, _, _, updated_amax, _ = PrimitiveClass.outer_primitive.bind( output, _, _, _, updated_amax, _ = PrimitiveClass.outer_primitive.bind(
dz, dz,
...@@ -1465,7 +1494,7 @@ def quantize_dact_dbias( ...@@ -1465,7 +1494,7 @@ def quantize_dact_dbias(
output_amax_when_no_scaling=output_amax_when_no_scaling, output_amax_when_no_scaling=output_amax_when_no_scaling,
) )
return _quantize_dbias_impl( return _quantize_dbias_impl(
out.data, out,
quantizer, quantizer,
is_dbias=True, is_dbias=True,
dq_dtype=x.dtype, dq_dtype=x.dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment