"tests/vscode:/vscode.git/clone" did not exist on "c3b40dc3e74dc0f552f76a01ae38b4f1385ad0af"
Commit 08c4bafa authored by lizhigong's avatar lizhigong
Browse files

add auto finish and restart thread when finish generate and start new generate

parent ca4ec0ce
...@@ -413,14 +413,13 @@ class LLMEngine: ...@@ -413,14 +413,13 @@ class LLMEngine:
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1' self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
if self.zero_overhead: if self.zero_overhead:
# self.step_switch = 0 # 0 step A 1 step B
# self.output_recorder = [None, None]
self.async_d2h = None self.async_d2h = None
self.last_record = None self.last_record = None
self.async_event = torch.cuda.Event(enable_timing=False) self.async_event = torch.cuda.Event(enable_timing=False)
self.zero_thread = threading.Thread(target=self.thread_zero_overhead) self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
self.q_recorder = queue.Queue() self.q_recorder = queue.Queue()
self.q_recorder.put(None) # None is use for first step ignore self.q_recorder.put(None) # None is use for first step ignore
self.thread_running = True
self.sem_m2s = threading.Semaphore(0) # main to scheduler thread self.sem_m2s = threading.Semaphore(0) # main to scheduler thread
self.zero_thread.start() self.zero_thread.start()
profile.StartTracer() profile.StartTracer()
...@@ -1317,9 +1316,16 @@ class LLMEngine: ...@@ -1317,9 +1316,16 @@ class LLMEngine:
def trans_last_output_tensor(self, last_output) -> torch.Tensor: def trans_last_output_tensor(self, last_output) -> torch.Tensor:
return None return None
def finish_thread(self):
if self.zero_overhead:
self.thread_running = False
self.sem_m2s.release()
def thread_zero_overhead(self): def thread_zero_overhead(self):
while True: while True:
self.sem_m2s.acquire() self.sem_m2s.acquire()
if not self.thread_running:
break
virtual_engine = 0 virtual_engine = 0
ctx = self.scheduler_contexts[virtual_engine] ctx = self.scheduler_contexts[virtual_engine]
...@@ -1377,6 +1383,10 @@ class LLMEngine: ...@@ -1377,6 +1383,10 @@ class LLMEngine:
self.last_record = [outputs, seq_group_metadata_list, scheduler_outputs] self.last_record = [outputs, seq_group_metadata_list, scheduler_outputs]
def zero_overhead_step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: def zero_overhead_step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
if not self.thread_running:
self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
self.thread_running = True
self.zero_thread.start()
self.sem_m2s.release() self.sem_m2s.release()
recode_output = self.q_recorder.get() recode_output = self.q_recorder.get()
if recode_output is None: # None is for the first step if recode_output is None: # None is for the first step
......
...@@ -1410,6 +1410,8 @@ class LLM: ...@@ -1410,6 +1410,8 @@ class LLM:
if use_tqdm: if use_tqdm:
pbar.close() pbar.close()
self.llm_engine.finish_thread()
# Sort the outputs by request ID. # Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than # This is necessary because some requests may be finished earlier than
# its previous requests. # its previous requests.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment