Unverified Commit 0d4e3228 authored by William's avatar William Committed by GitHub
Browse files

[Feature] Add test for speculative_token_map (#4016)

parent 926f8efc
...@@ -31,6 +31,16 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm ...@@ -31,6 +31,16 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def load_token_map(token_map_path: str) -> List[int]:
if not os.path.exists(token_map_path):
cache_dir = snapshot_download(
os.path.dirname(token_map_path),
ignore_patterns=["*.bin", "*.safetensors"],
)
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
return torch.load(token_map_path)
class EAGLEWorker(TpModelWorker): class EAGLEWorker(TpModelWorker):
def __init__( def __init__(
...@@ -48,20 +58,12 @@ class EAGLEWorker(TpModelWorker): ...@@ -48,20 +58,12 @@ class EAGLEWorker(TpModelWorker):
server_args.disable_cuda_graph = True server_args.disable_cuda_graph = True
if server_args.speculative_token_map is not None: if server_args.speculative_token_map is not None:
if os.path.exists(server_args.speculative_token_map): self.hot_token_id = load_token_map(server_args.speculative_token_map)
self.hot_token_id = torch.load(server_args.speculative_token_map)
else:
cache_dir = snapshot_download(
os.path.dirname(server_args.speculative_token_map),
ignore_patterns=["*.bin", "*.safetensors"],
)
file_path = os.path.join(
cache_dir, os.path.basename(server_args.speculative_token_map)
)
self.hot_token_id = torch.load(file_path)
server_args.json_model_override_args = ( server_args.json_model_override_args = (
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}' f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
) )
else:
self.hot_token_id = None
super().__init__( super().__init__(
gpu_id=gpu_id, gpu_id=gpu_id,
...@@ -84,14 +86,12 @@ class EAGLEWorker(TpModelWorker): ...@@ -84,14 +86,12 @@ class EAGLEWorker(TpModelWorker):
# Share the embedding and lm_head # Share the embedding and lm_head
embed, head = self.target_worker.model_runner.model.get_embed_and_head() embed, head = self.target_worker.model_runner.model.get_embed_and_head()
if server_args.speculative_token_map is not None: if self.hot_token_id is not None:
head = head.clone() head = head.clone()
self.hot_token_id = torch.tensor( self.hot_token_id = torch.tensor(
self.hot_token_id, dtype=torch.int32, device=head.device self.hot_token_id, dtype=torch.int32, device=head.device
) )
head.data = head.data[self.hot_token_id] head.data = head.data[self.hot_token_id]
else:
self.hot_token_id = None
self.model_runner.model.set_embed_and_head(embed, head) self.model_runner.model.set_embed_and_head(embed, head)
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
......
...@@ -95,6 +95,67 @@ class TestEAGLEEngine(unittest.TestCase): ...@@ -95,6 +95,67 @@ class TestEAGLEEngine(unittest.TestCase):
print("-" * 40) print("-" * 40)
class TestEAGLEEngineTokenMap(unittest.TestCase):
BASE_CONFIG = {
"model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
"speculative_draft_model_path": "lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B",
"speculative_algorithm": "EAGLE",
"speculative_num_steps": 5,
"speculative_eagle_topk": 8,
"speculative_num_draft_tokens": 64,
"mem_fraction_static": 0.7,
"cuda_graph_max_bs": 4,
"dtype": "float16",
}
def setUp(self):
self.prompt = "Today is a sunny day and I like"
self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
ref_engine = sgl.Engine(model_path=self.BASE_CONFIG["model_path"])
self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
ref_engine.shutdown()
def test_token_map_accuracy(self):
configs = [
self.BASE_CONFIG,
{
**self.BASE_CONFIG,
"speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
},
]
for config in configs:
print("testing config: ", config)
with self.subTest(cuda_graph="enabled"):
engine = sgl.Engine(**config)
try:
self._test_basic_generation(engine)
self._test_batch_generation(engine)
finally:
engine.shutdown()
def _test_basic_generation(self, engine):
output = engine.generate(self.prompt, self.sampling_params)["text"]
print(f"{output=}, {self.ref_output=}")
self.assertEqual(output, self.ref_output)
def _test_batch_generation(self, engine):
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
params = {"temperature": 0, "max_new_tokens": 30}
outputs = engine.generate(prompts, params)
for prompt, output in zip(prompts, outputs):
print(f"Prompt: {prompt}")
print(f"Generated: {output['text']}")
print("-" * 40)
prompts = [ prompts = [
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like[/INST]" "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like[/INST]"
'[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]', '[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment