Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
26118a13
"src/vscode:/vscode.git/clone" did not exist on "d5acb4110a5536f5c0ace4a0c158f0e0c71c0a50"
Unverified
Commit
26118a13
authored
Jul 12, 2025
by
Qi Yuhang
Committed by
GitHub
Jul 11, 2025
Browse files
[fix]Update unitest for fp8_blockwise_scaled_grouped_mm kernel (#7932)
parent
475a249b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
86 additions
and
25 deletions
+86
-25
sgl-kernel/tests/test_fp8_blockwise_moe.py
sgl-kernel/tests/test_fp8_blockwise_moe.py
+86
-25
No files found.
sgl-kernel/tests/test_fp8_blockwise_moe.py
View file @
26118a13
import
random
from
typing
import
Tuple
import
pytest
import
torch
...
...
@@ -20,6 +21,44 @@ def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
)
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
def
calc_diff
(
x
,
y
):
x
,
y
=
x
.
double
(),
y
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
1
-
sim
def
ceil_div
(
x
:
int
,
y
:
int
)
->
int
:
return
(
x
+
y
-
1
)
//
y
def
per_token_cast_to_fp8
(
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
pad_size
=
(
128
-
(
n
%
128
))
%
128
x
=
torch
.
nn
.
functional
.
pad
(
x
,
(
0
,
pad_size
),
value
=
0
)
if
pad_size
>
0
else
x
x_view
=
x
.
view
(
m
,
-
1
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
2
).
view
(
m
,
-
1
).
clamp
(
1e-4
)
fp8_data
=
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
)
return
fp8_data
.
view
(
m
,
n
+
pad_size
)[:,
:
n
],
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
(
(
ceil_div
(
m
,
128
)
*
128
,
ceil_div
(
n
,
128
)
*
128
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
)
)
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
...
...
@@ -55,7 +94,7 @@ def is_sm100_supported(device=None) -> bool:
def
is_sm90_supported
(
device
=
None
)
->
bool
:
return
(
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
==
9
)
and
(
torch
.
version
.
cuda
>=
"12.
8
"
torch
.
version
.
cuda
>=
"12.
3
"
)
...
...
@@ -66,14 +105,12 @@ def is_sm90_supported(device=None) -> bool:
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
half
,
torch
.
bfloat16
])
def
test_fp8_blockwise_scaled_grouped_mm
(
num_experts
,
out_dtype
):
cc
=
torch
.
cuda
.
get_device_capability
(
None
)[
0
]
device
=
"cuda"
alignment
=
16
n_g
=
alignment
*
random
.
randint
(
1
,
5
)
*
128
k_g
=
alignment
*
random
.
randint
(
1
,
5
)
*
128
scale_a_group_shape
=
(
1
,
128
)
scale_b_group_shape
=
(
128
,
128
)
expert_offsets
=
torch
.
zeros
((
num_experts
+
1
),
device
=
device
,
dtype
=
torch
.
int32
)
problem_sizes
=
torch
.
zeros
((
num_experts
,
3
),
device
=
device
,
dtype
=
torch
.
int32
)
layout_sfa
=
torch
.
zeros
((
num_experts
,
5
),
device
=
device
,
dtype
=
torch
.
int32
)
...
...
@@ -90,20 +127,21 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
expert_offsets
[
g
+
1
]
=
expert_offsets
[
g
]
+
m_g
problem_sizes
[
g
][:]
=
torch
.
tensor
([
m_g
,
n_g
,
k_g
],
device
=
device
)
a_g
=
to_fp8
(
torch
.
randn
((
m_g
,
k_g
),
device
=
device
))
b_g
=
to_fp8
(
torch
.
randn
((
n_g
,
k_g
),
device
=
device
).
t
())
a
=
torch
.
randn
((
m_g
,
k_g
),
device
=
device
,
dtype
=
out_dtype
)
# (M, K):(K, 1)
b
=
torch
.
randn
((
n_g
,
k_g
),
device
=
device
,
dtype
=
out_dtype
).
t
()
# (K, N):(1, K)
a_g
,
a_scale
=
per_token_cast_to_fp8
(
a
)
# ag -- (M, K):(K, 1), a_scale() -- (M, k):(k, 1)
b_g
,
b_scale
=
per_block_cast_to_fp8
(
b
)
# bg -- (K, N):(N, 1), b_scale() -- (k, n):(n, 1)
a_tensors
.
append
(
a_g
)
b_tensors
.
append
(
b_g
)
a_scales_tensors
.
append
(
a_scale
)
b_scales_tensors
.
append
(
b_scale
)
scale_a_shape
=
scale_shape
(
a_g
.
shape
,
scale_a_group_shape
)
scale_b_shape
=
scale_shape
(
b_g
.
shape
,
scale_b_group_shape
)
a_scales_tensors
.
append
(
torch
.
randn
(
scale_a_shape
,
device
=
device
)
*
0.001
)
b_scales_tensors
.
append
(
torch
.
randn
(
scale_b_shape
,
device
=
device
)
*
0.001
)
baseline
=
baseline_scaled_mm
(
a_g
,
b_g
,
a_scales_tensors
[
-
1
],
b_scales_tensors
[
-
1
],
out_dtype
)
baseline
=
torch
.
mm
(
a
,
b
)
baseline_tensors
.
append
(
baseline
)
a_stack
=
torch
.
empty
(
...
...
@@ -114,21 +152,41 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
)
for
g
in
range
(
num_experts
):
a_stack
[
expert_offsets
[
g
]
:
expert_offsets
[
g
+
1
]]
=
a_tensors
[
g
]
b_stack
[
g
]
=
b_tensors
[
g
].
t
()
b_stack
=
b_stack
.
transpose
(
1
,
2
)
# Matrix A is Row-Major.
a_stack
[
expert_offsets
[
g
]
:
expert_offsets
[
g
+
1
]]
=
a_tensors
[
g
]
# a_stack[expert_offsets[g] : expert_offsets[g + 1]] -- (M, K):(K, 1)
b_stack
[
g
]
=
b_tensors
[
g
].
t
()
# b_stack[g] -- (N, K):(K, 1)
b_stack
=
b_stack
.
transpose
(
1
,
2
)
# Transpose Matrix B to Column-Major.
a_scale_stack
=
torch
.
empty
(
(
expert_offsets
[
-
1
]
,
k_g
//
128
),
device
=
device
,
dtype
=
torch
.
float32
(
expert_offsets
[
-
1
]
*
(
k_g
//
128
)
)
,
device
=
device
,
dtype
=
torch
.
float32
)
b_scale_stack
=
torch
.
empty
(
(
num_experts
,
n
_g
//
128
,
k
_g
//
128
),
device
=
device
,
dtype
=
torch
.
float32
(
num_experts
,
k
_g
//
128
,
n
_g
//
128
),
device
=
device
,
dtype
=
torch
.
float32
)
for
g
in
range
(
num_experts
):
a_scale_stack
[
expert_offsets
[
g
]
:
expert_offsets
[
g
+
1
]]
=
a_scales_tensors
[
g
]
b_scale_stack
[
g
]
=
b_scales_tensors
[
g
].
t
()
b_scale_stack
=
b_scale_stack
.
transpose
(
1
,
2
)
if
cc
==
9
:
# For SM90, we need MN-Major scale factor
# a_scales_tensors[g] -- (M, k):(k, 1)
# a_scales_tensors[g].t().contiguous() -- (k, M):(M, 1)
a_scale_stack
[
expert_offsets
[
g
]
*
(
k_g
//
128
)
:
expert_offsets
[
g
+
1
]
*
(
k_g
//
128
)
]
=
(
a_scales_tensors
[
g
].
t
().
contiguous
().
view
(
-
1
))
b_scale_stack
[
g
]
=
b_scales_tensors
[
g
]
# b_scale_stack[g] -- (k, n):(n, 1)
elif
cc
==
10
:
# For SM100, we need K-Major scale factor
# a_scales_tensors[g] -- (M, k):(k, 1)
a_scale_stack
[
expert_offsets
[
g
]
*
(
k_g
//
128
)
:
expert_offsets
[
g
+
1
]
*
(
k_g
//
128
)
]
=
a_scales_tensors
[
g
].
view
(
-
1
)
b_scale_stack
[
g
]
=
b_scales_tensors
[
g
]
# b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
a_scale_stack
=
a_scale_stack
.
view
(
expert_offsets
[
-
1
],
k_g
//
128
)
if
cc
==
10
:
b_scale_stack
=
b_scale_stack
.
transpose
(
1
,
2
).
contiguous
()
c_out
=
torch
.
empty
((
expert_offsets
[
-
1
],
n_g
),
device
=
device
,
dtype
=
out_dtype
)
a_strides
=
torch
.
full
(
...
...
@@ -168,8 +226,11 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
for
g
in
range
(
num_experts
):
baseline
=
baseline_tensors
[
g
]
actual
=
c_out
[
expert_offsets
[
g
]
:
expert_offsets
[
g
+
1
]]
torch
.
testing
.
assert_close
(
actual
,
baseline
,
rtol
=
1e-2
,
atol
=
1e-3
)
print
(
f
"num_experts=
{
num_experts
}
, out_dtype=
{
out_dtype
}
: OK"
)
diff
=
calc_diff
(
actual
,
baseline
)
assert
diff
<
0.001
print
(
f
"cc=
{
cc
}
0 num_experts=
{
num_experts
}
, out_dtype=
{
out_dtype
}
, diff=
{
diff
:.
5
f
}
: OK"
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment