Commit a495fc3b authored by zhuwenwen's avatar zhuwenwen
Browse files

fix zero overhead to support chunk prefill

parent fe1c4016
...@@ -299,7 +299,10 @@ class ZeroOverheadEngine(LLMEngine): ...@@ -299,7 +299,10 @@ class ZeroOverheadEngine(LLMEngine):
last_sampler = self.last_record[1] last_sampler = self.last_record[1]
spec_step = get_spec_step() spec_step = get_spec_step()
if spec_step == SpecStepKind.KIND_DEFAULT: if spec_step == SpecStepKind.KIND_DEFAULT:
if last_sampler.sampled_token_ids_tensor is not None:
self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True) self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True)
else:
self.async_d2h = None
elif spec_step == SpecStepKind.SCORE_DECODE: elif spec_step == SpecStepKind.SCORE_DECODE:
self.async_d2h = last_sampler.to('cpu', non_blocking=True) self.async_d2h = last_sampler.to('cpu', non_blocking=True)
self.async_event.record() self.async_event.record()
...@@ -367,6 +370,7 @@ class ZeroOverheadEngine(LLMEngine): ...@@ -367,6 +370,7 @@ class ZeroOverheadEngine(LLMEngine):
ctx.scheduler_outputs = scheduler_outputs ctx.scheduler_outputs = scheduler_outputs
if spec_step == SpecStepKind.KIND_DEFAULT: if spec_step == SpecStepKind.KIND_DEFAULT:
self.async_event.synchronize() self.async_event.synchronize()
if self.async_d2h is not None:
self._fix_last_step( self._fix_last_step(
outputs, last_sampler, seq_group_metadata_list, outputs, last_sampler, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups) scheduler_outputs.scheduled_seq_groups)
......
...@@ -99,13 +99,15 @@ class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder): ...@@ -99,13 +99,15 @@ class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
if spec_step == SpecStepKind.KIND_DEFAULT: if spec_step == SpecStepKind.KIND_DEFAULT:
update_indices = [] update_indices = []
select_indices = [] select_indices = []
query_idx = 0
for i, seq_id in enumerate(self.req_ids): for i, seq_id in enumerate(self.req_ids):
for j, seq_id_ in enumerate(last_sampler.seq_ids): for j, seq_id_ in enumerate(last_sampler.seq_ids):
if seq_id == seq_id_: if seq_id == seq_id_:
select_indices.append(j) select_indices.append(j)
update_indices.append(i) update_indices.append(query_idx)
break break
if len(select_indices) > 0: query_idx += model_input.query_lens[i]
if len(select_indices) > 0 and last_sampler.sampled_token_ids_tensor is not None:
select_indices = async_tensor_h2d(select_indices, torch.long, select_indices = async_tensor_h2d(select_indices, torch.long,
self.runner.device, self.runner.device,
self.runner.pin_memory) self.runner.pin_memory)
......
from importlib.util import find_spec from importlib.util import find_spec
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple
import torch import torch
from vllm import envs from vllm import envs
from vllm.model_executor.layers.sampler import MultinomialSamplesType, SampleMetadataType, \ from vllm.distributed.parallel_state import get_tp_group
from vllm.model_executor.layers.sampler import MaybeDeferredSampleResultType, MultinomialSamplesType, SampleMetadataType, \
SampleResultArgsType, SampleResultType, SampleResultsDictType, SampleReturnType, Sampler, \ SampleResultArgsType, SampleResultType, SampleResultsDictType, SampleReturnType, Sampler, \
SamplerOutput, _apply_min_p, _apply_min_tokens_penalty, _apply_top_k_top_p, _build_sampler_output, \ SamplerOutput, _apply_min_p, _apply_min_tokens_penalty, _apply_top_k_top_p, \
_modify_greedy_probs_inplace, _top_k_top_p_multinomial_with_flashinfer, get_logprobs, _multinomial _modify_greedy_probs_inplace, _top_k_top_p_multinomial_with_flashinfer, get_logprobs, _multinomial
from vllm.model_executor.layers.utils import apply_penalties from vllm.model_executor.layers.utils import apply_penalties
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors, SequenceGroupToSample from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors, SequenceGroupToSample
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import VLLM_INVALID_TOKEN_ID from vllm.sequence import VLLM_INVALID_TOKEN_ID, CompletionSequenceGroupOutput, PromptLogprobs, SampleLogprobs, SequenceOutput
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling import flashinfer.sampling
# yapf: disable # yapf: disable
...@@ -275,10 +276,8 @@ def _sample_with_torch( ...@@ -275,10 +276,8 @@ def _sample_with_torch(
t: [] t: []
for t in SamplingType for t in SamplingType
} }
last_sampler.seq_ids = []
categorized_sample_indices = sampling_metadata.categorized_sample_indices categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
last_sampler.seq_ids.append(seq_group.seq_ids[0])
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i) categorized_seq_group_ids[sampling_type].append(i)
...@@ -430,3 +429,72 @@ def get_pythonized_sample_results( ...@@ -430,3 +429,72 @@ def get_pythonized_sample_results(
sample_results_dict.get(i, ([], [])) sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups)) for i in range(len(sampling_metadata.seq_groups))
] ]
def _build_sampler_output(
maybe_deferred_sample_results: MaybeDeferredSampleResultType,
sampling_metadata: SamplingMetadata,
prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
sample_logprobs: Optional[List[SampleLogprobs]],
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
torch.Tensor]],
skip_sampler_cpu_output: bool = False,
logits: Optional[torch.Tensor] = None
) -> SamplerOutput:
"""Construct Python objects with the output of sampling.
Args:
on_device_tensors: Tuple containing on-device tensors with the
probabilities used in sampling and the sampled token ids. This
allows post-processing without copies to CPU/serialization, e.g. in
speculative decoding rejection sampling.
"""
sampler_output: List[CompletionSequenceGroupOutput] = []
last_sampler.seq_ids = []
if skip_sampler_cpu_output:
assert isinstance(maybe_deferred_sample_results, SampleResultArgsType)
deferred_sample_results_args = maybe_deferred_sample_results
else:
assert prompt_logprobs is not None
assert sample_logprobs is not None
assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
assert len(sampling_metadata.seq_groups) \
== len(maybe_deferred_sample_results) \
== len(prompt_logprobs) \
== len(sample_logprobs)
deferred_sample_results_args = None
for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
maybe_deferred_sample_results,
prompt_logprobs, sample_logprobs):
seq_ids = seq_group.seq_ids
next_token_ids, parent_ids = sample_result
seq_outputs: List[SequenceOutput] = []
for parent_id, next_token_id, logprobs in zip(
parent_ids, next_token_ids, group_sample_logprobs):
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id,
logprobs))
sampler_output.append(
CompletionSequenceGroupOutput(seq_outputs,
group_prompt_logprobs))
if len(seq_outputs) > 0:
last_sampler.seq_ids.append(seq_outputs[0].parent_seq_id)
# If not specified, store None values in SamplerOutput.
if on_device_tensors is not None:
(sampled_token_probs, logprobs_tensor,
sampled_token_ids) = on_device_tensors
else:
sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
None)
return SamplerOutput(
outputs=sampler_output,
sampled_token_probs=sampled_token_probs,
sampled_token_ids=sampled_token_ids,
logprobs=logprobs_tensor,
deferred_sample_results_args=deferred_sample_results_args,
logits=logits)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment