Commit d4b6b8cc authored by lizhigong's avatar lizhigong
Browse files

support serving in use zero overhead

parent 024e595d
......@@ -280,8 +280,6 @@ class _AsyncLLMEngine(LLMEngine):
"""
# these are cached outputs from previous iterations. None if on first
# iteration
if self.zero_overhead:
return self.zero_overhead_step()
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
......@@ -728,7 +726,6 @@ class AsyncLLMEngine(EngineClient):
"""Kick the engine to process the waiting requests.
Returns True if there are in-progress requests."""
new_requests, aborted_requests = (
self._request_tracker.get_new_and_aborted_requests())
......@@ -748,7 +745,6 @@ class AsyncLLMEngine(EngineClient):
await self._engine_abort(aborted_requests)
request_outputs = await self.engine.step_async(virtual_engine)
# Put the outputs into the corresponding streams.
# If used as a callback, then already invoked inside
# LLMEngine's _process_model_outputs
......
......@@ -10,6 +10,7 @@ from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
import traceback
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
......@@ -1265,7 +1266,10 @@ class LLMEngine:
seq = seq_group.seqs[0]
for token_id, seq_id in zip(sample_out_list, sample_out_ids):
if seq.seq_id == seq_id:
if type(token_id) is list:
sample.output_token = token_id[0]
else:
sample.output_token = token_id
seq.fix_last_token_id(sample.output_token)
break
......@@ -1492,8 +1496,12 @@ class LLMEngine:
>>> if not (engine.has_unfinished_requests() or example_inputs):
>>> break
"""
#traceback.print_stack()
if self.zero_overhead:
return self.zero_overhead_step()
out = self.zero_overhead_step()
if out is None: #the first step need launch twice
out = self.zero_overhead_step()
return out
if self.parallel_config.pipeline_parallel_size > 1:
raise NotImplementedError(
......
......@@ -1388,8 +1388,6 @@ class LLM:
total_out_toks = 0
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
if step_outputs is None:
continue
for output in step_outputs:
if output.finished:
outputs.append(output)
......
......@@ -477,6 +477,7 @@ def _greedy_sample(
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
if not d2d_data.zero_overhead:
samples_lst = samples.tolist()
sample_idx = 0
results: SampleResultType = []
......@@ -490,6 +491,9 @@ def _greedy_sample(
assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs))
if d2d_data.zero_overhead:
next_token_ids = [0] #place holder token id
else:
next_token_ids = [samples_lst[sample_idx]]
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
......@@ -713,12 +717,12 @@ def get_pythonized_sample_results(
sample_result_args.beam_search_logprobs,
sample_result_args.sample_results_dict,
)
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
d2d_data.random_samples = greedy_samples
sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
d2d_data.random_samples = multinomial_samples[sampling_type]#记录random_samples的数据
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment