"examples/pytorch/vscode:/vscode.git/clone" did not exist on "fff3dd9593554051e77a051347ea25e05c078985"
Unverified Commit 6222e1c2 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

add disable cuda graph unit test for eagle 2 (#3412)

parent fad315cb
...@@ -30,51 +30,69 @@ class TestEAGLEEngine(unittest.TestCase): ...@@ -30,51 +30,69 @@ class TestEAGLEEngine(unittest.TestCase):
ref_output = ref_engine.generate(prompt, sampling_params)["text"] ref_output = ref_engine.generate(prompt, sampling_params)["text"]
ref_engine.shutdown() ref_engine.shutdown()
# Launch EAGLE engine # Test cases with different configurations
engine = sgl.Engine( configs = [
model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, # Original config
speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, {
speculative_algorithm="EAGLE", "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
speculative_num_steps=5, "speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
speculative_eagle_topk=8, "speculative_algorithm": "EAGLE",
speculative_num_draft_tokens=64, "speculative_num_steps": 5,
mem_fraction_static=0.7, "speculative_eagle_topk": 8,
) "speculative_num_draft_tokens": 64,
"mem_fraction_static": 0.7,
# Case 1: Test the output of EAGLE engine is the same as normal engine },
out1 = engine.generate(prompt, sampling_params)["text"] # Config with CUDA graph disabled
print(f"{out1=}, {ref_output=}") {
self.assertEqual(out1, ref_output) "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
# Case 2: Test the output of EAGLE engine does not contain unexpected EOS "speculative_algorithm": "EAGLE",
prompt = "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like [/INST]" "speculative_num_steps": 5,
sampling_params = { "speculative_eagle_topk": 8,
"temperature": 0, "speculative_num_draft_tokens": 64,
"max_new_tokens": 1024, "mem_fraction_static": 0.7,
"skip_special_tokens": False, "disable_cuda_graph": True,
} },
tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
out2 = engine.generate(prompt, sampling_params)["text"]
print(f"{out2=}")
tokens = tokenizer.encode(out2, truncation=False)
assert tokenizer.eos_token_id not in tokens
# Case 3: Batched prompts
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
] ]
sampling_params = {"temperature": 0, "max_new_tokens": 30}
outputs = engine.generate(prompts, sampling_params) for config in configs:
for prompt, output in zip(prompts, outputs): # Launch EAGLE engine
print("===============================") engine = sgl.Engine(**config)
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
# Case 1: Test the output of EAGLE engine is the same as normal engine
# Shutdown the engine out1 = engine.generate(prompt, sampling_params)["text"]
engine.shutdown() print(f"{out1=}, {ref_output=}")
self.assertEqual(out1, ref_output)
# Case 2: Test the output of EAGLE engine does not contain unexpected EOS
prompt = "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like [/INST]"
sampling_params = {
"temperature": 0,
"max_new_tokens": 1024,
"skip_special_tokens": False,
}
tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
out2 = engine.generate(prompt, sampling_params)["text"]
print(f"{out2=}")
tokens = tokenizer.encode(out2, truncation=False)
assert tokenizer.eos_token_id not in tokens
# Case 3: Batched prompts
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = {"temperature": 0, "max_new_tokens": 30}
outputs = engine.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
# Shutdown the engine
engine.shutdown()
prompts = [ prompts = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment