Unverified Commit a47bf391 authored by justdoit's avatar justdoit Committed by GitHub
Browse files

[Eagle2] Fix multiple concurrent request crashes (#2730)

parent b1706469
...@@ -245,9 +245,10 @@ class EAGLEDraftInput(SpecInfo): ...@@ -245,9 +245,10 @@ class EAGLEDraftInput(SpecInfo):
) # (b, topk) ) # (b, topk)
topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
selected_input_index = ( selected_input_index = topk_cs_index.flatten() // self.topk + torch.arange(
topk_cs_index.flatten() // self.topk 0, batch.batch_size() * self.topk, step=self.topk, device="cuda"
) # shape: (b * topk) ).repeat_interleave(self.topk)
batch.spec_info.hidden_states = batch.spec_info.hidden_states[ batch.spec_info.hidden_states = batch.spec_info.hidden_states[
selected_input_index, : selected_input_index, :
] ]
...@@ -336,6 +337,7 @@ class EAGLEDraftInput(SpecInfo): ...@@ -336,6 +337,7 @@ class EAGLEDraftInput(SpecInfo):
triton.next_power_of_2(self.spec_steps + 1), triton.next_power_of_2(self.spec_steps + 1),
) )
batch.seq_lens_sum = sum(batch.seq_lens)
batch.input_ids = self.verified_id batch.input_ids = self.verified_id
self.verified_id = new_verified_id self.verified_id = new_verified_id
...@@ -439,7 +441,14 @@ class EAGLEDraftInput(SpecInfo): ...@@ -439,7 +441,14 @@ class EAGLEDraftInput(SpecInfo):
return kv_indices, cum_kv_seq_len, qo_indptr, None return kv_indices, cum_kv_seq_len, qo_indptr, None
def merge_batch(self, spec_info: EAGLEDraftInput): def merge_batch(self, spec_info: EAGLEDraftInput):
if self.hidden_states is None:
self.hidden_states = spec_info.hidden_states
self.verified_id = spec_info.verified_id
self.sample_output = spec_info.sample_output
self.prev_mode = spec_info.prev_mode
return
if spec_info.hidden_states is None:
return
self.hidden_states = torch.cat( self.hidden_states = torch.cat(
[self.hidden_states, spec_info.hidden_states], axis=0 [self.hidden_states, spec_info.hidden_states], axis=0
) )
......
...@@ -169,6 +169,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -169,6 +169,8 @@ class EAGLEWorker(TpModelWorker):
if not isinstance(reqs, List): if not isinstance(reqs, List):
reqs = [reqs] reqs = [reqs]
for req in reqs: for req in reqs:
if req.rid not in self.finish_extend_len:
continue
req_len = ( req_len = (
len(req.origin_input_ids) len(req.origin_input_ids)
+ len(req.output_ids) + len(req.output_ids)
......
import multiprocessing
import random
import time
import unittest import unittest
import requests
from transformers import AutoConfig, AutoTokenizer from transformers import AutoConfig, AutoTokenizer
import sglang as sgl import sglang as sgl
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestEAGLEEngine(unittest.TestCase): class TestEAGLEEngine(unittest.TestCase):
...@@ -64,5 +74,114 @@ class TestEAGLEEngine(unittest.TestCase): ...@@ -64,5 +74,114 @@ class TestEAGLEEngine(unittest.TestCase):
assert tokenizer.eos_token_id not in tokens assert tokenizer.eos_token_id not in tokens
prompts = [
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like[/INST]"
'[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]',
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nSummarize Russell Brunson's Perfect Webinar Script...[/INST]",
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nwho are you?[/INST]",
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nwhere are you from?[/INST]",
]
def process(server_url: str):
time.sleep(random.uniform(0, 2))
for prompt in prompts:
url = server_url
data = {
"model": "base",
"text": prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 1024,
},
}
response = requests.post(url, json=data)
assert response.status_code == 200
def abort_process(server_url: str):
for prompt in prompts:
try:
time.sleep(1)
url = server_url
data = {
"model": "base",
"text": prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 1024,
},
}
# set timeout = 1s,mock disconnected
requests.post(url, json=data, timeout=1)
except:
pass
class TestEAGLELaunchServer(unittest.TestCase):
@classmethod
def setUpClass(cls):
speculative_draft_model_path = "lmzheng/sglang-EAGLE-llama2-chat-7B"
cls.model = "meta-llama/Llama-2-7b-chat-hf"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
speculative_draft_model_path,
"--speculative-num-steps",
"3",
"--speculative-eagle-topk",
"4",
"--speculative-num-draft-tokens",
"16",
"--served-model-name",
"base",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_eagle_server_concurrency(self):
concurrency = 4
processes = [
multiprocessing.Process(
target=process,
kwargs={"server_url": self.base_url + "/generate"},
)
for _ in range(concurrency)
]
for worker in processes:
worker.start()
for p in processes:
p.join()
def test_eagle_server_request_abort(self):
concurrency = 4
processes = [
multiprocessing.Process(
target=process,
kwargs={"server_url": self.base_url + "/generate"},
)
for _ in range(concurrency)
] + [
multiprocessing.Process(
target=abort_process,
kwargs={"server_url": self.base_url + "/generate"},
)
for _ in range(concurrency)
]
for worker in processes:
worker.start()
for p in processes:
p.join()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment