test_radix_attention.py 3.68 KB
Newer Older
1
2
3
4
5
6
7
import os
import random
import unittest

import requests

from sglang.test.test_utils import (
Lianmin Zheng's avatar
Lianmin Zheng committed
8
    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
9
10
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
11
    CustomTestCase,
12
    kill_process_tree,
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    popen_launch_server,
)


def gen_radix_tree(num_nodes=400, chunk_len=256):
    num0 = num_nodes // 2
    num1 = num_nodes - num0
    nodes = [{"input_ids": [37] * 117, "decode_len": 217}]
    for _ in range(num0):
        parent = random.choice(nodes)
        unique_len = random.randint(0, chunk_len)
        decode_len = random.randint(0, chunk_len)
        token_id = random.randint(0, 32000)
        child = {
            "input_ids": parent["input_ids"] + [token_id] * unique_len,
            "decode_len": decode_len,
        }
        nodes.append(child)

    while num1 > 0:
        num_branch = random.randint(1, min(num1, 10))
        parent = random.choice(nodes)
        for _ in range(num_branch):
            unique_len = random.randint(0, chunk_len)
            decode_len = random.randint(0, chunk_len)
            token_id = random.randint(0, 32000)
            child = {
                "input_ids": parent["input_ids"] + [token_id] * unique_len,
                "decode_len": decode_len,
            }
            nodes.append(child)

        num1 -= num_branch

    random.shuffle(nodes)
    return nodes


def run_test(base_url, nodes):
    data = {
        "input_ids": [node["input_ids"] for node in nodes],
        "sampling_params": [
            {"max_new_tokens": node["decode_len"], "temperature": 0} for node in nodes
        ],
    }

    res = requests.post(base_url + "/generate", json=data)
    assert res.status_code == 200


63
class TestRadixCacheFCFS(CustomTestCase):
64
65
    @classmethod
    def setUpClass(cls):
Lianmin Zheng's avatar
Lianmin Zheng committed
66
        cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=[
                "--chunked-prefill-size",
                "128",
                "--max-total-tokens",
                "20000",
                "--schedule-policy",
                "fcfs",
            ],
        )

    @classmethod
    def tearDownClass(cls):
84
        kill_process_tree(cls.process.pid)
85
86
87
88
89
90
91
92
93

    def test_radix_attention(self):
        nodes = gen_radix_tree()
        run_test(self.base_url, nodes)


class TestRadixCacheLPM(TestRadixCacheFCFS):
    @classmethod
    def setUpClass(cls):
Lianmin Zheng's avatar
Lianmin Zheng committed
94
        cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=[
                "--chunked-prefill-size",
                "128",
                "--max-total-tokens",
                "20000",
                "--schedule-policy",
                "lpm",
            ],
        )


111
class TestRadixCacheNonOverlapLPM(TestRadixCacheFCFS):
112
113
    @classmethod
    def setUpClass(cls):
Lianmin Zheng's avatar
Lianmin Zheng committed
114
        cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
115
116
117
118
119
120
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=[
121
                "--disable-overlap-schedule",
122
123
124
125
126
127
128
129
130
131
                "--chunked-prefill-size",
                "128",
                "--max-total-tokens",
                "20000",
                "--schedule-policy",
                "lpm",
            ],
        )


132
133
134
if __name__ == "__main__":
    os.environ["SGLANG_TEST_RETRACT"] = "true"
    unittest.main()