import multiprocessing as mp import os import random import threading import time import unittest from types import SimpleNamespace from typing import List, Optional import requests import torch import sglang as sgl from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval from sglang.test.runners import DEFAULT_PROMPTS, SRTRunner from sglang.test.test_utils import ( DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, ) torch_dtype = torch.float16 prefill_tolerance = 5e-2 decode_tolerance: float = 5e-2 class TestEAGLEEngine(unittest.TestCase): BASE_CONFIG = { "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, "speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, "speculative_algorithm": "EAGLE", "speculative_num_steps": 5, "speculative_eagle_topk": 4, "speculative_num_draft_tokens": 8, "mem_fraction_static": 0.7, "cuda_graph_max_bs": 5, } NUM_CONFIGS = 3 def setUp(self): self.prompt = "Today is a sunny day and I like" self.sampling_params = {"temperature": 0, "max_new_tokens": 8} ref_engine = sgl.Engine( model_path=self.BASE_CONFIG["model_path"], cuda_graph_max_bs=1 ) self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"] ref_engine.shutdown() def test_correctness(self): configs = [ # Basic config self.BASE_CONFIG, # Disable cuda graph {**self.BASE_CONFIG, "disable_cuda_graph": True}, # Chunked prefill {**self.BASE_CONFIG, "chunked_prefill_size": 4}, ] for i, config in enumerate(configs[: self.NUM_CONFIGS]): with self.subTest(i=i): print(f"{config=}") engine = sgl.Engine(**config, log_level="info", decode_log_interval=10) try: self._test_single_generation(engine) self._test_batch_generation(engine) self._test_eos_token(engine) self._test_acc_length(engine) finally: engine.shutdown() print("=" * 100) def _test_single_generation(self, engine): output = engine.generate(self.prompt, self.sampling_params)["text"] print(f"{output=}, {self.ref_output=}") self.assertEqual(output, self.ref_output) def _test_batch_generation(self, engine): prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] params = {"temperature": 0, "max_new_tokens": 50} outputs = engine.generate(prompts, params) for prompt, output in zip(prompts, outputs): print(f"Prompt: {prompt}") print(f"Generated: {output['text']}") print("-" * 40) print(f"{engine.get_server_info()=}") avg_spec_accept_length = engine.get_server_info()["avg_spec_accept_length"] print(f"{avg_spec_accept_length=}") self.assertGreater(avg_spec_accept_length, 1.9) def _test_eos_token(self, engine): prompt = "[INST] <>\nYou are a helpful assistant.\n<>\nToday is a sunny day and I like [/INST]" params = { "temperature": 0, "max_new_tokens": 1024, "skip_special_tokens": False, } tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) output = engine.generate(prompt, params)["text"] print(f"{output=}") tokens = tokenizer.encode(output, truncation=False) self.assertNotIn(tokenizer.eos_token_id, tokens) def _test_acc_length(self, engine): prompt = [ "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:" ] sampling_params = {"temperature": 0, "max_new_tokens": 512} output = engine.generate(prompt, sampling_params) output = output[0] if "spec_verify_ct" in output["meta_info"]: acc_length = ( output["meta_info"]["completion_tokens"] / output["meta_info"]["spec_verify_ct"] ) else: acc_length = 1.0 speed = ( output["meta_info"]["completion_tokens"] / output["meta_info"]["e2e_latency"] ) print(f"{acc_length=}") self.assertGreater(acc_length, 3.6) class TestEAGLEEngineTokenMap(unittest.TestCase): BASE_CONFIG = { "model_path": "meta-llama/Meta-Llama-3-8B-Instruct", "speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B", "speculative_algorithm": "EAGLE", "speculative_num_steps": 5, "speculative_eagle_topk": 4, "speculative_num_draft_tokens": 8, "speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt", "mem_fraction_static": 0.7, "cuda_graph_max_bs": 5, } NUM_CONFIGS = 1 class TestEAGLEServer(unittest.TestCase): PROMPTS = [ "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like[/INST]" '[INST] <>\\nYou are a helpful assistant.\\n<>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]', "[INST] <>\\nYou are a helpful assistant.\\n<>\\nSummarize Russell Brunson's Perfect Webinar Script...[/INST]", "[INST] <>\\nYou are a helpful assistant.\\n<>\\nwho are you?[/INST]", "[INST] <>\\nYou are a helpful assistant.\\n<>\\nwhere are you from?[/INST]", ] @classmethod def setUpClass(cls): cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--speculative-algorithm", "EAGLE", "--speculative-draft-model-path", DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, "--speculative-num-steps", 5, "--speculative-eagle-topk", 8, "--speculative-num-draft-tokens", 64, "--mem-fraction-static", 0.7, "--chunked-prefill-size", 128, "--max-running-requests", 8, ], ) @classmethod def tearDownClass(cls): kill_process_tree(cls.process.pid) def send_request(self): time.sleep(random.uniform(0, 2)) for prompt in self.PROMPTS: url = self.base_url + "/generate" data = { "text": prompt, "sampling_params": { "temperature": 0, "max_new_tokens": 1024, }, } response = requests.post(url, json=data) assert response.status_code == 200 def send_requests_abort(self): for prompt in self.PROMPTS: try: time.sleep(random.uniform(0, 2)) url = self.base_url + "/generate" data = { "model": "base", "text": prompt, "sampling_params": { "temperature": 0, "max_new_tokens": 1024, }, } # set timeout = 1s, mock disconnected requests.post(url, json=data, timeout=1) except Exception as e: print(e) pass def test_request_abort(self): concurrency = 4 threads = [ threading.Thread(target=self.send_request) for _ in range(concurrency) ] + [ threading.Thread(target=self.send_requests_abort) for _ in range(concurrency) ] for worker in threads: worker.start() for p in threads: p.join() def test_gsm8k(self): server_info = requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( num_shots=5, data_path=None, num_questions=200, max_new_tokens=512, parallel=128, host="http://127.0.0.1", port=int(self.base_url.split(":")[-1]), ) metrics = run_eval(args) print(f"{metrics=}") self.assertGreater(metrics["accuracy"], 0.20) server_info = requests.get(self.base_url + "/get_server_info") avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] print(f"{avg_spec_accept_length=}") self.assertGreater(avg_spec_accept_length, 2.9) # Wait a little bit so that the memory check happens. time.sleep(4) class TestEAGLERetract(TestEAGLEServer): @classmethod def setUpClass(cls): # These config helps find a leak. os.environ["SGLANG_CI_SMALL_KV_SIZE"] = "4500" cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--speculative-algorithm", "EAGLE", "--speculative-draft-model-path", DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, "--speculative-num-steps", 5, "--speculative-eagle-topk", 8, "--speculative-num-draft-tokens", 64, "--mem-fraction-static", 0.7, "--chunked-prefill-size", 128, "--max-running-requests", 64, ], ) class TestEAGLEServerTriton(TestEAGLEServer): @classmethod def setUpClass(cls): cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--speculative-algorithm", "EAGLE", "--speculative-draft-model-path", DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, "--speculative-num-steps", 5, "--speculative-eagle-topk", 8, "--speculative-num-draft-tokens", 64, "--mem-fraction-static", 0.7, "--attention-backend", "triton", "--max-running-requests", 8, ], ) if __name__ == "__main__": unittest.main()