Commit 411d255e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'origin/v0.9.2-dev_mtp_sampler' into v0.9.2-dev

parents 18b4e6f3 33e33aa7
......@@ -8,6 +8,7 @@ from typing import Optional, cast
import numpy as np
import torch
from vllm import envs
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams
......@@ -16,6 +17,10 @@ from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
LogitBiasLogitsProcessor,
LogitsProcessorManager,
MinPLogitsProcessor,
MinTokensLogitsProcessor,
MoveDirectionality,
init_builtin_logitsprocs)
from vllm.v1.sample.metadata import SamplingMetadata
......@@ -192,6 +197,10 @@ class InputBatch:
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: set[str] = set()
# Track whether sampling metadata is currently expanded to
# per-token shape (spec decode reject path).
self._sampling_metadata_is_expanded = False
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
dtype=np.int32)
......@@ -593,7 +602,7 @@ class InputBatch:
del self._req_ids[self.num_reqs:]
del self.req_output_token_ids[self.num_reqs:]
def refresh_metadata(self):
def refresh_metadata(self, repeat_counts: Optional[torch.Tensor] = None):
"""Apply batch updates, reset input batch at end of step
* Apply batch add/remove/permute to logits procs' states
......@@ -602,8 +611,17 @@ class InputBatch:
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
for logit_proc in self.logitsprocs.all:
logit_proc.update_state(batch_update)
if batch_update:
self.sampling_metadata = self._make_sampling_metadata()
needs_rebuild = (batch_update
or repeat_counts is not None
or self._sampling_metadata_is_expanded)
if needs_rebuild:
if repeat_counts is None:
self.sampling_metadata = self._make_sampling_metadata()
else:
self.sampling_metadata = self._make_sampling_metadata_expanded(
repeat_counts)
self._sampling_metadata_is_expanded = repeat_counts is not None
# Expanded metadata is built on demand; do not cache a copy here.
def _make_sampling_metadata(self) -> SamplingMetadata:
num_reqs = self.num_reqs
......@@ -666,6 +684,105 @@ class InputBatch:
logitsprocs=self.logitsprocs,
)
def _make_sampling_metadata_expanded(
self, repeat_counts: torch.Tensor
) -> SamplingMetadata:
num_reqs = self.num_reqs
repeat_counts_cpu = repeat_counts
all_greedy = self.all_greedy
all_random = self.all_random
# For reject-sampling optimization, force greedy sampling to keep
# rejection sampler assumptions (per-request shapes) intact.
def _expand_cpu_to_gpu(
t: Optional[torch.Tensor],
*,
dtype: Optional[torch.dtype] = None,
) -> Optional[torch.Tensor]:
if t is None:
return None
base = t[:num_reqs]
if repeat_counts_cpu is not None:
base = base.repeat_interleave(repeat_counts_cpu, dim=0)
return base.to(device=self.device,
dtype=dtype if dtype is not None else None,
non_blocking=True)
needs_prompt_token_ids = (not self.no_penalties or
(self.num_reqs > 0
and self.logits_processing_needs_token_ids))
if needs_prompt_token_ids:
# The prompt tokens are used only for applying penalties or
# step pooling during the sampling/pooling process.
# Hence copy these tensors only when there are requests which
# need penalties/step_pooler to be applied.
prompt_token_ids = self._make_prompt_token_ids_tensor(
repeat_counts_cpu)
else:
prompt_token_ids = None
allowed_token_ids_mask: Optional[torch.Tensor] = None
if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None
allowed_token_ids_mask = self.allowed_token_ids_mask_cpu_tensor
# Expand per-request metadata to per-token shape when repeat_counts
# is provided (spec decode reject-sampling path).
top_p_cpu = None if self.no_top_p else self.top_p_cpu_tensor
top_k_cpu = None if self.no_top_k else self.top_k_cpu_tensor
repeat_list = repeat_counts_cpu.tolist()
row_offsets: list[int] = []
total_rows = 0
for repeat in repeat_list:
row_offsets.append(total_rows)
total_rows += int(repeat)
expanded_output_token_ids: list[list[int]] = []
expanded_bad_words_token_ids: dict[int, list[list[int]]] = {}
expanded_generators: dict[int, torch.Generator] = {}
row_idx = 0
for req_idx in range(num_reqs):
repeat = int(repeat_list[req_idx])
if repeat <= 0:
continue
output_tokens = self.req_output_token_ids[req_idx]
assert output_tokens is not None
bad_words = self.bad_words_token_ids.get(req_idx)
generator = self.generators.get(req_idx)
for _ in range(repeat):
expanded_output_token_ids.append(output_tokens)
if bad_words is not None:
expanded_bad_words_token_ids[row_idx] = bad_words
if generator is not None:
expanded_generators[row_idx] = generator
row_idx += 1
return SamplingMetadata(
temperature=_expand_cpu_to_gpu(
None if all_greedy else self.temperature_cpu_tensor),
all_greedy=all_greedy,
all_random=all_random,
top_p=_expand_cpu_to_gpu(top_p_cpu),
top_k=_expand_cpu_to_gpu(top_k_cpu, dtype=torch.int32),
generators=expanded_generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids,
frequency_penalties=(
None if self.no_penalties else _expand_cpu_to_gpu(
self.frequency_penalties_cpu_tensor)),
presence_penalties=(
None if self.no_penalties else _expand_cpu_to_gpu(
self.presence_penalties_cpu_tensor)),
repetition_penalties=(
None if self.no_penalties else _expand_cpu_to_gpu(
self.repetition_penalties_cpu_tensor)),
output_token_ids=expanded_output_token_ids,
no_penalties=self.no_penalties,
allowed_token_ids_mask=_expand_cpu_to_gpu(
allowed_token_ids_mask, dtype=torch.bool),
bad_words_token_ids=expanded_bad_words_token_ids,
logitsprocs=self.logitsprocs,
)
@property
def pooling_metadata(self) -> PoolingMetadata:
if len(self.pooling_params) == 0:
......@@ -685,7 +802,9 @@ class InputBatch:
pooling_params=pooling_params,
)
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
def _make_prompt_token_ids_tensor(
self, repeat_counts_cpu: Optional[torch.Tensor] = None
) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(
(self.num_reqs, max_prompt_len),
......@@ -700,6 +819,9 @@ class InputBatch:
# token_id of this value.
for i in range(self.num_reqs):
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
if repeat_counts_cpu is not None:
prompt_token_ids_cpu_tensor = prompt_token_ids_cpu_tensor \
.repeat_interleave(repeat_counts_cpu, dim=0)
return prompt_token_ids_cpu_tensor.to(device=self.device,
non_blocking=True)
......
......@@ -572,8 +572,21 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.input_batch.condense()
# Allow attention backend to reorder the batch, potentially
self._may_reorder_batch(scheduler_output)
# Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata()
# Refresh batch metadata with any pending updates. If we are in spec
# decode + reject mode, also expand sampling metadata to token shape
# using per-request repeat counts.
repeat_counts: Optional[torch.Tensor] = None
if envs.VLLM_REJECT_SAMPLE_OPT and \
scheduler_output.scheduled_spec_decode_tokens:
num_reqs = self.input_batch.num_reqs
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index.get(req_id)
if req_idx is not None:
num_draft_tokens[req_idx] = len(draft_token_ids)
repeat_counts = torch.from_numpy(num_draft_tokens).add_(1)
self.input_batch.refresh_metadata(repeat_counts)
def _get_cumsum_and_arange(
self,
......@@ -3360,8 +3373,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
)
sampler_output.sampled_token_ids = output_token_ids
else:
sampling_metadata.all_greedy = True
sampling_metadata.all_random = False
# sampling_metadata.all_greedy = True
# sampling_metadata.all_random = False
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment