Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
1fe691a4
"vscode:/vscode.git/clone" did not exist on "220a24cd4c2e6ec2825316e1badba4ddf8877456"
Unverified
Commit
1fe691a4
authored
Aug 02, 2025
by
YanbingJiang
Committed by
GitHub
Aug 01, 2025
Browse files
Fix FP8 block quantization when N or K is not multiples of 128 (#8648)
parent
e2521926
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
39 additions
and
18 deletions
+39
-18
sgl-kernel/csrc/cpu/moe.cpp
sgl-kernel/csrc/cpu/moe.cpp
+10
-10
test/srt/cpu/test_moe.py
test/srt/cpu/test_moe.py
+10
-4
test/srt/cpu/utils.py
test/srt/cpu/utils.py
+19
-4
No files found.
sgl-kernel/csrc/cpu/moe.cpp
View file @
1fe691a4
...
...
@@ -955,16 +955,16 @@ static inline void check_moe_scales(
}
}
#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \
auto w1s = w1_scale.value(); \
auto w2s = w2_scale.value(); \
auto block_size_val = block_size.value(); \
int64_t block_size_N = block_size_val[0]; \
int64_t block_size_K = block_size_val[1]; \
TORCH_CHECK(w1s.size(DIM0) == 2 * N
/
block_size_N); \
TORCH_CHECK(w1s.size(DIM1) ==
K /
block_size_K); \
TORCH_CHECK(w2s.size(DIM0) ==
K /
block_size_N); \
TORCH_CHECK(w2s.size(DIM1) ==
N /
block_size_K)
#define CHECK_MOE_SCALES_FP8(DIM0, DIM1)
\
auto w1s = w1_scale.value();
\
auto w2s = w2_scale.value();
\
auto block_size_val = block_size.value();
\
int64_t block_size_N = block_size_val[0];
\
int64_t block_size_K = block_size_val[1];
\
TORCH_CHECK(w1s.size(DIM0) ==
div_up(
2 * N
,
block_size_N)
)
; \
TORCH_CHECK(w1s.size(DIM1) ==
div_up(K,
block_size_K)
)
; \
TORCH_CHECK(w2s.size(DIM0) ==
div_up(K,
block_size_N)
)
; \
TORCH_CHECK(w2s.size(DIM1) ==
div_up(N,
block_size_K)
)
// hidden_states: [M, K]
// w1: [E, 2N, K]
...
...
test/srt/cpu/test_moe.py
View file @
1fe691a4
...
...
@@ -75,8 +75,8 @@ class TestFusedExperts(CustomTestCase):
topk_int8
=
[
3
]
M_fp8
=
[
2
,
121
]
N_fp8
=
[
512
]
K_fp8
=
[
256
]
N_fp8
=
[
352
,
512
]
K_fp8
=
[
256
,
320
]
E_fp8
=
[
8
]
topk_fp8
=
[
4
]
...
...
@@ -201,8 +201,14 @@ class TestFusedExperts(CustomTestCase):
w2_fp32
=
torch
.
randn
(
E
,
K
,
N
)
w2
=
(
w2_fp32
*
fp8_max
).
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
w1s
=
torch
.
randn
(
E
,
2
*
N
//
BLOCK_N
,
K
//
BLOCK_K
)
*
factor_for_scale
w2s
=
torch
.
randn
(
E
,
K
//
BLOCK_N
,
N
//
BLOCK_K
)
*
factor_for_scale
w1s
=
(
torch
.
randn
(
E
,
math
.
ceil
(
2
*
N
/
BLOCK_N
),
math
.
ceil
(
K
/
BLOCK_K
))
*
factor_for_scale
)
w2s
=
(
torch
.
randn
(
E
,
math
.
ceil
(
K
/
BLOCK_N
),
math
.
ceil
(
N
/
BLOCK_K
))
*
factor_for_scale
)
w1_scaled
=
scaled_weight
(
w1
,
w1s
)
w2_scaled
=
scaled_weight
(
w2
,
w2s
)
...
...
test/srt/cpu/utils.py
View file @
1fe691a4
...
...
@@ -136,18 +136,33 @@ def torch_w8a8_per_column_moe(a, w1_q, w2_q, w1_s, w2_s, b, routed_scaling_facto
def
scaled_weight
(
weight
,
scales
):
E
,
N
,
K
=
weight
.
shape
pad_N
=
(
BLOCK_N
-
(
N
%
BLOCK_N
))
%
BLOCK_N
pad_K
=
(
BLOCK_K
-
(
K
%
BLOCK_K
))
%
BLOCK_K
if
pad_N
>
0
or
pad_K
>
0
:
weight
=
torch
.
nn
.
functional
.
pad
(
weight
,
(
0
,
pad_K
,
0
,
pad_N
))
weight_block
=
(
weight
.
view
(
E
,
N
/
/
BLOCK_N
,
BLOCK_N
,
K
/
/
BLOCK_K
,
BLOCK_K
)
weight
.
view
(
E
,
math
.
ceil
(
N
/
BLOCK_N
)
,
BLOCK_N
,
math
.
ceil
(
K
/
BLOCK_K
)
,
BLOCK_K
)
.
permute
(
0
,
1
,
3
,
2
,
4
)
.
float
()
.
contiguous
()
)
return
(
(
weight_block
*
scales
.
view
(
E
,
N
//
BLOCK_N
,
K
//
BLOCK_K
,
1
,
1
))
weight_scaled
=
(
(
weight_block
*
scales
.
view
(
E
,
math
.
ceil
(
N
/
BLOCK_N
),
math
.
ceil
(
K
/
BLOCK_K
),
1
,
1
)
)
.
permute
(
0
,
1
,
3
,
2
,
4
)
.
contiguous
()
.
view
(
E
,
N
,
K
)
)
if
pad_N
>
0
or
pad_K
>
0
:
weight_scaled
=
weight_scaled
.
view
(
E
,
N
+
pad_N
,
K
+
pad_K
)
weight_scaled
=
weight_scaled
[...,
:
N
,
:
K
].
contiguous
()
else
:
weight_scaled
=
weight_scaled
.
view
(
E
,
N
,
K
)
return
weight_scaled
def
torch_naive_fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment