test_eagle_infer.py 10.9 KB
Newer Older
1
import multiprocessing as mp
2
import os
3
import random
4
import threading
5
import time
6
import unittest
7
from types import SimpleNamespace
8
from typing import List, Optional
9

10
import requests
11
import torch
12

13
import sglang as sgl
14
from sglang.srt.hf_transformers_utils import get_tokenizer
15
from sglang.srt.utils import kill_process_tree
16
from sglang.test.few_shot_gsm8k import run_eval
17
from sglang.test.runners import DEFAULT_PROMPTS, SRTRunner
18
from sglang.test.test_utils import (
19
20
    DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
    DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
21
22
23
24
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
    popen_launch_server,
)
25

26
27
28
torch_dtype = torch.float16
prefill_tolerance = 5e-2
decode_tolerance: float = 5e-2
29

30
31

class TestEAGLEEngine(unittest.TestCase):
32
33
34
35
36
    BASE_CONFIG = {
        "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
        "speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
        "speculative_algorithm": "EAGLE",
        "speculative_num_steps": 5,
37
38
        "speculative_eagle_topk": 4,
        "speculative_num_draft_tokens": 8,
39
        "mem_fraction_static": 0.7,
40
        "cuda_graph_max_bs": 5,
41
    }
42
    NUM_CONFIGS = 3
43

44
45
46
    def setUp(self):
        self.prompt = "Today is a sunny day and I like"
        self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
47

48
49
50
        ref_engine = sgl.Engine(
            model_path=self.BASE_CONFIG["model_path"], cuda_graph_max_bs=1
        )
51
        self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
52
53
        ref_engine.shutdown()

54
    def test_correctness(self):
55
        configs = [
56
            # Basic config
57
            self.BASE_CONFIG,
58
            # Disable cuda graph
59
            {**self.BASE_CONFIG, "disable_cuda_graph": True},
60
61
            # Chunked prefill
            {**self.BASE_CONFIG, "chunked_prefill_size": 4},
62
        ]
63

64
65
66
67
        for i, config in enumerate(configs[: self.NUM_CONFIGS]):
            with self.subTest(i=i):
                print(f"{config=}")
                engine = sgl.Engine(**config, log_level="info", decode_log_interval=10)
68
                try:
69
                    self._test_single_generation(engine)
70
                    self._test_batch_generation(engine)
71
72
                    self._test_eos_token(engine)
                    self._test_acc_length(engine)
73
74
                finally:
                    engine.shutdown()
75
                print("=" * 100)
76

77
    def _test_single_generation(self, engine):
78
79
80
81
        output = engine.generate(self.prompt, self.sampling_params)["text"]
        print(f"{output=}, {self.ref_output=}")
        self.assertEqual(output, self.ref_output)

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    def _test_batch_generation(self, engine):
        prompts = [
            "Hello, my name is",
            "The president of the United States is",
            "The capital of France is",
            "The future of AI is",
        ]
        params = {"temperature": 0, "max_new_tokens": 50}

        outputs = engine.generate(prompts, params)
        for prompt, output in zip(prompts, outputs):
            print(f"Prompt: {prompt}")
            print(f"Generated: {output['text']}")
            print("-" * 40)

        print(f"{engine.get_server_info()=}")

        avg_spec_accept_length = engine.get_server_info()["avg_spec_accept_length"]
        print(f"{avg_spec_accept_length=}")
        self.assertGreater(avg_spec_accept_length, 1.9)

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    def _test_eos_token(self, engine):
        prompt = "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\nToday is a sunny day and I like [/INST]"
        params = {
            "temperature": 0,
            "max_new_tokens": 1024,
            "skip_special_tokens": False,
        }

        tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
        output = engine.generate(prompt, params)["text"]
        print(f"{output=}")

        tokens = tokenizer.encode(output, truncation=False)
        self.assertNotIn(tokenizer.eos_token_id, tokens)

118
119
120
    def _test_acc_length(self, engine):
        prompt = [
            "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
121
        ]
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        sampling_params = {"temperature": 0, "max_new_tokens": 512}
        output = engine.generate(prompt, sampling_params)
        output = output[0]

        if "spec_verify_ct" in output["meta_info"]:
            acc_length = (
                output["meta_info"]["completion_tokens"]
                / output["meta_info"]["spec_verify_ct"]
            )
        else:
            acc_length = 1.0

        speed = (
            output["meta_info"]["completion_tokens"]
            / output["meta_info"]["e2e_latency"]
        )
        print(f"{acc_length=}")
        self.assertGreater(acc_length, 3.6)
140

141

142
143
144
145
146
147
148
149
150
151
152
153
154
class TestEAGLEEngineTokenMap(unittest.TestCase):
    BASE_CONFIG = {
        "model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
        "speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B",
        "speculative_algorithm": "EAGLE",
        "speculative_num_steps": 5,
        "speculative_eagle_topk": 4,
        "speculative_num_draft_tokens": 8,
        "speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
        "mem_fraction_static": 0.7,
        "cuda_graph_max_bs": 5,
    }
    NUM_CONFIGS = 1
155
156


157
class TestEAGLEServer(unittest.TestCase):
158
159
160
161
162
163
164
165
    PROMPTS = [
        "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like[/INST]"
        '[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]',
        "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nSummarize Russell Brunson's Perfect Webinar Script...[/INST]",
        "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nwho are you?[/INST]",
        "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nwhere are you from?[/INST]",
    ]

166
167
168
169
    @classmethod
    def setUpClass(cls):
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
170
            DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
171
172
173
174
175
176
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=[
                "--speculative-algorithm",
                "EAGLE",
                "--speculative-draft-model-path",
177
                DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
178
                "--speculative-num-steps",
179
                5,
180
                "--speculative-eagle-topk",
181
                8,
182
                "--speculative-num-draft-tokens",
183
                64,
184
                "--mem-fraction-static",
185
                0.7,
186
                "--chunked-prefill-size",
187
188
189
                128,
                "--max-running-requests",
                8,
190
191
192
193
194
195
196
            ],
        )

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)

197
198
    def send_request(self):
        time.sleep(random.uniform(0, 2))
199
        for prompt in self.PROMPTS:
200
201
202
203
204
205
206
207
208
209
210
211
            url = self.base_url + "/generate"
            data = {
                "text": prompt,
                "sampling_params": {
                    "temperature": 0,
                    "max_new_tokens": 1024,
                },
            }
            response = requests.post(url, json=data)
            assert response.status_code == 200

    def send_requests_abort(self):
212
        for prompt in self.PROMPTS:
213
214
215
216
217
218
219
220
221
222
223
            try:
                time.sleep(random.uniform(0, 2))
                url = self.base_url + "/generate"
                data = {
                    "model": "base",
                    "text": prompt,
                    "sampling_params": {
                        "temperature": 0,
                        "max_new_tokens": 1024,
                    },
                }
224
                # set timeout = 1s, mock disconnected
225
226
227
228
229
230
                requests.post(url, json=data, timeout=1)
            except Exception as e:
                print(e)
                pass

    def test_request_abort(self):
231
        concurrency = 4
232
233
        threads = [
            threading.Thread(target=self.send_request) for _ in range(concurrency)
234
        ] + [
235
            threading.Thread(target=self.send_requests_abort)
236
237
            for _ in range(concurrency)
        ]
238
        for worker in threads:
239
            worker.start()
240
        for p in threads:
241
242
            p.join()

243
    def test_gsm8k(self):
244
245
        server_info = requests.get(self.base_url + "/flush_cache")

246
247
248
249
250
251
252
253
254
        args = SimpleNamespace(
            num_shots=5,
            data_path=None,
            num_questions=200,
            max_new_tokens=512,
            parallel=128,
            host="http://127.0.0.1",
            port=int(self.base_url.split(":")[-1]),
        )
255

256
257
258
259
        metrics = run_eval(args)
        print(f"{metrics=}")
        self.assertGreater(metrics["accuracy"], 0.20)

260
261
262
263
        server_info = requests.get(self.base_url + "/get_server_info")
        avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
        print(f"{avg_spec_accept_length=}")
        self.assertGreater(avg_spec_accept_length, 2.9)
264

265
266
        # Wait a little bit so that the memory check happens.
        time.sleep(4)
267
268


269
class TestEAGLERetract(TestEAGLEServer):
270
271
    @classmethod
    def setUpClass(cls):
272
273
        # These config helps find a leak.
        os.environ["SGLANG_CI_SMALL_KV_SIZE"] = "4500"
274
275
276
277
278
279
280
281
282
283
284
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=[
                "--speculative-algorithm",
                "EAGLE",
                "--speculative-draft-model-path",
                DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
                "--speculative-num-steps",
285
                5,
286
                "--speculative-eagle-topk",
287
                8,
288
                "--speculative-num-draft-tokens",
289
                64,
290
                "--mem-fraction-static",
291
                0.7,
292
                "--chunked-prefill-size",
293
                128,
294
                "--max-running-requests",
295
                64,
296
297
298
299
            ],
        )


300
301
302
303
304
305
306
307
308
309
310
311
312
313
class TestEAGLEServerTriton(TestEAGLEServer):
    @classmethod
    def setUpClass(cls):
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=[
                "--speculative-algorithm",
                "EAGLE",
                "--speculative-draft-model-path",
                DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
                "--speculative-num-steps",
314
                5,
315
                "--speculative-eagle-topk",
316
                8,
317
                "--speculative-num-draft-tokens",
318
                64,
319
                "--mem-fraction-static",
320
                0.7,
321
322
                "--attention-backend",
                "triton",
323
324
                "--max-running-requests",
                8,
325
326
327
328
            ],
        )


329
330
if __name__ == "__main__":
    unittest.main()