Commit a495fc3b authored by zhuwenwen's avatar zhuwenwen
Browse files

fix zero overhead to support chunk prefill

parent fe1c4016
......@@ -299,7 +299,10 @@ class ZeroOverheadEngine(LLMEngine):
last_sampler = self.last_record[1]
spec_step = get_spec_step()
if spec_step == SpecStepKind.KIND_DEFAULT:
if last_sampler.sampled_token_ids_tensor is not None:
self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True)
else:
self.async_d2h = None
elif spec_step == SpecStepKind.SCORE_DECODE:
self.async_d2h = last_sampler.to('cpu', non_blocking=True)
self.async_event.record()
......@@ -367,6 +370,7 @@ class ZeroOverheadEngine(LLMEngine):
ctx.scheduler_outputs = scheduler_outputs
if spec_step == SpecStepKind.KIND_DEFAULT:
self.async_event.synchronize()
if self.async_d2h is not None:
self._fix_last_step(
outputs, last_sampler, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
......
......@@ -99,13 +99,15 @@ class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
if spec_step == SpecStepKind.KIND_DEFAULT:
update_indices = []
select_indices = []
query_idx = 0
for i, seq_id in enumerate(self.req_ids):
for j, seq_id_ in enumerate(last_sampler.seq_ids):
if seq_id == seq_id_:
select_indices.append(j)
update_indices.append(i)
update_indices.append(query_idx)
break
if len(select_indices) > 0:
query_idx += model_input.query_lens[i]
if len(select_indices) > 0 and last_sampler.sampled_token_ids_tensor is not None:
select_indices = async_tensor_h2d(select_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
......
from importlib.util import find_spec
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple
import torch
from vllm import envs
from vllm.model_executor.layers.sampler import MultinomialSamplesType, SampleMetadataType, \
from vllm.distributed.parallel_state import get_tp_group
from vllm.model_executor.layers.sampler import MaybeDeferredSampleResultType, MultinomialSamplesType, SampleMetadataType, \
SampleResultArgsType, SampleResultType, SampleResultsDictType, SampleReturnType, Sampler, \
SamplerOutput, _apply_min_p, _apply_min_tokens_penalty, _apply_top_k_top_p, _build_sampler_output, \
SamplerOutput, _apply_min_p, _apply_min_tokens_penalty, _apply_top_k_top_p, \
_modify_greedy_probs_inplace, _top_k_top_p_multinomial_with_flashinfer, get_logprobs, _multinomial
from vllm.model_executor.layers.utils import apply_penalties
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors, SequenceGroupToSample
from vllm.sampling_params import SamplingType
from vllm.sequence import VLLM_INVALID_TOKEN_ID
from vllm.sequence import VLLM_INVALID_TOKEN_ID, CompletionSequenceGroupOutput, PromptLogprobs, SampleLogprobs, SequenceOutput
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling
# yapf: disable
......@@ -275,10 +276,8 @@ def _sample_with_torch(
t: []
for t in SamplingType
}
last_sampler.seq_ids = []
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
last_sampler.seq_ids.append(seq_group.seq_ids[0])
sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
......@@ -430,3 +429,72 @@ def get_pythonized_sample_results(
sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups))
]
def _build_sampler_output(
maybe_deferred_sample_results: MaybeDeferredSampleResultType,
sampling_metadata: SamplingMetadata,
prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
sample_logprobs: Optional[List[SampleLogprobs]],
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
torch.Tensor]],
skip_sampler_cpu_output: bool = False,
logits: Optional[torch.Tensor] = None
) -> SamplerOutput:
"""Construct Python objects with the output of sampling.
Args:
on_device_tensors: Tuple containing on-device tensors with the
probabilities used in sampling and the sampled token ids. This
allows post-processing without copies to CPU/serialization, e.g. in
speculative decoding rejection sampling.
"""
sampler_output: List[CompletionSequenceGroupOutput] = []
last_sampler.seq_ids = []
if skip_sampler_cpu_output:
assert isinstance(maybe_deferred_sample_results, SampleResultArgsType)
deferred_sample_results_args = maybe_deferred_sample_results
else:
assert prompt_logprobs is not None
assert sample_logprobs is not None
assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
assert len(sampling_metadata.seq_groups) \
== len(maybe_deferred_sample_results) \
== len(prompt_logprobs) \
== len(sample_logprobs)
deferred_sample_results_args = None
for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
maybe_deferred_sample_results,
prompt_logprobs, sample_logprobs):
seq_ids = seq_group.seq_ids
next_token_ids, parent_ids = sample_result
seq_outputs: List[SequenceOutput] = []
for parent_id, next_token_id, logprobs in zip(
parent_ids, next_token_ids, group_sample_logprobs):
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id,
logprobs))
sampler_output.append(
CompletionSequenceGroupOutput(seq_outputs,
group_prompt_logprobs))
if len(seq_outputs) > 0:
last_sampler.seq_ids.append(seq_outputs[0].parent_seq_id)
# If not specified, store None values in SamplerOutput.
if on_device_tensors is not None:
(sampled_token_probs, logprobs_tensor,
sampled_token_ids) = on_device_tensors
else:
sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
None)
return SamplerOutput(
outputs=sampler_output,
sampled_token_probs=sampled_token_probs,
sampled_token_ids=sampled_token_ids,
logprobs=logprobs_tensor,
deferred_sample_results_args=deferred_sample_results_args,
logits=logits)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment