Commit fdf9bf98 authored by lizhigong's avatar lizhigong
Browse files

fix bug for decode len > max model len

parent 09553647
...@@ -414,12 +414,12 @@ class LLMEngine: ...@@ -414,12 +414,12 @@ class LLMEngine:
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1' self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
if self.zero_overhead: if self.zero_overhead:
assert os.environ.get('HIP_ALLOC_INITIALIZE') == '0'
self.async_d2h = None self.async_d2h = None
self.last_record = None self.last_record = None
self.async_event = torch.cuda.Event(enable_timing=False) self.async_event = torch.cuda.Event(enable_timing=False)
self.zero_thread = threading.Thread(target=self.thread_zero_overhead) self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
self.q_recorder = queue.Queue() self.q_recorder = queue.Queue()
self.q_recorder.put(None) # None is use for first step ignore
self.thread_running = True self.thread_running = True
self.sem_m2s = threading.Semaphore(0) # main to scheduler thread self.sem_m2s = threading.Semaphore(0) # main to scheduler thread
self.zero_thread.start() self.zero_thread.start()
...@@ -1317,9 +1317,6 @@ class LLMEngine: ...@@ -1317,9 +1317,6 @@ class LLMEngine:
else: else:
seq.append_token_id(sample.output_token, sample.logprobs) seq.append_token_id(sample.output_token, sample.logprobs)
def trans_last_output_tensor(self, last_output) -> torch.Tensor:
return None
def finish_thread(self): def finish_thread(self):
if self.zero_overhead: if self.zero_overhead:
self.thread_running = False self.thread_running = False
...@@ -1348,6 +1345,11 @@ class LLMEngine: ...@@ -1348,6 +1345,11 @@ class LLMEngine:
self.async_d2h = last_outputs_tensor.to('cpu', non_blocking=True) self.async_d2h = last_outputs_tensor.to('cpu', non_blocking=True)
self.async_event.record() self.async_event.record()
self.q_recorder.put(self.last_record) self.q_recorder.put(self.last_record)
else:
self.q_recorder.put(None)
if len(seq_group_metadata_list) == 0:
self.last_record = None
continue
finished_requests_ids = self.scheduler[ finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids() virtual_engine].get_and_reset_finished_requests_ids()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment