"...git@developer.sourcefind.cn:jerrrrry/infinilm.git" did not exist on "81fe2ba35fb8cde53088eae7fab5abe6fba711aa"
Unverified Commit f64d1459 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Fix 1x quantize kernel availability check on hopper (#1845)



* Fix 1x quantize kernel availability check on hopper

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 12af02f2
...@@ -183,6 +183,16 @@ def get_xla_flag(flag: str, default=None, cast=str): ...@@ -183,6 +183,16 @@ def get_xla_flag(flag: str, default=None, cast=str):
return default return default
def get_min_device_compute_capability():
"""
Returns the minimum compute capability of all local devices.
"""
return min(
transformer_engine_jax.get_device_compute_capability(local_gpu_id)
for local_gpu_id in range(len(jax.local_devices()))
)
def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quantizer=None): def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quantizer=None):
""" """
Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to
......
...@@ -23,6 +23,7 @@ from .misc import ( ...@@ -23,6 +23,7 @@ from .misc import (
jax_dtype_to_te_dtype, jax_dtype_to_te_dtype,
multidim_transpose, multidim_transpose,
should_apply_1x_fused_dbias_war_for_arch_l_100, should_apply_1x_fused_dbias_war_for_arch_l_100,
get_min_device_compute_capability,
NamedSharding, NamedSharding,
) )
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
...@@ -629,8 +630,13 @@ def _quantize_dbias_impl( ...@@ -629,8 +630,13 @@ def _quantize_dbias_impl(
if isinstance(quantizer, DelayedScaleQuantizer): if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale scale = quantizer.scale
is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
# It is faster to use 1x quantization for tensor scaling # It is faster to use 1x quantization for tensor scaling
force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() force_1x_quantization = (
quantizer.scaling_mode.is_tensor_scaling()
and quantizer.is_2x2x()
and is_1x_kernel_supported
)
q_layout = quantizer.q_layout q_layout = quantizer.q_layout
if force_1x_quantization: if force_1x_quantization:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment