Unverified Commit 08f8a562 authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[CI/Build][Kernel][BugFix][AMD] Fix per_token_group_quant_fp8 to use correct...


[CI/Build][Kernel][BugFix][AMD] Fix per_token_group_quant_fp8 to use correct fp8 min/max values and update atol/rtol in test_quantfp8_group_functionality  (#30292)
Signed-off-by: default avatarRandall Smith <ransmith@amd.com>
Co-authored-by: default avatarRandall Smith <ransmith@amd.com>
parent b4039c08
...@@ -62,7 +62,7 @@ def test_quantfp8_group_functionality( ...@@ -62,7 +62,7 @@ def test_quantfp8_group_functionality(
assert scales_col.stride(1) == batch_size assert scales_col.stride(1) == batch_size
# Test column-major scales consistency # Test column-major scales consistency
assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8) torch.testing.assert_close(scales_col, scales_native, rtol=1e-9, atol=1e-8)
# 3. Test CUDA implementation (only for divisible dimensions) # 3. Test CUDA implementation (only for divisible dimensions)
if is_divisible: if is_divisible:
...@@ -71,7 +71,7 @@ def test_quantfp8_group_functionality( ...@@ -71,7 +71,7 @@ def test_quantfp8_group_functionality(
assert scales_cuda.shape == (batch_size, expected_num_groups) assert scales_cuda.shape == (batch_size, expected_num_groups)
# Verify CUDA/native consistency # Verify CUDA/native consistency
assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8) torch.testing.assert_close(scales_cuda, scales_native, rtol=2e-7, atol=2e-8)
# Quantized values should mostly match # Quantized values should mostly match
diff_count = (x_quant_cuda != x_quant_native).sum().item() diff_count = (x_quant_cuda != x_quant_native).sum().item()
......
...@@ -762,9 +762,12 @@ def per_token_group_quant_fp8( ...@@ -762,9 +762,12 @@ def per_token_group_quant_fp8(
) )
assert x.stride(-1) == 1, "`x` groups must be contiguous" assert x.stride(-1) == 1, "`x` groups must be contiguous"
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm
# platforms that use the torch.float8_e4mefnuz dtype.
finfo = torch.finfo(dtype) finfo = torch.finfo(dtype)
fp8_min = finfo.min fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min
fp8_max = finfo.max fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max
assert out_q is None or out_q.shape == x.shape assert out_q is None or out_q.shape == x.shape
x_q = out_q x_q = out_q
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment