Unverified Commit ed01b451 authored by PGFLMG's avatar PGFLMG Committed by GitHub
Browse files

[Misc] Clean sgl-kernel test (#5216)

parent d050df36
...@@ -49,7 +49,6 @@ def test_verify_tree_greedy(): ...@@ -49,7 +49,6 @@ def test_verify_tree_greedy():
if torch.max(target_logits[i][j]) < 10: if torch.max(target_logits[i][j]) < 10:
target_logits[i][j][18] = 10 target_logits[i][j][18] = 10
print(f"{target_logits=}")
target_predict = torch.argmax(target_logits, dim=-1).to(torch.int32) target_predict = torch.argmax(target_logits, dim=-1).to(torch.int32)
predict_shape = (12,) predict_shape = (12,)
...@@ -65,12 +64,6 @@ def test_verify_tree_greedy(): ...@@ -65,12 +64,6 @@ def test_verify_tree_greedy():
) # mutable ) # mutable
accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device="cuda") # mutable accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device="cuda") # mutable
print(f"{candidates=}")
print(f"{retrive_index=}")
print(f"{retrive_next_token=}")
print(f"{retrive_next_sibling=}")
print(f"{target_predict=}")
verify_tree_greedy( verify_tree_greedy(
predicts=predicts, predicts=predicts,
accept_index=accept_index, accept_index=accept_index,
...@@ -82,10 +75,6 @@ def test_verify_tree_greedy(): ...@@ -82,10 +75,6 @@ def test_verify_tree_greedy():
target_predict=target_predict, target_predict=target_predict,
) )
print(f"{predicts=}")
print(f"{accept_index=}")
print(f"{accept_token_num=}")
# Check the expected output. # Check the expected output.
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18] assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [ assert accept_index.tolist() == [
......
...@@ -3,18 +3,47 @@ import torch ...@@ -3,18 +3,47 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from sgl_kernel import tree_speculative_sampling_target_only from sgl_kernel import tree_speculative_sampling_target_only
test_cases = [
(
1,
1,
[3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18],
[[0, 3, 4, 5], [6, 10, 11, -1]],
[3, 2],
),
(
0, # threshold_single
0, # threshold_acc
[1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18],
[[0, 1, 2, -1], [6, 10, 11, -1]],
[2, 2],
),
]
@pytest.mark.parametrize(
"threshold_single, threshold_acc, expected_predicts, expected_accept_index, expected_accept_token_num",
test_cases,
)
def test_tree_speculative_sampling_target_only(
threshold_single,
threshold_acc,
expected_predicts,
expected_accept_index,
expected_accept_token_num,
):
"""
Tests the tree_speculative_sampling_target_only function using Pytest parameterization.
"""
device = "cuda"
def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc=1):
print(
f"\n============= run test: {threshold_single=} {threshold_acc=} ==============\n"
)
candidates = torch.tensor( candidates = torch.tensor(
[ [
[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5],
[7, 8, 9, 10, 11, 12], [7, 8, 9, 10, 11, 12],
], ],
dtype=torch.int32, dtype=torch.int32,
device="cuda", device=device,
) )
retrive_index = torch.tensor( retrive_index = torch.tensor(
[ [
...@@ -22,7 +51,7 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc ...@@ -22,7 +51,7 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc
[6, 7, 8, 9, 10, 11], [6, 7, 8, 9, 10, 11],
], ],
dtype=torch.int32, dtype=torch.int32,
device="cuda", device=device,
) )
retrive_next_token = torch.tensor( retrive_next_token = torch.tensor(
[ [
...@@ -30,7 +59,7 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc ...@@ -30,7 +59,7 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc
[4, 2, 3, -1, 5, -1], [4, 2, 3, -1, 5, -1],
], ],
dtype=torch.int32, dtype=torch.int32,
device="cuda", device=device,
) )
retrive_next_sibling = torch.tensor( retrive_next_sibling = torch.tensor(
[ [
...@@ -38,45 +67,34 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc ...@@ -38,45 +67,34 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc
[-1, -1, -1, -1, 1, -1], [-1, -1, -1, -1, 1, -1],
], ],
dtype=torch.int32, dtype=torch.int32,
device="cuda", device=device,
) )
target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device="cuda") target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device=device)
target_logits[0, 0, 3] = 10 target_logits[0, 0, 3] = 10
target_logits[0, 3, 4] = 10 target_logits[0, 3, 4] = 10
target_logits[0, 4, 5] = 10 target_logits[0, 4, 5] = 10
target_logits[1, 0, 11] = 10 target_logits[1, 0, 11] = 10
target_logits[1, 4, 12] = 10 target_logits[1, 4, 12] = 10
for i in range(target_logits.shape[0]): for i in range(target_logits.shape[0]):
for j in range(target_logits.shape[1]): for j in range(target_logits.shape[1]):
if torch.max(target_logits[i][j]) < 10: if torch.max(target_logits[i, j]) < 10:
target_logits[i][j][18] = 10 target_logits[i, j, 18] = 10
temperatures = torch.tensor([0.01, 0.01], dtype=torch.float32, device="cuda") temperatures = torch.tensor([0.01, 0.01], dtype=torch.float32, device=device)
predict_shape = (12,) bs, num_draft_tokens = candidates.shape
num_spec_step = len(expected_accept_index[0])
predict_shape = (len(expected_predicts),)
bs = candidates.shape[0] predicts = torch.full(predict_shape, -1, dtype=torch.int32, device=device)
num_spec_step = 4 accept_index = torch.full((bs, num_spec_step), -1, dtype=torch.int32, device=device)
num_draft_tokens = candidates.shape[1] accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device=device)
predicts = torch.full(
predict_shape, -1, dtype=torch.int32, device="cuda"
) # mutable
accept_index = torch.full(
(bs, num_spec_step), -1, dtype=torch.int32, device="cuda"
) # mutable
accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device="cuda") # mutable
expanded_temperature = temperatures.unsqueeze(1).unsqueeze(1) expanded_temperature = temperatures.unsqueeze(1).unsqueeze(1)
target_probs = F.softmax(target_logits / expanded_temperature, dim=-1) target_probs = F.softmax(target_logits / expanded_temperature, dim=-1)
draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device="cuda") draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device=device)
coins = torch.rand(bs, num_draft_tokens, device=device, dtype=torch.float32)
coins = torch.rand(bs, num_draft_tokens, device="cuda").to(torch.float32)
print(f"{candidates=}")
print(f"{retrive_index=}")
print(f"{retrive_next_token=}")
print(f"{retrive_next_sibling=}")
print(f"{coins=}")
tree_speculative_sampling_target_only( tree_speculative_sampling_target_only(
predicts=predicts, predicts=predicts,
...@@ -94,24 +112,15 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc ...@@ -94,24 +112,15 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc
deterministic=True, deterministic=True,
) )
print(f"{predicts=}") assert (
print(f"{accept_index=}") predicts.tolist() == expected_predicts
print(f"{accept_token_num=}") ), f"Predicts mismatch for thresholds ({threshold_single}, {threshold_acc})"
assert (
if threshold_single == 1 and threshold_acc == 1: accept_index.tolist() == expected_accept_index
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18] ), f"Accept index mismatch for thresholds ({threshold_single}, {threshold_acc})"
assert accept_index.tolist() == [ assert (
[0, 3, 4, 5], accept_token_num.tolist() == expected_accept_token_num
[6, 10, 11, -1], ), f"Accept token num mismatch for thresholds ({threshold_single}, {threshold_acc})"
]
assert accept_token_num.tolist() == [3, 2]
elif threshold_single == 0 and threshold_acc == 0:
assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 1, 2, -1],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [2, 2]
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -79,7 +79,6 @@ def _test_accuracy_once(M, N, K, out_dtype, device): ...@@ -79,7 +79,6 @@ def _test_accuracy_once(M, N, K, out_dtype, device):
rtol = 0.02 rtol = 0.02
atol = 1 atol = 1
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol) torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
@pytest.mark.parametrize("M", [1, 3, 5, 127, 128, 512, 1024, 4096]) @pytest.mark.parametrize("M", [1, 3, 5, 127, 128, 512, 1024, 4096])
......
...@@ -28,7 +28,6 @@ def _test_accuracy_once(M, N, K, with_bias, out_dtype, device): ...@@ -28,7 +28,6 @@ def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
torch.testing.assert_close(o, o1) torch.testing.assert_close(o, o1)
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192]) @pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192])
......
...@@ -70,8 +70,6 @@ def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim ...@@ -70,8 +70,6 @@ def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim
ref_output, ref_output,
rtol=rtol, rtol=rtol,
atol=atol, atol=atol,
msg=f"Output mismatch for batch_size={batch_size}, num_heads={num_heads}, "
f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}",
) )
torch.testing.assert_close( torch.testing.assert_close(
...@@ -79,8 +77,6 @@ def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim ...@@ -79,8 +77,6 @@ def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim
ref_new_kv, ref_new_kv,
rtol=rtol, rtol=rtol,
atol=atol, atol=atol,
msg=f"New KV mismatch for batch_size={batch_size}, num_heads={num_heads}, "
f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}",
) )
......
...@@ -42,12 +42,10 @@ def test_topk_softmax(num_tokens, num_experts, topk): ...@@ -42,12 +42,10 @@ def test_topk_softmax(num_tokens, num_experts, topk):
topk_weights_ref, topk_weights, atol=1e-3, rtol=1e-3 topk_weights_ref, topk_weights, atol=1e-3, rtol=1e-3
), f"Weights mismatch: torch={topk_indices_ref} vs SGLang={topk_weights}" ), f"Weights mismatch: torch={topk_indices_ref} vs SGLang={topk_weights}"
assert torch.equal( assert torch.allclose(
topk_indices_ref, topk_indices topk_indices_ref.int(), topk_indices, atol=0, rtol=0
), f"Indices mismatch: torch={topk_indices_ref}, SGLang={topk_indices}" ), f"Indices mismatch: torch={topk_indices_ref}, SGLang={topk_indices}"
print("✅ Native torch and custom kernel implementations match.")
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
...@@ -304,10 +304,10 @@ def test_per_token_group_quant_with_column_major( ...@@ -304,10 +304,10 @@ def test_per_token_group_quant_with_column_major(
scale_tma_aligned=scale_tma_aligned, scale_tma_aligned=scale_tma_aligned,
) )
assert torch.allclose( torch.testing.assert_close(
x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5 x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5
) )
assert torch.allclose( torch.testing.assert_close(
x_s_triton.contiguous(), x_s_sglang.contiguous(), rtol=1e-3, atol=1e-5 x_s_triton.contiguous(), x_s_sglang.contiguous(), rtol=1e-3, atol=1e-5
) )
......
...@@ -187,9 +187,6 @@ def test_correctness( ...@@ -187,9 +187,6 @@ def test_correctness(
pos_ids, query_flashinfer, key_flashinfer pos_ids, query_flashinfer, key_flashinfer
) )
print(query_ref_out)
print(query_flashinfer_out)
torch.testing.assert_close( torch.testing.assert_close(
query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2 query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment