Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9fccda31
Unverified
Commit
9fccda31
authored
Mar 30, 2025
by
Adarsh Shirawalmath
Committed by
GitHub
Mar 30, 2025
Browse files
[Feature] use pytest for sgl-kernel (#4896)
parent
4ede6770
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
275 additions
and
302 deletions
+275
-302
.github/workflows/pr-test-sgl-kernel.yml
.github/workflows/pr-test-sgl-kernel.yml
+3
-2
sgl-kernel/tests/speculative/test_eagle_utils.py
sgl-kernel/tests/speculative/test_eagle_utils.py
+6
-5
sgl-kernel/tests/speculative/test_speculative_sampling.py
sgl-kernel/tests/speculative/test_speculative_sampling.py
+16
-20
sgl-kernel/tests/test_awq_dequant.py
sgl-kernel/tests/test_awq_dequant.py
+0
-1
sgl-kernel/tests/test_cublas_grouped_gemm.py
sgl-kernel/tests/test_cublas_grouped_gemm.py
+32
-41
sgl-kernel/tests/test_fp8_blockwise_gemm.py
sgl-kernel/tests/test_fp8_blockwise_gemm.py
+33
-51
sgl-kernel/tests/test_fp8_gemm.py
sgl-kernel/tests/test_fp8_gemm.py
+31
-49
sgl-kernel/tests/test_int8_gemm.py
sgl-kernel/tests/test_int8_gemm.py
+26
-35
sgl-kernel/tests/test_per_token_quant_fp8.py
sgl-kernel/tests/test_per_token_quant_fp8.py
+0
-1
sgl-kernel/tests/test_trt_allreduce.py
sgl-kernel/tests/test_trt_allreduce.py
+128
-97
No files found.
.github/workflows/pr-test-sgl-kernel.yml
View file @
9fccda31
...
...
@@ -80,7 +80,8 @@ jobs:
-
name
:
Install
run
:
|
pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.6.4.post1
bash scripts/ci_install_dependency.sh
pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.7.2
pip3 uninstall sgl-kernel -y || true
pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps
pip3 list | grep sgl-kernel
...
...
@@ -89,7 +90,7 @@ jobs:
timeout-minutes
:
30
run
:
|
cd sgl-kernel
find tests -name "test_*.py" | xargs -n 1 python3
pytest tests/
-
name
:
Uninstall dependencies
run
:
|
...
...
sgl-kernel/tests/speculative/test_eagle_utils.py
View file @
9fccda31
import
pytest
import
torch
import
torch.nn.functional
as
F
from
sgl_kernel
import
verify_tree_greedy
...
...
@@ -85,14 +86,14 @@ def test_verify_tree_greedy():
print
(
f
"
{
accept_index
=
}
"
)
print
(
f
"
{
accept_token_num
=
}
"
)
return
predicts
,
accept_index
,
accept_token_num
if
__name__
==
"__main__"
:
predicts
,
accept_index
,
accept_token_num
=
test_verify_tree_greedy
()
# Check the expected output.
assert
predicts
.
tolist
()
==
[
3
,
-
1
,
-
1
,
4
,
5
,
18
,
11
,
-
1
,
-
1
,
-
1
,
12
,
18
]
assert
accept_index
.
tolist
()
==
[
[
0
,
3
,
4
,
5
],
[
6
,
10
,
11
,
-
1
],
]
assert
accept_token_num
.
tolist
()
==
[
3
,
2
]
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
sgl-kernel/tests/speculative/test_speculative_sampling.py
View file @
9fccda31
import
pytest
import
torch
import
torch.nn.functional
as
F
from
sgl_kernel
import
tree_speculative_sampling_target_only
...
...
@@ -97,26 +98,21 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc
print
(
f
"
{
accept_index
=
}
"
)
print
(
f
"
{
accept_token_num
=
}
"
)
return
predicts
,
accept_index
,
accept_token_num
if
threshold_single
==
1
and
threshold_acc
==
1
:
assert
predicts
.
tolist
()
==
[
3
,
-
1
,
-
1
,
4
,
5
,
18
,
11
,
-
1
,
-
1
,
-
1
,
12
,
18
]
assert
accept_index
.
tolist
()
==
[
[
0
,
3
,
4
,
5
],
[
6
,
10
,
11
,
-
1
],
]
assert
accept_token_num
.
tolist
()
==
[
3
,
2
]
elif
threshold_single
==
0
and
threshold_acc
==
0
:
assert
predicts
.
tolist
()
==
[
1
,
2
,
18
,
-
1
,
-
1
,
-
1
,
11
,
-
1
,
-
1
,
-
1
,
12
,
18
]
assert
accept_index
.
tolist
()
==
[
[
0
,
1
,
2
,
-
1
],
[
6
,
10
,
11
,
-
1
],
]
assert
accept_token_num
.
tolist
()
==
[
2
,
2
]
if
__name__
==
"__main__"
:
predicts
,
accept_index
,
accept_token_num
=
(
test_tree_speculative_sampling_target_only
(
threshold_single
=
1
,
threshold_acc
=
1
)
)
assert
predicts
.
tolist
()
==
[
3
,
-
1
,
-
1
,
4
,
5
,
18
,
11
,
-
1
,
-
1
,
-
1
,
12
,
18
]
assert
accept_index
.
tolist
()
==
[
[
0
,
3
,
4
,
5
],
[
6
,
10
,
11
,
-
1
],
]
assert
accept_token_num
.
tolist
()
==
[
3
,
2
]
predicts
,
accept_index
,
accept_token_num
=
(
test_tree_speculative_sampling_target_only
(
threshold_single
=
0
,
threshold_acc
=
0
)
)
assert
predicts
.
tolist
()
==
[
1
,
2
,
18
,
-
1
,
-
1
,
-
1
,
11
,
-
1
,
-
1
,
-
1
,
12
,
18
]
assert
accept_index
.
tolist
()
==
[
[
0
,
1
,
2
,
-
1
],
[
6
,
10
,
11
,
-
1
],
]
assert
accept_token_num
.
tolist
()
==
[
2
,
2
]
pytest
.
main
([
__file__
])
sgl-kernel/tests/test_awq_dequant.py
View file @
9fccda31
...
...
@@ -128,5 +128,4 @@ def test_awq_dequant_compare_implementations(
if
__name__
==
"__main__"
:
# Run the specific test function directly
pytest
.
main
([
__file__
])
sgl-kernel/tests/test_cublas_grouped_gemm.py
View file @
9fccda31
import
unittest
import
pytest
import
torch
from
sgl_kernel
import
cublas_grouped_gemm
def
torch_grouped_gemm
(
a_array
,
b_array
,
out_dtype
):
c_array
=
[]
for
a
,
b
in
zip
(
a_array
,
b_array
):
c_array
.
append
(
torch
.
matmul
(
a
,
b
.
t
()).
to
(
out_dtype
))
return
c_array
class
TestGroupedGemm
(
unittest
.
TestCase
):
def
_test_accuracy
(
self
,
Ms
,
Ns
,
Ks
,
out_dtype
):
group_count
=
len
(
Ms
)
a_array
=
[]
b_array
=
[]
c_array_cublas
=
[]
for
i
in
range
(
group_count
):
M
,
N
,
K
=
Ms
[
i
],
Ns
[
i
],
Ks
[
i
]
a_array
.
append
(
torch
.
randn
((
M
,
K
),
device
=
"cuda"
,
dtype
=
out_dtype
)
*
5
)
b_array
.
append
(
torch
.
randn
((
N
,
K
),
device
=
"cuda"
,
dtype
=
out_dtype
)
*
5
)
c_array_cublas
.
append
(
torch
.
empty
((
M
,
N
),
device
=
"cuda"
,
dtype
=
out_dtype
))
c_array_torch
=
torch_grouped_gemm
(
a_array
,
b_array
,
out_dtype
)
cublas_grouped_gemm
(
a_array
,
b_array
,
c_array_cublas
,
out_dtype
)
for
i
in
range
(
group_count
):
M
,
N
,
K
=
Ms
[
i
],
Ns
[
i
],
Ks
[
i
]
torch
.
testing
.
assert_close
(
c_array_torch
[
i
],
c_array_cublas
[
i
])
print
(
f
"M=
{
M
}
, N=
{
N
}
, K=
{
K
}
, out_dtype=
{
out_dtype
}
: OK"
)
def
test_accuracy
(
self
):
Ms
=
[
1
,
16
,
32
,
256
,
1024
]
Ns
=
[
2
,
16
,
128
,
256
,
4096
]
Ks
=
[
3
,
16
,
32
,
512
,
8192
]
out_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
]
for
out_dtype
in
out_dtypes
:
self
.
_test_accuracy
(
Ms
,
Ns
,
Ks
,
out_dtype
)
return
[
torch
.
matmul
(
a
,
b
.
t
()).
to
(
out_dtype
)
for
a
,
b
in
zip
(
a_array
,
b_array
)]
skip_condition
=
not
torch
.
cuda
.
is_available
()
or
(
torch
.
version
.
cuda
is
None
or
tuple
(
map
(
int
,
torch
.
version
.
cuda
.
split
(
"."
)))
<
(
12
,
5
)
)
@
pytest
.
mark
.
skipif
(
skip_condition
,
reason
=
"CUDA not available or CUDA version lower than 12.5"
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"M"
,
[
1
,
16
,
32
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
2
,
16
,
128
,
256
,
4096
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
3
,
16
,
32
,
512
,
8192
])
def
test_grouped_gemm_accuracy
(
out_dtype
,
M
,
N
,
K
):
a
=
torch
.
randn
((
M
,
K
),
device
=
"cuda"
,
dtype
=
out_dtype
)
*
5
b
=
torch
.
randn
((
N
,
K
),
device
=
"cuda"
,
dtype
=
out_dtype
)
*
5
expected
=
torch
.
matmul
(
a
,
b
.
t
()).
to
(
out_dtype
)
a_array
=
[
a
]
b_array
=
[
b
]
c_array
=
[
torch
.
empty
((
M
,
N
),
device
=
"cuda"
,
dtype
=
out_dtype
)]
result_torch
=
torch_grouped_gemm
(
a_array
,
b_array
,
out_dtype
)[
0
]
cublas_grouped_gemm
(
a_array
,
b_array
,
c_array
,
out_dtype
)
torch
.
testing
.
assert_close
(
result_torch
,
expected
)
torch
.
testing
.
assert_close
(
c_array
[
0
],
expected
)
if
__name__
==
"__main__"
:
if
torch
.
cuda
.
is_available
():
cuda_version
=
tuple
(
map
(
int
,
torch
.
version
.
cuda
.
split
(
"."
)))
if
cuda_version
>=
(
12
,
5
):
unittest
.
main
()
else
:
print
(
f
"Cuda version
{
cuda_version
}
lower than 12.5, not executing tests."
)
pytest
.
main
([
__file__
])
sgl-kernel/tests/test_fp8_blockwise_gemm.py
View file @
9fccda31
import
unittest
import
os
import
random
from
typing
import
Optional
,
Type
import
pytest
import
torch
from
sgl_kernel
import
fp8_blockwise_scaled_mm
def
cdiv
(
a
:
int
,
b
:
int
)
->
int
:
"""Ceiling division."""
return
-
(
a
//
-
b
)
...
...
@@ -23,7 +24,6 @@ def baseline_scaled_mm(
out_dtype
:
Type
[
torch
.
dtype
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
# We treat N-dimensional group scaling as extended numpy-style broadcasting
# in numpy simply stretches dimensions with an extent of 1 to match the
# the target shape by repeating the data along that dimension (broadcasting)
...
...
@@ -51,62 +51,44 @@ def baseline_scaled_mm(
scale_a
=
group_broadcast
(
scale_a
,
a
.
shape
)
scale_b
=
group_broadcast
(
scale_b
,
b
.
shape
)
output
=
torch
.
mm
(
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
)),
(
scale_b
*
b
.
to
(
dtype
=
torch
.
float32
))
).
to
(
out_dtype
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
class
TestFp8Gemm
(
unittest
.
TestCase
):
def
_test_accuracy_once
(
self
,
M
,
N
,
K
,
out_dtype
,
device
):
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
a_fp32
=
(
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
)
a_fp8
=
a_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
b_fp32
=
(
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
)
b_fp8
=
b_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
).
t
()
scale_a_group_shape
=
(
1
,
128
)
scale_b_group_shape
=
(
128
,
128
)
scale_a_shape
=
scale_shape
(
a_fp8
.
shape
,
scale_a_group_shape
)
scale_b_shape
=
scale_shape
(
b_fp8
.
shape
,
scale_b_group_shape
)
scale_a
=
torch
.
randn
(
scale_a_shape
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.001
scale_b
=
torch
.
randn
(
scale_b_shape
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.001
scale_a
=
scale_a
.
t
().
contiguous
().
t
()
scale_b
=
scale_b
.
t
().
contiguous
().
t
()
o1
=
torch
.
empty
((
M
,
N
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
o
=
baseline_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
out_dtype
)
o1
=
fp8_blockwise_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
out_dtype
)
rtol
=
0.02
atol
=
1
torch
.
testing
.
assert_close
(
o
,
o1
,
rtol
=
rtol
,
atol
=
atol
)
print
(
f
"M=
{
M
}
, N=
{
N
}
, K=
{
K
}
, out_dtype=
{
out_dtype
}
: OK"
)
def
test_accuracy
(
self
):
Ms
=
[
1
,
128
,
512
,
1024
,
4096
]
Ns
=
[
128
,
512
,
1024
,
4096
]
Ks
=
[
512
,
1024
,
4096
,
8192
,
16384
]
out_dtypes
=
[
torch
.
bfloat16
,
torch
.
float16
]
for
M
in
Ms
:
for
N
in
Ns
:
for
K
in
Ks
:
for
out_dtype
in
out_dtypes
:
self
.
_test_accuracy_once
(
M
,
N
,
K
,
out_dtype
,
"cuda"
)
def
_test_accuracy_once
(
M
,
N
,
K
,
out_dtype
,
device
):
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
a_fp32
=
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
a_fp8
=
a_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
b_fp32
=
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
b_fp8
=
b_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
).
t
()
scale_a_group_shape
=
(
1
,
128
)
scale_b_group_shape
=
(
128
,
128
)
scale_a_shape
=
scale_shape
(
a_fp8
.
shape
,
scale_a_group_shape
)
scale_b_shape
=
scale_shape
(
b_fp8
.
shape
,
scale_b_group_shape
)
scale_a
=
torch
.
randn
(
scale_a_shape
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.001
scale_b
=
torch
.
randn
(
scale_b_shape
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.001
scale_a
=
scale_a
.
t
().
contiguous
().
t
()
scale_b
=
scale_b
.
t
().
contiguous
().
t
()
o
=
baseline_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
out_dtype
)
o1
=
fp8_blockwise_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
out_dtype
)
rtol
=
0.02
atol
=
1
torch
.
testing
.
assert_close
(
o
,
o1
,
rtol
=
rtol
,
atol
=
atol
)
print
(
f
"M=
{
M
}
, N=
{
N
}
, K=
{
K
}
, out_dtype=
{
out_dtype
}
: OK"
)
@
pytest
.
mark
.
parametrize
(
"M"
,
[
1
,
128
,
512
,
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
128
,
512
,
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
512
,
1024
,
4096
,
8192
,
16384
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
def
test_accuracy
(
M
,
N
,
K
,
out_dtype
):
_test_accuracy_once
(
M
,
N
,
K
,
out_dtype
,
"cuda"
)
if
__name__
==
"__main__"
:
unit
test
.
main
()
py
test
.
main
(
[
__file__
]
)
sgl-kernel/tests/test_fp8_gemm.py
View file @
9fccda31
import
unittest
import
pytest
import
torch
from
sgl_kernel
import
fp8_scaled_mm
def
torch_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
):
o
=
torch
.
matmul
(
a
.
to
(
torch
.
float32
),
b
.
to
(
torch
.
float32
))
o
=
o
.
to
(
torch
.
float32
)
temp1
=
o
*
scale_a
.
view
(
-
1
,
1
)
temp2
=
temp1
*
scale_b
.
view
(
1
,
-
1
)
final
=
temp2
.
to
(
out_dtype
)
if
bias
is
not
None
:
final
=
final
+
bias
.
view
(
1
,
-
1
)
return
final
class
TestFp8Gemm
(
unittest
.
TestCase
):
def
_test_accuracy_once
(
self
,
M
,
N
,
K
,
with_bias
,
out_dtype
,
device
):
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
a_fp32
=
(
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
)
a_fp8
=
a_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
b_fp32
=
(
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
)
b_fp8
=
b_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
scale_a
=
torch
.
randn
((
M
,),
device
=
device
,
dtype
=
torch
.
float32
)
*
0.001
scale_b
=
torch
.
randn
((
N
,),
device
=
device
,
dtype
=
torch
.
float32
)
*
0.001
if
with_bias
:
bias
=
torch
.
randn
((
N
,),
device
=
device
,
dtype
=
out_dtype
)
else
:
bias
=
None
o1
=
torch
.
empty
((
M
,
N
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
b_fp8
=
b_fp8
.
t
()
o
=
torch_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
o1
=
fp8_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
rtol
=
0.02
atol
=
1
torch
.
testing
.
assert_close
(
o
,
o1
,
rtol
=
rtol
,
atol
=
atol
)
print
(
f
"M=
{
M
}
, N=
{
N
}
, K=
{
K
}
, with_bias=
{
with_bias
}
, out_dtype=
{
out_dtype
}
: OK"
)
def
test_accuracy
(
self
):
Ms
=
[
1
,
128
,
512
,
1024
,
4096
]
Ns
=
[
16
,
128
,
512
,
1024
,
4096
]
Ks
=
[
512
,
1024
,
4096
,
8192
,
16384
]
bias_opts
=
[
True
,
False
]
out_dtypes
=
[
torch
.
bfloat16
,
torch
.
float16
]
for
M
in
Ms
:
for
N
in
Ns
:
for
K
in
Ks
:
for
with_bias
in
bias_opts
:
for
out_dtype
in
out_dtypes
:
self
.
_test_accuracy_once
(
M
,
N
,
K
,
with_bias
,
out_dtype
,
"cuda"
)
def
_test_accuracy_once
(
M
,
N
,
K
,
with_bias
,
out_dtype
,
device
):
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
a_fp32
=
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
a_fp8
=
a_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
b_fp32
=
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
b_fp8
=
b_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
scale_a
=
torch
.
randn
((
M
,),
device
=
device
,
dtype
=
torch
.
float32
)
*
0.001
scale_b
=
torch
.
randn
((
N
,),
device
=
device
,
dtype
=
torch
.
float32
)
*
0.001
if
with_bias
:
bias
=
torch
.
randn
((
N
,),
device
=
device
,
dtype
=
out_dtype
)
else
:
bias
=
None
b_fp8
=
b_fp8
.
t
()
o
=
torch_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
o1
=
fp8_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
rtol
=
0.02
atol
=
1
torch
.
testing
.
assert_close
(
o
,
o1
,
rtol
=
rtol
,
atol
=
atol
)
print
(
f
"M=
{
M
}
, N=
{
N
}
, K=
{
K
}
, with_bias=
{
with_bias
}
, out_dtype=
{
out_dtype
}
: OK"
)
@
pytest
.
mark
.
parametrize
(
"M"
,
[
1
,
128
,
512
,
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
16
,
128
,
512
,
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
512
,
1024
,
4096
,
8192
,
16384
])
@
pytest
.
mark
.
parametrize
(
"with_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
def
test_accuracy
(
M
,
N
,
K
,
with_bias
,
out_dtype
):
_test_accuracy_once
(
M
,
N
,
K
,
with_bias
,
out_dtype
,
"cuda"
)
if
__name__
==
"__main__"
:
unit
test
.
main
()
py
test
.
main
(
[
__file__
]
)
sgl-kernel/tests/test_int8_gemm.py
View file @
9fccda31
import
unittest
import
pytest
import
torch
from
sgl_kernel
import
int8_scaled_mm
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
...
...
@@ -18,39 +17,31 @@ def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
return
o
.
to
(
out_dtype
)
class
TestInt8Gemm
(
unittest
.
TestCase
):
def
_test_accuracy_once
(
self
,
M
,
N
,
K
,
with_bias
,
out_dtype
,
device
):
a
=
to_int8
(
torch
.
randn
((
M
,
K
),
device
=
device
)
*
5
)
b
=
to_int8
(
torch
.
randn
((
N
,
K
),
device
=
device
).
t
()
*
5
)
scale_a
=
torch
.
randn
((
M
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
randn
((
N
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
if
with_bias
:
bias
=
torch
.
randn
((
N
,),
device
=
"cuda"
,
dtype
=
out_dtype
)
*
10
else
:
bias
=
None
o
=
int8_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
o1
=
torch_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
o2
=
vllm_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
torch
.
testing
.
assert_close
(
o
,
o1
)
torch
.
testing
.
assert_close
(
o
,
o2
)
print
(
f
"M=
{
M
}
, N=
{
N
}
, K=
{
K
}
, with_bias=
{
with_bias
}
, out_dtype=
{
out_dtype
}
: OK"
)
def
test_accuracy
(
self
):
Ms
=
[
1
,
16
,
32
,
64
,
128
,
512
,
1024
,
4096
,
8192
]
Ns
=
[
16
,
128
,
512
,
1024
,
4096
,
8192
,
16384
]
Ks
=
[
512
,
1024
,
4096
,
8192
,
16384
]
bias_opts
=
[
True
,
False
]
out_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
]
for
M
in
Ms
:
for
N
in
Ns
:
for
K
in
Ks
:
for
with_bias
in
bias_opts
:
for
out_dtype
in
out_dtypes
:
self
.
_test_accuracy_once
(
M
,
N
,
K
,
with_bias
,
out_dtype
,
"cuda"
)
def
_test_accuracy_once
(
M
,
N
,
K
,
with_bias
,
out_dtype
,
device
):
a
=
to_int8
(
torch
.
randn
((
M
,
K
),
device
=
device
)
*
5
)
b
=
to_int8
(
torch
.
randn
((
N
,
K
),
device
=
device
).
t
()
*
5
)
scale_a
=
torch
.
randn
((
M
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
randn
((
N
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
if
with_bias
:
bias
=
torch
.
randn
((
N
,),
device
=
"cuda"
,
dtype
=
out_dtype
)
*
10
else
:
bias
=
None
o
=
int8_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
o1
=
torch_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
o2
=
vllm_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
torch
.
testing
.
assert_close
(
o
,
o1
)
torch
.
testing
.
assert_close
(
o
,
o2
)
print
(
f
"M=
{
M
}
, N=
{
N
}
, K=
{
K
}
, with_bias=
{
with_bias
}
, out_dtype=
{
out_dtype
}
: OK"
)
@
pytest
.
mark
.
parametrize
(
"M"
,
[
1
,
16
,
32
,
64
,
128
,
512
,
1024
,
4096
,
8192
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
16
,
128
,
512
,
1024
,
4096
,
8192
,
16384
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
512
,
1024
,
4096
,
8192
,
16384
])
@
pytest
.
mark
.
parametrize
(
"with_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
def
test_accuracy
(
M
,
N
,
K
,
with_bias
,
out_dtype
):
_test_accuracy_once
(
M
,
N
,
K
,
with_bias
,
out_dtype
,
"cuda"
)
if
__name__
==
"__main__"
:
unit
test
.
main
()
py
test
.
main
(
[
__file__
]
)
sgl-kernel/tests/test_per_token_quant_fp8.py
View file @
9fccda31
...
...
@@ -51,5 +51,4 @@ def test_per_token_quant_compare_implementations(
if
__name__
==
"__main__"
:
# Run the specific test function directly
pytest
.
main
([
__file__
])
sgl-kernel/tests/test_trt_allreduce.py
View file @
9fccda31
...
...
@@ -13,154 +13,185 @@ from torch.distributed import ProcessGroup
from
sglang.srt.distributed.device_communicators.cuda_wrapper
import
CudaRTLibrary
def
_run_correctness_worker
(
world_size
,
rank
,
distributed_init_port
,
test_sizes
):
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
ranks
=
list
(
range
(
world_size
))
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
dist
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
distributed_init_method
,
rank
=
rank
,
world_size
=
world_size
,
)
group
=
dist
.
group
.
WORLD
buffer_max_size
=
8
*
1024
*
1024
barrier_max_size
=
8
*
(
24
+
2
)
*
8
buffer_ptrs
=
None
tmp_result_buffer_ptrs
=
None
barrier_in_ptrs
=
None
barrier_out_ptrs
=
None
custom_ptr
=
None
try
:
buffer_ptrs
=
TestCustomAllReduce
.
create_shared_buffer
(
buffer_max_size
,
group
=
group
)
tmp_result_buffer_ptrs
=
TestCustomAllReduce
.
create_shared_buffer
(
buffer_max_size
,
group
=
group
)
barrier_in_ptrs
=
TestCustomAllReduce
.
create_shared_buffer
(
barrier_max_size
,
group
=
group
)
barrier_out_ptrs
=
TestCustomAllReduce
.
create_shared_buffer
(
barrier_max_size
,
group
=
group
)
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
device
)
custom_ptr
=
custom_ops
.
init_custom_reduce
(
rank
,
world_size
,
rank_data
,
buffer_ptrs
,
tmp_result_buffer_ptrs
,
barrier_in_ptrs
,
barrier_out_ptrs
,
)
test_loop
=
10
for
sz
in
test_sizes
:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
for
_
in
range
(
test_loop
):
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
dtype
,
device
=
device
)
inp1_ref
=
inp1
.
clone
()
out1
=
torch
.
empty_like
(
inp1
)
custom_ops
.
custom_reduce
(
custom_ptr
,
inp1
,
out1
)
dist
.
all_reduce
(
inp1_ref
,
group
=
group
)
torch
.
testing
.
assert_close
(
out1
,
inp1_ref
)
finally
:
dist
.
barrier
(
group
=
group
)
if
custom_ptr
is
not
None
:
custom_ops
.
custom_dispose
(
custom_ptr
)
if
buffer_ptrs
:
TestCustomAllReduce
.
free_shared_buffer
(
buffer_ptrs
,
group
)
if
tmp_result_buffer_ptrs
:
TestCustomAllReduce
.
free_shared_buffer
(
tmp_result_buffer_ptrs
,
group
)
if
barrier_in_ptrs
:
TestCustomAllReduce
.
free_shared_buffer
(
barrier_in_ptrs
,
group
)
if
barrier_out_ptrs
:
TestCustomAllReduce
.
free_shared_buffer
(
barrier_out_ptrs
,
group
)
dist
.
destroy_process_group
(
group
=
group
)
def
get_open_port
()
->
int
:
# try ipv4
try
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
"127.0.0.1"
,
0
))
return
s
.
getsockname
()[
1
]
except
OSError
:
# try ipv6
with
socket
.
socket
(
socket
.
AF_INET6
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
"
127.0.0.
1"
,
0
))
s
.
bind
((
"
::
1"
,
0
))
return
s
.
getsockname
()[
1
]
def
multi_process_parallel
(
world_size
:
int
,
test_target
:
Any
,
world_size
:
int
,
test_target
:
Any
,
target_args
:
tuple
=
()
)
->
None
:
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
procs
=
[]
distributed_init_port
=
get_open_port
()
for
i
in
range
(
world_size
):
proc
=
mp
.
Process
(
target
=
test_target
,
args
=
(
world_size
,
i
,
distributed_init_port
),
)
proc_args
=
(
world_size
,
i
,
distributed_init_port
)
+
target_args
proc
=
mp
.
Process
(
target
=
test_target
,
args
=
proc_args
,
name
=
f
"Worker-
{
i
}
"
)
proc
.
start
()
procs
.
append
(
proc
)
for
i
in
range
(
world_size
):
procs
[
i
].
join
()
assert
procs
[
i
].
exitcode
==
0
assert
(
procs
[
i
].
exitcode
==
0
),
f
"Process
{
i
}
failed with exit code
{
procs
[
i
].
exitcode
}
"
class
TestCustomAllReduce
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
random
.
seed
(
42
)
cls
.
test_sizes
=
[
512
,
4096
,
32768
,
262144
,
524288
,
1048576
,
2097152
]
cls
.
world_sizes
=
[
2
,
4
,
8
]
test_sizes
=
[
512
,
4096
,
32768
,
262144
,
524288
,
1048576
,
2097152
]
world_sizes
=
[
2
,
4
,
8
]
@
staticmethod
def
create_shared_buffer
(
size_in_bytes
:
int
,
group
:
Optional
[
ProcessGroup
]
=
None
)
->
List
[
int
]:
"""
Creates a shared buffer and returns a list of pointers
representing the buffer on all processes in the group.
"""
lib
=
CudaRTLibrary
()
pointer
=
lib
.
cudaMalloc
(
size_in_bytes
)
handle
=
lib
.
cudaIpcGetMemHandle
(
pointer
)
if
group
is
None
:
group
=
dist
.
group
.
WORLD
world_size
=
dist
.
get_world_size
(
group
=
group
)
rank
=
dist
.
get_rank
(
group
=
group
)
handles
=
[
None
]
*
world_size
dist
.
all_gather_object
(
handles
,
handle
,
group
=
group
)
handle_bytes
=
ctypes
.
string_at
(
ctypes
.
addressof
(
handle
),
ctypes
.
sizeof
(
handle
))
input_tensor
=
torch
.
ByteTensor
(
list
(
handle_bytes
)).
to
(
f
"cuda:
{
rank
}
"
)
gathered_tensors
=
[
torch
.
empty_like
(
input_tensor
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
gathered_tensors
,
input_tensor
,
group
=
group
)
handles
=
[]
handle_type
=
type
(
handle
)
for
tensor
in
gathered_tensors
:
bytes_list
=
tensor
.
cpu
().
tolist
()
bytes_data
=
bytes
(
bytes_list
)
handle_obj
=
handle_type
()
ctypes
.
memmove
(
ctypes
.
addressof
(
handle_obj
),
bytes_data
,
len
(
bytes_data
))
handles
.
append
(
handle_obj
)
pointers
:
List
[
int
]
=
[]
for
i
,
h
in
enumerate
(
handles
):
if
i
==
rank
:
pointers
.
append
(
pointer
.
value
)
# type: ignore
pointers
.
append
(
pointer
.
value
)
else
:
pointers
.
append
(
lib
.
cudaIpcOpenMemHandle
(
h
).
value
)
# type: ignore
try
:
opened_ptr
=
lib
.
cudaIpcOpenMemHandle
(
h
)
pointers
.
append
(
opened_ptr
.
value
)
except
Exception
as
e
:
print
(
f
"Rank
{
rank
}
: Failed to open IPC handle from rank
{
i
}
:
{
e
}
"
)
raise
dist
.
barrier
(
group
=
group
)
return
pointers
@
staticmethod
def
free_shared_buffer
(
pointers
:
List
[
int
],
group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
if
group
is
None
:
group
=
dist
.
group
.
WORLD
rank
=
dist
.
get_rank
(
group
=
group
)
lib
=
CudaRTLibrary
()
lib
.
cudaFree
(
ctypes
.
c_void_p
(
pointers
[
rank
]))
if
pointers
and
len
(
pointers
)
>
rank
and
pointers
[
rank
]
is
not
None
:
lib
.
cudaFree
(
ctypes
.
c_void_p
(
pointers
[
rank
]))
dist
.
barrier
(
group
=
group
)
def
test_correctness
(
self
):
for
world_size
in
self
.
world_sizes
:
if
world_size
>
torch
.
cuda
.
device_count
():
available_gpus
=
torch
.
cuda
.
device_count
()
if
world_size
>
available_gpus
:
print
(
f
"Skipping world_size=
{
world_size
}
, requires
{
world_size
}
GPUs, found
{
available_gpus
}
"
)
continue
multi_process_parallel
(
world_size
,
self
.
correctness
)
print
(
f
"custom allreduce tp =
{
world_size
}
: OK"
)
def
init_custom_allreduce
(
self
,
rank
,
world_size
,
group
):
buffer_max_size
=
8
*
1024
*
1024
barrier_max_size
=
8
*
(
24
+
2
)
*
8
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
buffer_max_size
,
group
=
group
)
self
.
tmp_result_buffer_ptrs
=
self
.
create_shared_buffer
(
buffer_max_size
,
group
=
group
)
self
.
barrier_in_ptrs
=
self
.
create_shared_buffer
(
barrier_max_size
,
group
=
group
)
self
.
barrier_out_ptrs
=
self
.
create_shared_buffer
(
barrier_max_size
,
group
=
group
)
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
torch
.
device
(
"cuda:0"
)
)
self
.
custom_ptr
=
custom_ops
.
init_custom_reduce
(
rank
,
world_size
,
self
.
rank_data
,
self
.
buffer_ptrs
,
self
.
tmp_result_buffer_ptrs
,
self
.
barrier_in_ptrs
,
self
.
barrier_out_ptrs
,
)
def
custom_allreduce
(
self
,
inp
,
out
):
custom_ops
.
custom_reduce
(
self
.
custom_ptr
,
inp
,
out
)
def
free_custom_allreduce
(
self
,
group
):
self
.
free_shared_buffer
(
self
.
buffer_ptrs
,
group
)
self
.
free_shared_buffer
(
self
.
tmp_result_buffer_ptrs
,
group
)
self
.
free_shared_buffer
(
self
.
barrier_in_ptrs
,
group
)
self
.
free_shared_buffer
(
self
.
barrier_out_ptrs
,
group
)
custom_ops
.
custom_dispose
(
self
.
custom_ptr
)
@
staticmethod
def
init_distributed_env
(
world_size
,
rank
,
distributed_init_port
):
device
=
torch
.
device
(
"cuda:0"
)
torch
.
cuda
.
set_device
(
device
)
ranks
=
[
i
for
i
in
range
(
world_size
)]
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
dist
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
distributed_init_method
,
rank
=
rank
,
world_size
=
world_size
,
)
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
"gloo"
)
return
group
# compare result with torch.distributed
def
correctness
(
self
,
world_size
,
rank
,
distributed_init_port
):
group
=
self
.
init_distributed_env
(
world_size
,
rank
,
distributed_init_port
)
self
.
init_custom_allreduce
(
rank
=
rank
,
world_size
=
world_size
,
group
=
group
)
test_loop
=
10
for
sz
in
self
.
test_sizes
:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
for
_
in
range
(
test_loop
):
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
()
)
out1
=
torch
.
empty_like
(
inp1
)
self
.
custom_allreduce
(
inp1
,
out1
)
dist
.
all_reduce
(
inp1
,
group
=
group
)
torch
.
testing
.
assert_close
(
out1
,
inp1
)
self
.
free_custom_allreduce
(
group
)
print
(
f
"Running test for world_size=
{
world_size
}
"
)
multi_process_parallel
(
world_size
,
_run_correctness_worker
,
target_args
=
(
self
.
test_sizes
,)
)
print
(
f
"custom allreduce tp =
{
world_size
}
: OK"
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment