Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
70bb066e
Unverified
Commit
70bb066e
authored
Aug 21, 2025
by
Azure
Committed by
GitHub
Aug 20, 2025
Browse files
Fix FP4 inference corruption issue in glm4.5-air model (#9346)
parent
2c4b4b78
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
9 deletions
+24
-9
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+24
-9
No files found.
sgl-kernel/python/sgl_kernel/gemm.py
View file @
70bb066e
...
@@ -205,9 +205,15 @@ def scaled_fp4_quant(
...
@@ -205,9 +205,15 @@ def scaled_fp4_quant(
rounded_m
=
((
m
+
128
-
1
)
//
128
)
*
128
rounded_m
=
((
m
+
128
-
1
)
//
128
)
*
128
scale_n
=
n
//
block_size
scale_n
=
n
//
block_size
rounded_n
=
((
scale_n
+
4
-
1
)
//
4
)
*
4
rounded_n
=
((
scale_n
+
4
-
1
)
//
4
)
*
4
output_scale
=
torch
.
empty
(
# padded part should be zeroed out
(
rounded_m
,
rounded_n
//
4
),
device
=
device
,
dtype
=
torch
.
int32
if
rounded_n
>
scale_n
:
)
output_scale
=
torch
.
zeros
(
(
rounded_m
,
rounded_n
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
else
:
output_scale
=
torch
.
empty
(
(
rounded_m
,
rounded_n
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
torch
.
ops
.
sgl_kernel
.
scaled_fp4_quant
.
default
(
torch
.
ops
.
sgl_kernel
.
scaled_fp4_quant
.
default
(
output
,
input
,
output_scale
,
input_global_scale
output
,
input
,
output_scale
,
input_global_scale
...
@@ -338,12 +344,21 @@ def scaled_fp4_experts_quant(
...
@@ -338,12 +344,21 @@ def scaled_fp4_experts_quant(
output
=
torch
.
empty
(
output
=
torch
.
empty
(
m_numtopk
,
k
//
2
,
device
=
input_tensor
.
device
,
dtype
=
torch
.
uint8
m_numtopk
,
k
//
2
,
device
=
input_tensor
.
device
,
dtype
=
torch
.
uint8
)
)
output_scales
=
torch
.
empty
(
# padded part should be zeroed out
MAX_TOKENS_PER_EXPERT
*
topk
,
if
padded_k
>
scales_k
:
padded_k
,
output_scales
=
torch
.
zeros
(
dtype
=
torch
.
int32
,
MAX_TOKENS_PER_EXPERT
*
topk
,
device
=
input_tensor
.
device
,
padded_k
,
)
dtype
=
torch
.
int32
,
device
=
input_tensor
.
device
,
)
else
:
output_scales
=
torch
.
empty
(
MAX_TOKENS_PER_EXPERT
*
topk
,
padded_k
,
dtype
=
torch
.
int32
,
device
=
input_tensor
.
device
,
)
torch
.
ops
.
sgl_kernel
.
scaled_fp4_experts_quant
.
default
(
torch
.
ops
.
sgl_kernel
.
scaled_fp4_experts_quant
.
default
(
output
,
output
,
output_scales
,
output_scales
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment