Commit 2c3f0e14 authored by lizhigong's avatar lizhigong
Browse files

fix error when cpu is slower than dcu

parent d4b6b8cc
......@@ -214,7 +214,8 @@ def run_vllm(
use_tqdm=False)
use_beam_search = False
print("sleep 1")
time.sleep(1)
if not use_beam_search:
if args.profile:
profile_dir = args.profile_result_dir
......
......@@ -1330,16 +1330,25 @@ class LLMEngine:
self.sem_m2s.acquire()
if not self.thread_running:
break
last_outputs_ids = None
last_outputs_tensor = None
if self.last_record is not None:
last_output = self.last_record[0][0]
last_outputs_ids, last_outputs_tensor = last_output.sampler_out_ids, last_output.sampler_out_tenosr
self.async_d2h = last_outputs_tensor.to('cpu', non_blocking=True)
self.async_event.record()
self.q_recorder.put(self.last_record)
virtual_engine = 0
ctx = self.scheduler_contexts[virtual_engine]
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
......@@ -1348,15 +1357,6 @@ class LLMEngine:
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
last_outputs_ids = None
last_outputs_tensor = None
if self.last_record is not None:
last_output = self.last_record[0][0]
last_outputs_ids, last_outputs_tensor = last_output.sampler_out_ids, last_output.sampler_out_tenosr
self.async_d2h = last_outputs_tensor.to('cpu', non_blocking=True)
self.async_event.record()
self.q_recorder.put(self.last_record)
last_sampled_token_ids = \
self._get_last_sampled_token_ids(virtual_engine)
......@@ -1388,6 +1388,7 @@ class LLMEngine:
def zero_overhead_step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
if not self.thread_running:
self.zero_thread.join()
self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
self.thread_running = True
self.zero_thread.start()
......
......@@ -243,6 +243,9 @@ class LLM:
engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter()
def __del__(self):
self.llm_engine.finish_thread()
@staticmethod
def get_engine_class() -> Type[LLMEngine]:
......@@ -1408,8 +1411,7 @@ class LLM:
if use_tqdm:
pbar.close()
self.llm_engine.finish_thread()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
......
......@@ -492,6 +492,7 @@ def _greedy_sample(
"Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs))
if d2d_data.zero_overhead:
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] #place holder token id
else:
next_token_ids = [samples_lst[sample_idx]]
......@@ -534,7 +535,8 @@ def _random_sample(
# Prompt phase.
parent_ids = [0] * sampling_params.n
if d2d_data.zero_overhead:
next_token_ids = [0] * sampling_params.n
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] * sampling_params.n #place holder token id
else:
next_token_ids = random_samples[
sample_idx, :sampling_params.n].tolist()
......@@ -542,7 +544,8 @@ def _random_sample(
# Generation phase.
parent_ids = list(range(num_parent_seqs))
if d2d_data.zero_overhead:
next_token_ids = [0] * num_parent_seqs
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] * num_parent_seqs #place holder token id
else:
next_token_ids = random_samples[sample_idx:sample_idx +
num_parent_seqs, 0].tolist()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment