Unverified Commit 302b2c1e authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[CI/Build][AMD] Fix ref_dynamic_per_token_quant reference implementation on ROCm. (#30291)


Signed-off-by: default avatarRandall Smith <ransmith@amd.com>
Co-authored-by: default avatarRandall Smith <ransmith@amd.com>
parent 8f8fda26
...@@ -30,16 +30,11 @@ def ref_dynamic_per_token_quant( ...@@ -30,16 +30,11 @@ def ref_dynamic_per_token_quant(
if quant_dtype == torch.int8 if quant_dtype == torch.int8
else torch.finfo(quant_dtype) else torch.finfo(quant_dtype)
) )
qtype_traits_max = ( use_fp8fnuz = (
ROCM_FP8FNUZ_MAX current_platform.is_fp8_fnuz() and quant_dtype == current_platform.fp8_dtype()
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
else qtype_traits.max
)
qtype_traits_min = (
-ROCM_FP8FNUZ_MAX
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
else qtype_traits.min
) )
qtype_traits_max = ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.max
qtype_traits_min = -ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.min
qtype_max = as_float32_tensor(qtype_traits_max) qtype_max = as_float32_tensor(qtype_traits_max)
s_1 = as_float32_tensor(1.0) s_1 = as_float32_tensor(1.0)
s_512 = as_float32_tensor(512.0) s_512 = as_float32_tensor(512.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment