Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ed01b451
"vscode:/vscode.git/clone" did not exist on "a9f2acf3238dfb4d31ac5306bf5e75780ae50a78"
Unverified
Commit
ed01b451
authored
Apr 11, 2025
by
PGFLMG
Committed by
GitHub
Apr 10, 2025
Browse files
[Misc] Clean sgl-kernel test (#5216)
parent
d050df36
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
63 additions
and
76 deletions
+63
-76
sgl-kernel/tests/speculative/test_eagle_utils.py
sgl-kernel/tests/speculative/test_eagle_utils.py
+0
-11
sgl-kernel/tests/speculative/test_speculative_sampling.py
sgl-kernel/tests/speculative/test_speculative_sampling.py
+59
-50
sgl-kernel/tests/test_fp8_blockwise_gemm.py
sgl-kernel/tests/test_fp8_blockwise_gemm.py
+0
-1
sgl-kernel/tests/test_int8_gemm.py
sgl-kernel/tests/test_int8_gemm.py
+0
-1
sgl-kernel/tests/test_lightning_attention_decode.py
sgl-kernel/tests/test_lightning_attention_decode.py
+0
-4
sgl-kernel/tests/test_moe_topk_softmax.py
sgl-kernel/tests/test_moe_topk_softmax.py
+2
-4
sgl-kernel/tests/test_per_token_group_quant_8bit.py
sgl-kernel/tests/test_per_token_group_quant_8bit.py
+2
-2
sgl-kernel/tests/test_rotary_embedding.py
sgl-kernel/tests/test_rotary_embedding.py
+0
-3
No files found.
sgl-kernel/tests/speculative/test_eagle_utils.py
View file @
ed01b451
...
...
@@ -49,7 +49,6 @@ def test_verify_tree_greedy():
if
torch
.
max
(
target_logits
[
i
][
j
])
<
10
:
target_logits
[
i
][
j
][
18
]
=
10
print
(
f
"
{
target_logits
=
}
"
)
target_predict
=
torch
.
argmax
(
target_logits
,
dim
=-
1
).
to
(
torch
.
int32
)
predict_shape
=
(
12
,)
...
...
@@ -65,12 +64,6 @@ def test_verify_tree_greedy():
)
# mutable
accept_token_num
=
torch
.
full
((
bs
,),
0
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# mutable
print
(
f
"
{
candidates
=
}
"
)
print
(
f
"
{
retrive_index
=
}
"
)
print
(
f
"
{
retrive_next_token
=
}
"
)
print
(
f
"
{
retrive_next_sibling
=
}
"
)
print
(
f
"
{
target_predict
=
}
"
)
verify_tree_greedy
(
predicts
=
predicts
,
accept_index
=
accept_index
,
...
...
@@ -82,10 +75,6 @@ def test_verify_tree_greedy():
target_predict
=
target_predict
,
)
print
(
f
"
{
predicts
=
}
"
)
print
(
f
"
{
accept_index
=
}
"
)
print
(
f
"
{
accept_token_num
=
}
"
)
# Check the expected output.
assert
predicts
.
tolist
()
==
[
3
,
-
1
,
-
1
,
4
,
5
,
18
,
11
,
-
1
,
-
1
,
-
1
,
12
,
18
]
assert
accept_index
.
tolist
()
==
[
...
...
sgl-kernel/tests/speculative/test_speculative_sampling.py
View file @
ed01b451
...
...
@@ -3,18 +3,47 @@ import torch
import
torch.nn.functional
as
F
from
sgl_kernel
import
tree_speculative_sampling_target_only
test_cases
=
[
(
1
,
1
,
[
3
,
-
1
,
-
1
,
4
,
5
,
18
,
11
,
-
1
,
-
1
,
-
1
,
12
,
18
],
[[
0
,
3
,
4
,
5
],
[
6
,
10
,
11
,
-
1
]],
[
3
,
2
],
),
(
0
,
# threshold_single
0
,
# threshold_acc
[
1
,
2
,
18
,
-
1
,
-
1
,
-
1
,
11
,
-
1
,
-
1
,
-
1
,
12
,
18
],
[[
0
,
1
,
2
,
-
1
],
[
6
,
10
,
11
,
-
1
]],
[
2
,
2
],
),
]
@
pytest
.
mark
.
parametrize
(
"threshold_single, threshold_acc, expected_predicts, expected_accept_index, expected_accept_token_num"
,
test_cases
,
)
def
test_tree_speculative_sampling_target_only
(
threshold_single
,
threshold_acc
,
expected_predicts
,
expected_accept_index
,
expected_accept_token_num
,
):
"""
Tests the tree_speculative_sampling_target_only function using Pytest parameterization.
"""
device
=
"cuda"
def
test_tree_speculative_sampling_target_only
(
threshold_single
=
1
,
threshold_acc
=
1
):
print
(
f
"
\n
============= run test:
{
threshold_single
=
}
{
threshold_acc
=
}
==============
\n
"
)
candidates
=
torch
.
tensor
(
[
[
0
,
1
,
2
,
3
,
4
,
5
],
[
7
,
8
,
9
,
10
,
11
,
12
],
],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
device
=
device
,
)
retrive_index
=
torch
.
tensor
(
[
...
...
@@ -22,7 +51,7 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc
[
6
,
7
,
8
,
9
,
10
,
11
],
],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
device
=
device
,
)
retrive_next_token
=
torch
.
tensor
(
[
...
...
@@ -30,7 +59,7 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc
[
4
,
2
,
3
,
-
1
,
5
,
-
1
],
],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
device
=
device
,
)
retrive_next_sibling
=
torch
.
tensor
(
[
...
...
@@ -38,45 +67,34 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc
[
-
1
,
-
1
,
-
1
,
-
1
,
1
,
-
1
],
],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
device
=
device
,
)
target_logits
=
torch
.
full
((
2
,
6
,
20
),
1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
target_logits
=
torch
.
full
((
2
,
6
,
20
),
1
,
dtype
=
torch
.
float32
,
device
=
device
)
target_logits
[
0
,
0
,
3
]
=
10
target_logits
[
0
,
3
,
4
]
=
10
target_logits
[
0
,
4
,
5
]
=
10
target_logits
[
1
,
0
,
11
]
=
10
target_logits
[
1
,
4
,
12
]
=
10
for
i
in
range
(
target_logits
.
shape
[
0
]):
for
j
in
range
(
target_logits
.
shape
[
1
]):
if
torch
.
max
(
target_logits
[
i
][
j
])
<
10
:
target_logits
[
i
][
j
][
18
]
=
10
if
torch
.
max
(
target_logits
[
i
,
j
])
<
10
:
target_logits
[
i
,
j
,
18
]
=
10
temperatures
=
torch
.
tensor
([
0.01
,
0.01
],
dtype
=
torch
.
float32
,
device
=
"cuda"
)
predict_shape
=
(
12
,)
temperatures
=
torch
.
tensor
([
0.01
,
0.01
],
dtype
=
torch
.
float32
,
device
=
device
)
bs
,
num_draft_tokens
=
candidates
.
shape
num_spec_step
=
len
(
expected_accept_index
[
0
])
predict_shape
=
(
len
(
expected_predicts
),)
bs
=
candidates
.
shape
[
0
]
num_spec_step
=
4
num_draft_tokens
=
candidates
.
shape
[
1
]
predicts
=
torch
.
full
(
predict_shape
,
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# mutable
accept_index
=
torch
.
full
(
(
bs
,
num_spec_step
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# mutable
accept_token_num
=
torch
.
full
((
bs
,),
0
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# mutable
predicts
=
torch
.
full
(
predict_shape
,
-
1
,
dtype
=
torch
.
int32
,
device
=
device
)
accept_index
=
torch
.
full
((
bs
,
num_spec_step
),
-
1
,
dtype
=
torch
.
int32
,
device
=
device
)
accept_token_num
=
torch
.
full
((
bs
,),
0
,
dtype
=
torch
.
int32
,
device
=
device
)
expanded_temperature
=
temperatures
.
unsqueeze
(
1
).
unsqueeze
(
1
)
target_probs
=
F
.
softmax
(
target_logits
/
expanded_temperature
,
dim
=-
1
)
draft_probs
=
torch
.
full_like
(
target_probs
,
0
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
coins
=
torch
.
rand
(
bs
,
num_draft_tokens
,
device
=
"cuda"
).
to
(
torch
.
float32
)
print
(
f
"
{
candidates
=
}
"
)
print
(
f
"
{
retrive_index
=
}
"
)
print
(
f
"
{
retrive_next_token
=
}
"
)
print
(
f
"
{
retrive_next_sibling
=
}
"
)
print
(
f
"
{
coins
=
}
"
)
draft_probs
=
torch
.
full_like
(
target_probs
,
0
,
dtype
=
torch
.
float32
,
device
=
device
)
coins
=
torch
.
rand
(
bs
,
num_draft_tokens
,
device
=
device
,
dtype
=
torch
.
float32
)
tree_speculative_sampling_target_only
(
predicts
=
predicts
,
...
...
@@ -94,24 +112,15 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc
deterministic
=
True
,
)
print
(
f
"
{
predicts
=
}
"
)
print
(
f
"
{
accept_index
=
}
"
)
print
(
f
"
{
accept_token_num
=
}
"
)
if
threshold_single
==
1
and
threshold_acc
==
1
:
assert
predicts
.
tolist
()
==
[
3
,
-
1
,
-
1
,
4
,
5
,
18
,
11
,
-
1
,
-
1
,
-
1
,
12
,
18
]
assert
accept_index
.
tolist
()
==
[
[
0
,
3
,
4
,
5
],
[
6
,
10
,
11
,
-
1
],
]
assert
accept_token_num
.
tolist
()
==
[
3
,
2
]
elif
threshold_single
==
0
and
threshold_acc
==
0
:
assert
predicts
.
tolist
()
==
[
1
,
2
,
18
,
-
1
,
-
1
,
-
1
,
11
,
-
1
,
-
1
,
-
1
,
12
,
18
]
assert
accept_index
.
tolist
()
==
[
[
0
,
1
,
2
,
-
1
],
[
6
,
10
,
11
,
-
1
],
]
assert
accept_token_num
.
tolist
()
==
[
2
,
2
]
assert
(
predicts
.
tolist
()
==
expected_predicts
),
f
"Predicts mismatch for thresholds (
{
threshold_single
}
,
{
threshold_acc
}
)"
assert
(
accept_index
.
tolist
()
==
expected_accept_index
),
f
"Accept index mismatch for thresholds (
{
threshold_single
}
,
{
threshold_acc
}
)"
assert
(
accept_token_num
.
tolist
()
==
expected_accept_token_num
),
f
"Accept token num mismatch for thresholds (
{
threshold_single
}
,
{
threshold_acc
}
)"
if
__name__
==
"__main__"
:
...
...
sgl-kernel/tests/test_fp8_blockwise_gemm.py
View file @
ed01b451
...
...
@@ -79,7 +79,6 @@ def _test_accuracy_once(M, N, K, out_dtype, device):
rtol
=
0.02
atol
=
1
torch
.
testing
.
assert_close
(
o
,
o1
,
rtol
=
rtol
,
atol
=
atol
)
print
(
f
"M=
{
M
}
, N=
{
N
}
, K=
{
K
}
, out_dtype=
{
out_dtype
}
: OK"
)
@
pytest
.
mark
.
parametrize
(
"M"
,
[
1
,
3
,
5
,
127
,
128
,
512
,
1024
,
4096
])
...
...
sgl-kernel/tests/test_int8_gemm.py
View file @
ed01b451
...
...
@@ -28,7 +28,6 @@ def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
o
=
int8_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
o1
=
torch_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
torch
.
testing
.
assert_close
(
o
,
o1
)
print
(
f
"M=
{
M
}
, N=
{
N
}
, K=
{
K
}
, with_bias=
{
with_bias
}
, out_dtype=
{
out_dtype
}
: OK"
)
@
pytest
.
mark
.
parametrize
(
"M"
,
[
1
,
16
,
32
,
64
,
128
,
512
,
1024
,
4096
,
8192
])
...
...
sgl-kernel/tests/test_lightning_attention_decode.py
View file @
ed01b451
...
...
@@ -70,8 +70,6 @@ def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim
ref_output
,
rtol
=
rtol
,
atol
=
atol
,
msg
=
f
"Output mismatch for batch_size=
{
batch_size
}
, num_heads=
{
num_heads
}
, "
f
"dim=
{
dim
}
, embed_dim=
{
embed_dim
}
, dtype=
{
dtype
}
"
,
)
torch
.
testing
.
assert_close
(
...
...
@@ -79,8 +77,6 @@ def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim
ref_new_kv
,
rtol
=
rtol
,
atol
=
atol
,
msg
=
f
"New KV mismatch for batch_size=
{
batch_size
}
, num_heads=
{
num_heads
}
, "
f
"dim=
{
dim
}
, embed_dim=
{
embed_dim
}
, dtype=
{
dtype
}
"
,
)
...
...
sgl-kernel/tests/test_moe_topk_softmax.py
View file @
ed01b451
...
...
@@ -42,12 +42,10 @@ def test_topk_softmax(num_tokens, num_experts, topk):
topk_weights_ref
,
topk_weights
,
atol
=
1e-3
,
rtol
=
1e-3
),
f
"Weights mismatch: torch=
{
topk_indices_ref
}
vs SGLang=
{
topk_weights
}
"
assert
torch
.
equal
(
topk_indices_ref
,
topk_indices
assert
torch
.
allclose
(
topk_indices_ref
.
int
()
,
topk_indices
,
atol
=
0
,
rtol
=
0
),
f
"Indices mismatch: torch=
{
topk_indices_ref
}
, SGLang=
{
topk_indices
}
"
print
(
"✅ Native torch and custom kernel implementations match."
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
sgl-kernel/tests/test_per_token_group_quant_8bit.py
View file @
ed01b451
...
...
@@ -304,10 +304,10 @@ def test_per_token_group_quant_with_column_major(
scale_tma_aligned
=
scale_tma_aligned
,
)
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
x_q_triton
.
to
(
torch
.
float32
),
x_q_sglang
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
)
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
x_s_triton
.
contiguous
(),
x_s_sglang
.
contiguous
(),
rtol
=
1e-3
,
atol
=
1e-5
)
...
...
sgl-kernel/tests/test_rotary_embedding.py
View file @
ed01b451
...
...
@@ -187,9 +187,6 @@ def test_correctness(
pos_ids
,
query_flashinfer
,
key_flashinfer
)
print
(
query_ref_out
)
print
(
query_flashinfer_out
)
torch
.
testing
.
assert_close
(
query_ref_out
,
query_flashinfer_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment