Commit d4b6b8cc authored by lizhigong's avatar lizhigong
Browse files

support serving in use zero overhead

parent 024e595d
...@@ -280,8 +280,6 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -280,8 +280,6 @@ class _AsyncLLMEngine(LLMEngine):
""" """
# these are cached outputs from previous iterations. None if on first # these are cached outputs from previous iterations. None if on first
# iteration # iteration
if self.zero_overhead:
return self.zero_overhead_step()
cached_outputs = self.cached_scheduler_outputs[virtual_engine] cached_outputs = self.cached_scheduler_outputs[virtual_engine]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs scheduler_outputs = cached_outputs.scheduler_outputs
...@@ -728,7 +726,6 @@ class AsyncLLMEngine(EngineClient): ...@@ -728,7 +726,6 @@ class AsyncLLMEngine(EngineClient):
"""Kick the engine to process the waiting requests. """Kick the engine to process the waiting requests.
Returns True if there are in-progress requests.""" Returns True if there are in-progress requests."""
new_requests, aborted_requests = ( new_requests, aborted_requests = (
self._request_tracker.get_new_and_aborted_requests()) self._request_tracker.get_new_and_aborted_requests())
...@@ -748,7 +745,6 @@ class AsyncLLMEngine(EngineClient): ...@@ -748,7 +745,6 @@ class AsyncLLMEngine(EngineClient):
await self._engine_abort(aborted_requests) await self._engine_abort(aborted_requests)
request_outputs = await self.engine.step_async(virtual_engine) request_outputs = await self.engine.step_async(virtual_engine)
# Put the outputs into the corresponding streams. # Put the outputs into the corresponding streams.
# If used as a callback, then already invoked inside # If used as a callback, then already invoked inside
# LLMEngine's _process_model_outputs # LLMEngine's _process_model_outputs
......
...@@ -10,6 +10,7 @@ from collections import deque ...@@ -10,6 +10,7 @@ from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
import traceback
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable, from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
List, Mapping, NamedTuple, Optional) List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
...@@ -1265,7 +1266,10 @@ class LLMEngine: ...@@ -1265,7 +1266,10 @@ class LLMEngine:
seq = seq_group.seqs[0] seq = seq_group.seqs[0]
for token_id, seq_id in zip(sample_out_list, sample_out_ids): for token_id, seq_id in zip(sample_out_list, sample_out_ids):
if seq.seq_id == seq_id: if seq.seq_id == seq_id:
sample.output_token = token_id[0] if type(token_id) is list:
sample.output_token = token_id[0]
else:
sample.output_token = token_id
seq.fix_last_token_id(sample.output_token) seq.fix_last_token_id(sample.output_token)
break break
...@@ -1492,8 +1496,12 @@ class LLMEngine: ...@@ -1492,8 +1496,12 @@ class LLMEngine:
>>> if not (engine.has_unfinished_requests() or example_inputs): >>> if not (engine.has_unfinished_requests() or example_inputs):
>>> break >>> break
""" """
#traceback.print_stack()
if self.zero_overhead: if self.zero_overhead:
return self.zero_overhead_step() out = self.zero_overhead_step()
if out is None: #the first step need launch twice
out = self.zero_overhead_step()
return out
if self.parallel_config.pipeline_parallel_size > 1: if self.parallel_config.pipeline_parallel_size > 1:
raise NotImplementedError( raise NotImplementedError(
......
...@@ -1388,8 +1388,6 @@ class LLM: ...@@ -1388,8 +1388,6 @@ class LLM:
total_out_toks = 0 total_out_toks = 0
while self.llm_engine.has_unfinished_requests(): while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step() step_outputs = self.llm_engine.step()
if step_outputs is None:
continue
for output in step_outputs: for output in step_outputs:
if output.finished: if output.finished:
outputs.append(output) outputs.append(output)
......
...@@ -477,7 +477,8 @@ def _greedy_sample( ...@@ -477,7 +477,8 @@ def _greedy_sample(
same as the length of selected_seq_groups. If the corresponding same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], []) seq_group has do_sample=False, tuple contains ([], [])
""" """
samples_lst = samples.tolist() if not d2d_data.zero_overhead:
samples_lst = samples.tolist()
sample_idx = 0 sample_idx = 0
results: SampleResultType = [] results: SampleResultType = []
for seq_group in selected_seq_groups: for seq_group in selected_seq_groups:
...@@ -490,7 +491,10 @@ def _greedy_sample( ...@@ -490,7 +491,10 @@ def _greedy_sample(
assert num_parent_seqs == 1, ( assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.") "Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs)) parent_ids = list(range(num_parent_seqs))
next_token_ids = [samples_lst[sample_idx]] if d2d_data.zero_overhead:
next_token_ids = [0] #place holder token id
else:
next_token_ids = [samples_lst[sample_idx]]
results.append((next_token_ids, parent_ids)) results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs sample_idx += num_parent_seqs
return results return results
...@@ -713,12 +717,12 @@ def get_pythonized_sample_results( ...@@ -713,12 +717,12 @@ def get_pythonized_sample_results(
sample_result_args.beam_search_logprobs, sample_result_args.beam_search_logprobs,
sample_result_args.sample_results_dict, sample_result_args.sample_results_dict,
) )
for sampling_type in SamplingType: for sampling_type in SamplingType:
if sampling_type not in sample_metadata: if sampling_type not in sample_metadata:
continue continue
(seq_group_id, seq_groups) = sample_metadata[sampling_type] (seq_group_id, seq_groups) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY: if sampling_type == SamplingType.GREEDY:
d2d_data.random_samples = greedy_samples
sample_results = _greedy_sample(seq_groups, greedy_samples) sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
d2d_data.random_samples = multinomial_samples[sampling_type]#记录random_samples的数据 d2d_data.random_samples = multinomial_samples[sampling_type]#记录random_samples的数据
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment