Unverified Commit 70bb066e authored by Azure's avatar Azure Committed by GitHub
Browse files

Fix FP4 inference corruption issue in glm4.5-air model (#9346)

parent 2c4b4b78
...@@ -205,6 +205,12 @@ def scaled_fp4_quant( ...@@ -205,6 +205,12 @@ def scaled_fp4_quant(
rounded_m = ((m + 128 - 1) // 128) * 128 rounded_m = ((m + 128 - 1) // 128) * 128
scale_n = n // block_size scale_n = n // block_size
rounded_n = ((scale_n + 4 - 1) // 4) * 4 rounded_n = ((scale_n + 4 - 1) // 4) * 4
# padded part should be zeroed out
if rounded_n > scale_n:
output_scale = torch.zeros(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)
else:
output_scale = torch.empty( output_scale = torch.empty(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32 (rounded_m, rounded_n // 4), device=device, dtype=torch.int32
) )
...@@ -338,6 +344,15 @@ def scaled_fp4_experts_quant( ...@@ -338,6 +344,15 @@ def scaled_fp4_experts_quant(
output = torch.empty( output = torch.empty(
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8 m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
) )
# padded part should be zeroed out
if padded_k > scales_k:
output_scales = torch.zeros(
MAX_TOKENS_PER_EXPERT * topk,
padded_k,
dtype=torch.int32,
device=input_tensor.device,
)
else:
output_scales = torch.empty( output_scales = torch.empty(
MAX_TOKENS_PER_EXPERT * topk, MAX_TOKENS_PER_EXPERT * topk,
padded_k, padded_k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment