Unverified Commit 60abdb3e authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

minor: cleanup test_eagle_infer (#3415)

parent 7b4e61ff
...@@ -20,30 +20,7 @@ from sglang.test.test_utils import ( ...@@ -20,30 +20,7 @@ from sglang.test.test_utils import (
class TestEAGLEEngine(unittest.TestCase): class TestEAGLEEngine(unittest.TestCase):
BASE_CONFIG = {
def test_eagle_accuracy(self):
prompt1 = "Today is a sunny day and I like"
sampling_params1 = {"temperature": 0, "max_new_tokens": 8}
# Get the reference output
ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
ref_output = ref_engine.generate(prompt1, sampling_params1)["text"]
ref_engine.shutdown()
# Test cases with different configurations
configs = [
# Original config
{
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"speculative_algorithm": "EAGLE",
"speculative_num_steps": 5,
"speculative_eagle_topk": 8,
"speculative_num_draft_tokens": 64,
"mem_fraction_static": 0.7,
},
# Config with CUDA graph disabled
{
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, "speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"speculative_algorithm": "EAGLE", "speculative_algorithm": "EAGLE",
...@@ -51,48 +28,70 @@ class TestEAGLEEngine(unittest.TestCase): ...@@ -51,48 +28,70 @@ class TestEAGLEEngine(unittest.TestCase):
"speculative_eagle_topk": 8, "speculative_eagle_topk": 8,
"speculative_num_draft_tokens": 64, "speculative_num_draft_tokens": 64,
"mem_fraction_static": 0.7, "mem_fraction_static": 0.7,
"disable_cuda_graph": True, }
},
def setUp(self):
self.prompt = "Today is a sunny day and I like"
self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
ref_engine.shutdown()
def test_eagle_accuracy(self):
configs = [
self.BASE_CONFIG,
{**self.BASE_CONFIG, "disable_cuda_graph": True},
] ]
for config in configs: for config in configs:
# Launch EAGLE engine with self.subTest(
cuda_graph=(
"enabled" if len(config) == len(self.BASE_CONFIG) else "disabled"
)
):
engine = sgl.Engine(**config) engine = sgl.Engine(**config)
try:
self._test_basic_generation(engine)
self._test_eos_token(engine)
self._test_batch_generation(engine)
finally:
engine.shutdown()
# Case 1: Test the output of EAGLE engine is the same as normal engine def _test_basic_generation(self, engine):
out1 = engine.generate(prompt1, sampling_params1)["text"] output = engine.generate(self.prompt, self.sampling_params)["text"]
print(f"{out1=}, {ref_output=}") print(f"{output=}, {self.ref_output=}")
self.assertEqual(out1, ref_output) self.assertEqual(output, self.ref_output)
# Case 2: Test the output of EAGLE engine does not contain unexpected EOS def _test_eos_token(self, engine):
prompt2 = "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like [/INST]" prompt = "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\nToday is a sunny day and I like [/INST]"
sampling_params2 = { params = {
"temperature": 0, "temperature": 0,
"max_new_tokens": 1024, "max_new_tokens": 1024,
"skip_special_tokens": False, "skip_special_tokens": False,
} }
tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
out2 = engine.generate(prompt2, sampling_params2)["text"] output = engine.generate(prompt, params)["text"]
print(f"{out2=}") print(f"{output=}")
tokens = tokenizer.encode(out2, truncation=False)
assert tokenizer.eos_token_id not in tokens tokens = tokenizer.encode(output, truncation=False)
self.assertNotIn(tokenizer.eos_token_id, tokens)
# Case 3: Batched prompts def _test_batch_generation(self, engine):
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
"The capital of France is", "The capital of France is",
"The future of AI is", "The future of AI is",
] ]
sampling_params3 = {"temperature": 0, "max_new_tokens": 30} params = {"temperature": 0, "max_new_tokens": 30}
outputs = engine.generate(prompts, sampling_params3)
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
# Shutdown the engine outputs = engine.generate(prompts, params)
engine.shutdown() for prompt, output in zip(prompts, outputs):
print(f"Prompt: {prompt}")
print(f"Generated: {output['text']}")
print("-" * 40)
prompts = [ prompts = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment