"vscode:/vscode.git/clone" did not exist on "2d0afcc9dc925928ee8764c826a3661e487f9f82"
Unverified Commit 3c268564 authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[CI][AMD][Quantization][BugFix] Fix fp8 max in quant_utils.py and update...


[CI][AMD][Quantization][BugFix] Fix fp8 max in quant_utils.py and update test_fp8_quant.::test_static_fp8_quant_group_2d to use correct fp8 dtype and adjust atol/rtol (#32201)
Signed-off-by: default avatarRandall Smith <ransmith@amd.com>
parent 773d7073
...@@ -178,12 +178,12 @@ def test_static_fp8_quant_group_2d( ...@@ -178,12 +178,12 @@ def test_static_fp8_quant_group_2d(
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
ref_out, scale = scaled_quantize( ref_out, scale = scaled_quantize(
x, group_shape, FP8_DTYPE, compute_dtype=torch.float32 x, group_shape, current_platform.fp8_dtype(), compute_dtype=torch.float32
) )
ops_out, ops_scale = ops.scaled_fp8_quant(x, scale=scale, group_shape=group_shape) ops_out, ops_scale = ops.scaled_fp8_quant(x, scale=scale, group_shape=group_shape)
torch.testing.assert_close(scale, ops_scale) torch.testing.assert_close(scale, ops_scale)
torch.testing.assert_close(ref_out.float(), ops_out.float(), rtol=0.12, atol=0.0) torch.testing.assert_close(ref_out.float(), ops_out.float(), rtol=1.2e-1, atol=1e-5)
opcheck_fp8_quant(ops_out, x, scale=scale) opcheck_fp8_quant(ops_out, x, scale=scale)
......
...@@ -221,7 +221,8 @@ def scaled_quantize( ...@@ -221,7 +221,8 @@ def scaled_quantize(
# Compute scales # Compute scales
min_val, max_val = x_blkd_permd.aminmax(dim=-1) min_val, max_val = x_blkd_permd.aminmax(dim=-1)
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax _, fp8_max = get_fp8_min_max()
scale = fp8_max / amax
# Apply scale and convert form: # Apply scale and convert form:
# (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N) # (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment