Commit fdf9bf98 authored by lizhigong's avatar lizhigong
Browse files

fix bug for decode len > max model len

parent 09553647
......@@ -414,12 +414,12 @@ class LLMEngine:
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
if self.zero_overhead:
assert os.environ.get('HIP_ALLOC_INITIALIZE') == '0'
self.async_d2h = None
self.last_record = None
self.async_event = torch.cuda.Event(enable_timing=False)
self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
self.q_recorder = queue.Queue()
self.q_recorder.put(None) # None is use for first step ignore
self.thread_running = True
self.sem_m2s = threading.Semaphore(0) # main to scheduler thread
self.zero_thread.start()
......@@ -1317,9 +1317,6 @@ class LLMEngine:
else:
seq.append_token_id(sample.output_token, sample.logprobs)
def trans_last_output_tensor(self, last_output) -> torch.Tensor:
return None
def finish_thread(self):
if self.zero_overhead:
self.thread_running = False
......@@ -1348,6 +1345,11 @@ class LLMEngine:
self.async_d2h = last_outputs_tensor.to('cpu', non_blocking=True)
self.async_event.record()
self.q_recorder.put(self.last_record)
else:
self.q_recorder.put(None)
if len(seq_group_metadata_list) == 0:
self.last_record = None
continue
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment