Unverified Commit 66c98874 authored by Kevin McKay's avatar Kevin McKay Committed by GitHub
Browse files

[Bugfix][Hardware][AMD] Fix FP8 dtype in silu_mul quantization (#31179)


Signed-off-by: default avatarc0de128 <kevin.mckay@outlook.com>
parent 1ff67df1
...@@ -625,8 +625,9 @@ def silu_mul_per_token_group_quant_fp8_colmajor( ...@@ -625,8 +625,9 @@ def silu_mul_per_token_group_quant_fp8_colmajor(
M, N = input.size() M, N = input.size()
N_2 = N // 2 N_2 = N // 2
fp8_dtype = current_platform.fp8_dtype()
if output is None: if output is None:
output = torch.empty((M, N_2), dtype=torch.float8_e4m3fn, device=input.device) output = torch.empty((M, N_2), dtype=fp8_dtype, device=input.device)
output_scales = torch.empty( output_scales = torch.empty(
((N_2 // GROUP_SIZE), M), dtype=torch.float32, device=input.device ((N_2 // GROUP_SIZE), M), dtype=torch.float32, device=input.device
...@@ -637,9 +638,12 @@ def silu_mul_per_token_group_quant_fp8_colmajor( ...@@ -637,9 +638,12 @@ def silu_mul_per_token_group_quant_fp8_colmajor(
assert M % BLOCK_M == 0 assert M % BLOCK_M == 0
assert N_2 % BLOCK_N == 0 assert N_2 % BLOCK_N == 0
finfo = torch.finfo(torch.float8_e4m3fn) # Using the default value (240.0) from pytorch will cause accuracy
fp8_min = finfo.min # issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm
fp8_max = finfo.max # platforms that use the torch.float8_e4m3fnuz dtype.
finfo = torch.finfo(fp8_dtype)
fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min
fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max
# Force even division so we can avoid edgecases within the kernel. # Force even division so we can avoid edgecases within the kernel.
assert M % BLOCK_M == 0 assert M % BLOCK_M == 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment