"src/graph/transform/cuda/cuda_to_block.hip" did not exist on "2647afc9b343bdafbf47b1c55ab2f8691bcee0a8"
Unverified Commit 73dfd2df authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[Test] Enhance radix cache test for spec cases (#12394)

parent df5192cf
import random
import requests
def gen_radix_tree(num_nodes=400, chunk_len=256):
num0 = num_nodes // 2
num1 = num_nodes - num0
nodes = [{"input_ids": [37] * 117, "decode_len": 217}]
for _ in range(num0):
parent = random.choice(nodes)
unique_len = random.randint(0, chunk_len)
decode_len = random.randint(0, chunk_len)
token_id = random.randint(0, 32000)
child = {
"input_ids": parent["input_ids"] + [token_id] * unique_len,
"decode_len": decode_len,
}
nodes.append(child)
while num1 > 0:
num_branch = random.randint(1, min(num1, 10))
parent = random.choice(nodes)
for _ in range(num_branch):
unique_len = random.randint(0, chunk_len)
decode_len = random.randint(0, chunk_len)
token_id = random.randint(0, 32000)
child = {
"input_ids": parent["input_ids"] + [token_id] * unique_len,
"decode_len": decode_len,
}
nodes.append(child)
num1 -= num_branch
random.shuffle(nodes)
return nodes
def run_radix_attention_test(base_url: str):
nodes = gen_radix_tree()
data = {
"input_ids": [node["input_ids"] for node in nodes],
"sampling_params": [
{"max_new_tokens": node["decode_len"], "temperature": 0} for node in nodes
],
}
res = requests.post(base_url + "/generate", json=data)
assert res.status_code == 200
...@@ -2,7 +2,7 @@ import unittest ...@@ -2,7 +2,7 @@ import unittest
from sglang.srt.sampling.sampling_params import MAX_LEN, get_max_seq_length from sglang.srt.sampling.sampling_params import MAX_LEN, get_max_seq_length
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.kit_matched_stop import MatchedStopMixin from sglang.test.kits.matched_stop_kit import MatchedStopMixin
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
......
...@@ -78,7 +78,7 @@ suites = { ...@@ -78,7 +78,7 @@ suites = {
TestFile("test_deterministic.py", 320), TestFile("test_deterministic.py", 320),
TestFile("test_eagle_infer_a.py", 370), TestFile("test_eagle_infer_a.py", 370),
TestFile("test_eagle_infer_b.py", 700), TestFile("test_eagle_infer_b.py", 700),
TestFile("test_eagle_infer_beta.py", 300), TestFile("test_eagle_infer_beta.py", 90),
TestFile("test_ebnf_constrained.py", 108), TestFile("test_ebnf_constrained.py", 108),
TestFile("test_eval_fp8_accuracy.py", 303), TestFile("test_eval_fp8_accuracy.py", 303),
TestFile("test_fa3.py", 376), TestFile("test_fa3.py", 376),
......
...@@ -4,7 +4,8 @@ from types import SimpleNamespace ...@@ -4,7 +4,8 @@ from types import SimpleNamespace
from sglang.srt.environ import envs from sglang.srt.environ import envs
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.kit_matched_stop import MatchedStopMixin from sglang.test.kits.matched_stop_kit import MatchedStopMixin
from sglang.test.kits.radix_cache_server_kit import run_radix_attention_test
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
...@@ -65,6 +66,9 @@ class TestEagleServerBase(CustomTestCase, MatchedStopMixin): ...@@ -65,6 +66,9 @@ class TestEagleServerBase(CustomTestCase, MatchedStopMixin):
def tearDownClass(cls): def tearDownClass(cls):
kill_process_tree(cls.process.pid) kill_process_tree(cls.process.pid)
def test_radix_attention(self):
run_radix_attention_test(self.base_url)
def test_gsm8k(self): def test_gsm8k(self):
args = SimpleNamespace( args = SimpleNamespace(
num_shots=5, num_shots=5,
......
import os
import random
import unittest import unittest
import requests from sglang.srt.environ import envs
from sglang.test.kits.radix_cache_server_kit import run_radix_attention_test
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
...@@ -15,52 +13,6 @@ from sglang.test.test_utils import ( ...@@ -15,52 +13,6 @@ from sglang.test.test_utils import (
) )
def gen_radix_tree(num_nodes=400, chunk_len=256):
num0 = num_nodes // 2
num1 = num_nodes - num0
nodes = [{"input_ids": [37] * 117, "decode_len": 217}]
for _ in range(num0):
parent = random.choice(nodes)
unique_len = random.randint(0, chunk_len)
decode_len = random.randint(0, chunk_len)
token_id = random.randint(0, 32000)
child = {
"input_ids": parent["input_ids"] + [token_id] * unique_len,
"decode_len": decode_len,
}
nodes.append(child)
while num1 > 0:
num_branch = random.randint(1, min(num1, 10))
parent = random.choice(nodes)
for _ in range(num_branch):
unique_len = random.randint(0, chunk_len)
decode_len = random.randint(0, chunk_len)
token_id = random.randint(0, 32000)
child = {
"input_ids": parent["input_ids"] + [token_id] * unique_len,
"decode_len": decode_len,
}
nodes.append(child)
num1 -= num_branch
random.shuffle(nodes)
return nodes
def run_test(base_url, nodes):
data = {
"input_ids": [node["input_ids"] for node in nodes],
"sampling_params": [
{"max_new_tokens": node["decode_len"], "temperature": 0} for node in nodes
],
}
res = requests.post(base_url + "/generate", json=data)
assert res.status_code == 200
class TestRadixCacheFCFS(CustomTestCase): class TestRadixCacheFCFS(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
...@@ -85,8 +37,7 @@ class TestRadixCacheFCFS(CustomTestCase): ...@@ -85,8 +37,7 @@ class TestRadixCacheFCFS(CustomTestCase):
kill_process_tree(cls.process.pid) kill_process_tree(cls.process.pid)
def test_radix_attention(self): def test_radix_attention(self):
nodes = gen_radix_tree() run_radix_attention_test(self.base_url)
run_test(self.base_url, nodes)
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.") @unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
...@@ -132,5 +83,5 @@ class TestRadixCacheNonOverlapLPM(TestRadixCacheFCFS): ...@@ -132,5 +83,5 @@ class TestRadixCacheNonOverlapLPM(TestRadixCacheFCFS):
if __name__ == "__main__": if __name__ == "__main__":
os.environ["SGLANG_TEST_RETRACT"] = "true" envs.SGLANG_TEST_RETRACT.set(True)
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment