Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ddf8981d
Unverified
Commit
ddf8981d
authored
Mar 30, 2025
by
yinfan98
Committed by
GitHub
Mar 29, 2025
Browse files
Delete test_deep_gemm.py (#4891)
parent
400ad660
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
263 deletions
+0
-263
sgl-kernel/tests/test_deep_gemm.py
sgl-kernel/tests/test_deep_gemm.py
+0
-263
No files found.
sgl-kernel/tests/test_deep_gemm.py
deleted
100644 → 0
View file @
400ad660
import
os
import
random
import
unittest
from
typing
import
Any
,
Tuple
import
deep_gemm
import
torch
from
deep_gemm
import
calc_diff
,
ceil_div
,
get_col_major_tma_aligned_tensor
,
jit
"""
fork deepgemm/tests/test_core.py
"""
def
per_token_cast_to_fp8
(
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
and
x
.
size
(
1
)
%
128
==
0
m
,
n
=
x
.
shape
x_view
=
x
.
view
(
m
,
-
1
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
2
).
view
(
m
,
-
1
).
clamp
(
1e-4
)
return
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
n
),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
(
(
ceil_div
(
m
,
128
)
*
128
,
ceil_div
(
n
,
128
)
*
128
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
)
)
def
construct
(
m
:
int
,
k
:
int
,
n
:
int
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
]:
x
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
randn
((
n
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
out
=
torch
.
empty
((
m
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
ref_out
=
x
@
y
.
t
()
x_fp8
,
y_fp8
=
per_token_cast_to_fp8
(
x
),
per_block_cast_to_fp8
(
y
)
# Transpose earlier so that the testing will not trigger transposing kernels
x_fp8
=
(
x_fp8
[
0
],
get_col_major_tma_aligned_tensor
(
x_fp8
[
1
]))
return
x_fp8
,
y_fp8
,
out
,
ref_out
def
construct_grouped
(
num_groups
:
int
,
m
:
int
,
k
:
int
,
n
:
int
,
is_masked
:
bool
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
]:
x
=
torch
.
randn
((
num_groups
,
m
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
y
=
torch
.
randn
((
num_groups
,
n
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
out
=
torch
.
empty
((
num_groups
,
m
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
ref_out
=
torch
.
einsum
(
"gmk,gnk->gmn"
,
x
,
y
)
assert
m
%
4
==
0
,
f
"TMA alignment error:
{
m
}
"
x_fp8
=
(
torch
.
empty_like
(
x
,
dtype
=
torch
.
float8_e4m3fn
),
torch
.
empty
((
num_groups
,
m
,
k
//
128
),
device
=
"cuda"
,
dtype
=
torch
.
float
),
)
y_fp8
=
(
torch
.
empty_like
(
y
,
dtype
=
torch
.
float8_e4m3fn
),
torch
.
empty
(
(
num_groups
,
(
n
+
127
)
//
128
,
k
//
128
),
device
=
"cuda"
,
dtype
=
torch
.
float
),
)
for
i
in
range
(
num_groups
):
x_fp8
[
0
][
i
],
x_fp8
[
1
][
i
]
=
per_token_cast_to_fp8
(
x
[
i
])
y_fp8
[
0
][
i
],
y_fp8
[
1
][
i
]
=
per_block_cast_to_fp8
(
y
[
i
])
# For non-masked input, we must merge the group and M dims
if
not
is_masked
:
x_fp8
=
(
x_fp8
[
0
].
view
(
-
1
,
k
),
per_token_cast_to_fp8
(
x
.
view
(
-
1
,
k
))[
1
])
out
,
ref_out
=
out
.
view
(
-
1
,
n
),
ref_out
.
view
(
-
1
,
n
)
# Transpose earlier so that the testing will not trigger transposing kernels
x_fp8
=
(
x_fp8
[
0
],
get_col_major_tma_aligned_tensor
(
x_fp8
[
1
]))
return
x_fp8
,
y_fp8
,
out
,
ref_out
class
TestDeepGemmCore
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
torch
.
backends
.
cudnn
.
allow_tf32
=
True
torch
.
manual_seed
(
0
)
random
.
seed
(
0
)
print
(
"Library path:"
)
print
(
f
" >
{
deep_gemm
.
__path__
}
\n
"
)
def
test_gemm
(
self
):
print
(
"Testing GEMM:"
)
for
m
in
(
64
,
128
,
4096
):
for
k
,
n
in
[
(
7168
,
2112
),
(
1536
,
24576
),
(
512
,
32768
),
(
16384
,
7168
),
(
7168
,
4096
),
(
2048
,
7168
),
]:
x_fp8
,
y_fp8
,
out
,
ref_out
=
construct
(
m
,
k
,
n
)
deep_gemm
.
gemm_fp8_fp8_bf16_nt
(
x_fp8
,
y_fp8
,
out
)
diff
=
calc_diff
(
out
,
ref_out
)
self
.
assertTrue
(
diff
<
0.001
,
f
"
{
m
=
}
,
{
k
=
}
,
{
n
=
}
,
{
diff
:.
5
f
}
"
)
def
test_m_grouped_gemm_contiguous
(
self
):
print
(
"Testing grouped contiguous GEMM:"
)
for
num_groups
,
m
,
k
,
n
in
(
(
4
,
8192
,
7168
,
4096
),
(
4
,
8192
,
2048
,
7168
),
(
8
,
4096
,
7168
,
4096
),
(
8
,
4096
,
2048
,
7168
),
):
# TODO: make a stronger test
x_fp8
,
y_fp8
,
out
,
ref_out
=
construct_grouped
(
num_groups
,
m
,
k
,
n
,
is_masked
=
False
)
m_indices
=
torch
.
arange
(
0
,
num_groups
,
device
=
"cuda"
,
dtype
=
torch
.
int
)
m_indices
=
(
m_indices
.
unsqueeze
(
-
1
).
expand
(
num_groups
,
m
).
contiguous
().
view
(
-
1
)
)
deep_gemm
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
x_fp8
,
y_fp8
,
out
,
m_indices
)
diff
=
calc_diff
(
out
,
ref_out
)
self
.
assertTrue
(
diff
<
0.001
,
f
"m=
{
m
*
num_groups
}
,
{
k
=
}
,
{
n
=
}
,
{
diff
:.
5
f
}
"
)
def
test_m_grouped_gemm_masked
(
self
):
print
(
"Testing grouped masked GEMM:"
)
for
num_groups
,
m
in
((
1
,
1024
),
(
2
,
512
),
(
4
,
256
)):
for
k
,
n
in
(
(
7168
,
4096
),
(
2048
,
7168
),
):
# Test correctness
masked_m_candidates
=
list
(
filter
(
lambda
candidate
:
candidate
<=
m
,
(
64
,
128
,
192
,
256
,
320
,
384
)
)
)
for
i
in
range
(
10
):
x_fp8
,
y_fp8
,
out
,
ref_out
=
construct_grouped
(
num_groups
,
m
,
k
,
n
,
is_masked
=
True
)
masked_m
=
torch
.
empty
(
(
num_groups
,),
device
=
"cuda"
,
dtype
=
torch
.
int
)
for
j
in
range
(
num_groups
):
masked_m
[
j
]
=
random
.
choice
(
masked_m_candidates
)
expected_m
=
min
(
int
(
masked_m
.
float
().
mean
())
+
1
,
m
)
deep_gemm
.
m_grouped_gemm_fp8_fp8_bf16_nt_masked
(
x_fp8
,
y_fp8
,
out
,
masked_m
,
expected_m
)
for
j
in
range
(
num_groups
):
diff
=
calc_diff
(
out
[
j
,
:
masked_m
[
j
].
item
()],
ref_out
[
j
,
:
masked_m
[
j
].
item
()],
)
self
.
assertTrue
(
diff
<
0.001
,
f
"
{
m
=
}
,
{
k
=
}
,
{
n
=
}
,
{
j
=
}
, masked_m=
{
masked_m
[
j
]
}
,
{
num_groups
=
}
,
{
diff
:.
5
f
}
"
,
)
"""
fork deepgemm/tests/test_jit.py
"""
class
Capture
:
def
__init__
(
self
)
->
None
:
self
.
read_fd
=
None
self
.
write_fd
=
None
self
.
saved_stdout
=
None
self
.
captured
=
None
def
__enter__
(
self
)
->
Any
:
self
.
read_fd
,
self
.
write_fd
=
os
.
pipe
()
self
.
saved_stdout
=
os
.
dup
(
1
)
os
.
dup2
(
self
.
write_fd
,
1
)
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
)
->
None
:
os
.
dup2
(
self
.
saved_stdout
,
1
)
os
.
close
(
self
.
write_fd
)
with
os
.
fdopen
(
self
.
read_fd
,
"r"
)
as
f
:
self
.
captured
=
f
.
read
()
def
capture
(
self
)
->
str
:
return
self
.
captured
class
TestDeepGemmJIT
(
unittest
.
TestCase
):
def
test_jit
(
self
):
# Runtime
print
(
f
"NVCC compiler:
{
jit
.
get_nvcc_compiler
()
}
\n
"
)
# Templates
print
(
"Generated code:"
)
args
=
(
(
"lhs"
,
torch
.
float8_e4m3fn
),
(
"rhs"
,
torch
.
float8_e4m3fn
),
(
"scale"
,
torch
.
float
),
(
"out"
,
torch
.
bfloat16
),
(
"enable_double_streams"
,
bool
),
(
"stream"
,
torch
.
cuda
.
Stream
),
)
body
=
"
\n
"
body
+=
"std::cout << reinterpret_cast<uint64_t>(lhs) << std::endl;
\n
"
body
+=
"std::cout << reinterpret_cast<uint64_t>(rhs) << std::endl;
\n
"
body
+=
"std::cout << reinterpret_cast<uint64_t>(scale) << std::endl;
\n
"
body
+=
"std::cout << reinterpret_cast<uint64_t>(out) << std::endl;
\n
"
body
+=
"std::cout << enable_double_streams << std::endl;
\n
"
body
+=
"std::cout << reinterpret_cast<uint64_t>(stream) << std::endl;
\n
"
code
=
jit
.
generate
((),
args
,
body
)
print
(
code
)
# Build
print
(
"Building ..."
)
func
=
jit
.
build
(
"test_func"
,
args
,
code
)
# Test correctness
print
(
"Running ..."
)
fp8_tensor
=
torch
.
empty
((
1
,),
dtype
=
torch
.
float8_e4m3fn
,
device
=
"cuda"
)
fp32_tensor
=
torch
.
empty
((
1
,),
dtype
=
torch
.
float
,
device
=
"cuda"
)
bf16_tensor
=
torch
.
empty
((
1
,),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
with
Capture
()
as
capture
:
self
.
assertTrue
(
func
(
fp8_tensor
,
fp8_tensor
,
fp32_tensor
,
bf16_tensor
,
True
,
torch
.
cuda
.
current_stream
(),
)
==
0
)
output
=
capture
.
capture
()
ref_output
=
f
"
{
fp8_tensor
.
data_ptr
()
}
\n
{
fp8_tensor
.
data_ptr
()
}
\n
{
fp32_tensor
.
data_ptr
()
}
\n
{
bf16_tensor
.
data_ptr
()
}
\n
1
\n
{
torch
.
cuda
.
current_stream
().
cuda_stream
}
\n
"
self
.
assertTrue
(
output
==
ref_output
,
f
"
{
output
=
}
,
{
ref_output
=
}
"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment