Unverified Commit 70bb066e authored by Azure's avatar Azure Committed by GitHub
Browse files

Fix FP4 inference corruption issue in glm4.5-air model (#9346)

parent 2c4b4b78
...@@ -205,9 +205,15 @@ def scaled_fp4_quant( ...@@ -205,9 +205,15 @@ def scaled_fp4_quant(
rounded_m = ((m + 128 - 1) // 128) * 128 rounded_m = ((m + 128 - 1) // 128) * 128
scale_n = n // block_size scale_n = n // block_size
rounded_n = ((scale_n + 4 - 1) // 4) * 4 rounded_n = ((scale_n + 4 - 1) // 4) * 4
output_scale = torch.empty( # padded part should be zeroed out
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32 if rounded_n > scale_n:
) output_scale = torch.zeros(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)
else:
output_scale = torch.empty(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)
torch.ops.sgl_kernel.scaled_fp4_quant.default( torch.ops.sgl_kernel.scaled_fp4_quant.default(
output, input, output_scale, input_global_scale output, input, output_scale, input_global_scale
...@@ -338,12 +344,21 @@ def scaled_fp4_experts_quant( ...@@ -338,12 +344,21 @@ def scaled_fp4_experts_quant(
output = torch.empty( output = torch.empty(
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8 m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
) )
output_scales = torch.empty( # padded part should be zeroed out
MAX_TOKENS_PER_EXPERT * topk, if padded_k > scales_k:
padded_k, output_scales = torch.zeros(
dtype=torch.int32, MAX_TOKENS_PER_EXPERT * topk,
device=input_tensor.device, padded_k,
) dtype=torch.int32,
device=input_tensor.device,
)
else:
output_scales = torch.empty(
MAX_TOKENS_PER_EXPERT * topk,
padded_k,
dtype=torch.int32,
device=input_tensor.device,
)
torch.ops.sgl_kernel.scaled_fp4_experts_quant.default( torch.ops.sgl_kernel.scaled_fp4_experts_quant.default(
output, output,
output_scales, output_scales,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment