Unverified Commit f127355a authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Add batch test for draft extend (#6672)

parent bdb962d7
......@@ -577,5 +577,67 @@ class TestEAGLEServerTriton(TestEAGLEServer):
)
class TestEAGLEDraftExtend(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
1,
"--speculative-eagle-topk",
1,
"--speculative-num-draft-tokens",
2,
"--max-running-requests",
4,
"--attention-backend",
"fa3",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_one_batch_accept_length(self):
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
url = self.base_url + "/generate"
data = {
"text": prompts,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 512,
},
}
response = requests.post(url, json=data)
self.assertEqual(response.status_code, 200)
outputs = response.json()
for i in range(len(prompts)):
output = outputs[i]
if "spec_verify_ct" in output["meta_info"]:
acc_length = (
output["meta_info"]["completion_tokens"]
/ output["meta_info"]["spec_verify_ct"]
)
else:
acc_length = 1.0
print(f"{acc_length=}")
self.assertGreater(acc_length, 1.50)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment