Unverified Commit 70bb066e authored by Azure's avatar Azure Committed by GitHub
Browse files

Fix FP4 inference corruption issue in glm4.5-air model (#9346)

parent 2c4b4b78
......@@ -205,9 +205,15 @@ def scaled_fp4_quant(
rounded_m = ((m + 128 - 1) // 128) * 128
scale_n = n // block_size
rounded_n = ((scale_n + 4 - 1) // 4) * 4
output_scale = torch.empty(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)
# padded part should be zeroed out
if rounded_n > scale_n:
output_scale = torch.zeros(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)
else:
output_scale = torch.empty(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)
torch.ops.sgl_kernel.scaled_fp4_quant.default(
output, input, output_scale, input_global_scale
......@@ -338,12 +344,21 @@ def scaled_fp4_experts_quant(
output = torch.empty(
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
)
output_scales = torch.empty(
MAX_TOKENS_PER_EXPERT * topk,
padded_k,
dtype=torch.int32,
device=input_tensor.device,
)
# padded part should be zeroed out
if padded_k > scales_k:
output_scales = torch.zeros(
MAX_TOKENS_PER_EXPERT * topk,
padded_k,
dtype=torch.int32,
device=input_tensor.device,
)
else:
output_scales = torch.empty(
MAX_TOKENS_PER_EXPERT * topk,
padded_k,
dtype=torch.int32,
device=input_tensor.device,
)
torch.ops.sgl_kernel.scaled_fp4_experts_quant.default(
output,
output_scales,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment