Unverified Commit a47bf391 authored by justdoit's avatar justdoit Committed by GitHub
Browse files

[Eagle2] Fix multiple concurrent request crashes (#2730)

parent b1706469
......@@ -245,9 +245,10 @@ class EAGLEDraftInput(SpecInfo):
) # (b, topk)
topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
selected_input_index = (
topk_cs_index.flatten() // self.topk
) # shape: (b * topk)
selected_input_index = topk_cs_index.flatten() // self.topk + torch.arange(
0, batch.batch_size() * self.topk, step=self.topk, device="cuda"
).repeat_interleave(self.topk)
batch.spec_info.hidden_states = batch.spec_info.hidden_states[
selected_input_index, :
]
......@@ -336,6 +337,7 @@ class EAGLEDraftInput(SpecInfo):
triton.next_power_of_2(self.spec_steps + 1),
)
batch.seq_lens_sum = sum(batch.seq_lens)
batch.input_ids = self.verified_id
self.verified_id = new_verified_id
......@@ -439,7 +441,14 @@ class EAGLEDraftInput(SpecInfo):
return kv_indices, cum_kv_seq_len, qo_indptr, None
def merge_batch(self, spec_info: EAGLEDraftInput):
if self.hidden_states is None:
self.hidden_states = spec_info.hidden_states
self.verified_id = spec_info.verified_id
self.sample_output = spec_info.sample_output
self.prev_mode = spec_info.prev_mode
return
if spec_info.hidden_states is None:
return
self.hidden_states = torch.cat(
[self.hidden_states, spec_info.hidden_states], axis=0
)
......
......@@ -169,6 +169,8 @@ class EAGLEWorker(TpModelWorker):
if not isinstance(reqs, List):
reqs = [reqs]
for req in reqs:
if req.rid not in self.finish_extend_len:
continue
req_len = (
len(req.origin_input_ids)
+ len(req.output_ids)
......
import multiprocessing
import random
import time
import unittest
import requests
from transformers import AutoConfig, AutoTokenizer
import sglang as sgl
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestEAGLEEngine(unittest.TestCase):
......@@ -64,5 +74,114 @@ class TestEAGLEEngine(unittest.TestCase):
assert tokenizer.eos_token_id not in tokens
prompts = [
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like[/INST]"
'[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]',
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nSummarize Russell Brunson's Perfect Webinar Script...[/INST]",
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nwho are you?[/INST]",
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nwhere are you from?[/INST]",
]
def process(server_url: str):
time.sleep(random.uniform(0, 2))
for prompt in prompts:
url = server_url
data = {
"model": "base",
"text": prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 1024,
},
}
response = requests.post(url, json=data)
assert response.status_code == 200
def abort_process(server_url: str):
for prompt in prompts:
try:
time.sleep(1)
url = server_url
data = {
"model": "base",
"text": prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 1024,
},
}
# set timeout = 1s,mock disconnected
requests.post(url, json=data, timeout=1)
except:
pass
class TestEAGLELaunchServer(unittest.TestCase):
@classmethod
def setUpClass(cls):
speculative_draft_model_path = "lmzheng/sglang-EAGLE-llama2-chat-7B"
cls.model = "meta-llama/Llama-2-7b-chat-hf"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
speculative_draft_model_path,
"--speculative-num-steps",
"3",
"--speculative-eagle-topk",
"4",
"--speculative-num-draft-tokens",
"16",
"--served-model-name",
"base",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_eagle_server_concurrency(self):
concurrency = 4
processes = [
multiprocessing.Process(
target=process,
kwargs={"server_url": self.base_url + "/generate"},
)
for _ in range(concurrency)
]
for worker in processes:
worker.start()
for p in processes:
p.join()
def test_eagle_server_request_abort(self):
concurrency = 4
processes = [
multiprocessing.Process(
target=process,
kwargs={"server_url": self.base_url + "/generate"},
)
for _ in range(concurrency)
] + [
multiprocessing.Process(
target=abort_process,
kwargs={"server_url": self.base_url + "/generate"},
)
for _ in range(concurrency)
]
for worker in processes:
worker.start()
for p in processes:
p.join()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment