Commit 411d255e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'origin/v0.9.2-dev_mtp_sampler' into v0.9.2-dev

parents 18b4e6f3 33e33aa7
...@@ -8,6 +8,7 @@ from typing import Optional, cast ...@@ -8,6 +8,7 @@ from typing import Optional, cast
import numpy as np import numpy as np
import torch import torch
from vllm import envs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -16,6 +17,10 @@ from vllm.utils import swap_dict_values ...@@ -16,6 +17,10 @@ from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
LogitBiasLogitsProcessor,
LogitsProcessorManager,
MinPLogitsProcessor,
MinTokensLogitsProcessor,
MoveDirectionality, MoveDirectionality,
init_builtin_logitsprocs) init_builtin_logitsprocs)
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
...@@ -192,6 +197,10 @@ class InputBatch: ...@@ -192,6 +197,10 @@ class InputBatch:
self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: set[str] = set() self.repetition_penalties_reqs: set[str] = set()
# Track whether sampling metadata is currently expanded to
# per-token shape (spec decode reject path).
self._sampling_metadata_is_expanded = False
# lora related # lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ), self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
dtype=np.int32) dtype=np.int32)
...@@ -593,7 +602,7 @@ class InputBatch: ...@@ -593,7 +602,7 @@ class InputBatch:
del self._req_ids[self.num_reqs:] del self._req_ids[self.num_reqs:]
del self.req_output_token_ids[self.num_reqs:] del self.req_output_token_ids[self.num_reqs:]
def refresh_metadata(self): def refresh_metadata(self, repeat_counts: Optional[torch.Tensor] = None):
"""Apply batch updates, reset input batch at end of step """Apply batch updates, reset input batch at end of step
* Apply batch add/remove/permute to logits procs' states * Apply batch add/remove/permute to logits procs' states
...@@ -602,8 +611,17 @@ class InputBatch: ...@@ -602,8 +611,17 @@ class InputBatch:
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs) batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
for logit_proc in self.logitsprocs.all: for logit_proc in self.logitsprocs.all:
logit_proc.update_state(batch_update) logit_proc.update_state(batch_update)
if batch_update: needs_rebuild = (batch_update
or repeat_counts is not None
or self._sampling_metadata_is_expanded)
if needs_rebuild:
if repeat_counts is None:
self.sampling_metadata = self._make_sampling_metadata() self.sampling_metadata = self._make_sampling_metadata()
else:
self.sampling_metadata = self._make_sampling_metadata_expanded(
repeat_counts)
self._sampling_metadata_is_expanded = repeat_counts is not None
# Expanded metadata is built on demand; do not cache a copy here.
def _make_sampling_metadata(self) -> SamplingMetadata: def _make_sampling_metadata(self) -> SamplingMetadata:
num_reqs = self.num_reqs num_reqs = self.num_reqs
...@@ -666,6 +684,105 @@ class InputBatch: ...@@ -666,6 +684,105 @@ class InputBatch:
logitsprocs=self.logitsprocs, logitsprocs=self.logitsprocs,
) )
def _make_sampling_metadata_expanded(
self, repeat_counts: torch.Tensor
) -> SamplingMetadata:
num_reqs = self.num_reqs
repeat_counts_cpu = repeat_counts
all_greedy = self.all_greedy
all_random = self.all_random
# For reject-sampling optimization, force greedy sampling to keep
# rejection sampler assumptions (per-request shapes) intact.
def _expand_cpu_to_gpu(
t: Optional[torch.Tensor],
*,
dtype: Optional[torch.dtype] = None,
) -> Optional[torch.Tensor]:
if t is None:
return None
base = t[:num_reqs]
if repeat_counts_cpu is not None:
base = base.repeat_interleave(repeat_counts_cpu, dim=0)
return base.to(device=self.device,
dtype=dtype if dtype is not None else None,
non_blocking=True)
needs_prompt_token_ids = (not self.no_penalties or
(self.num_reqs > 0
and self.logits_processing_needs_token_ids))
if needs_prompt_token_ids:
# The prompt tokens are used only for applying penalties or
# step pooling during the sampling/pooling process.
# Hence copy these tensors only when there are requests which
# need penalties/step_pooler to be applied.
prompt_token_ids = self._make_prompt_token_ids_tensor(
repeat_counts_cpu)
else:
prompt_token_ids = None
allowed_token_ids_mask: Optional[torch.Tensor] = None
if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None
allowed_token_ids_mask = self.allowed_token_ids_mask_cpu_tensor
# Expand per-request metadata to per-token shape when repeat_counts
# is provided (spec decode reject-sampling path).
top_p_cpu = None if self.no_top_p else self.top_p_cpu_tensor
top_k_cpu = None if self.no_top_k else self.top_k_cpu_tensor
repeat_list = repeat_counts_cpu.tolist()
row_offsets: list[int] = []
total_rows = 0
for repeat in repeat_list:
row_offsets.append(total_rows)
total_rows += int(repeat)
expanded_output_token_ids: list[list[int]] = []
expanded_bad_words_token_ids: dict[int, list[list[int]]] = {}
expanded_generators: dict[int, torch.Generator] = {}
row_idx = 0
for req_idx in range(num_reqs):
repeat = int(repeat_list[req_idx])
if repeat <= 0:
continue
output_tokens = self.req_output_token_ids[req_idx]
assert output_tokens is not None
bad_words = self.bad_words_token_ids.get(req_idx)
generator = self.generators.get(req_idx)
for _ in range(repeat):
expanded_output_token_ids.append(output_tokens)
if bad_words is not None:
expanded_bad_words_token_ids[row_idx] = bad_words
if generator is not None:
expanded_generators[row_idx] = generator
row_idx += 1
return SamplingMetadata(
temperature=_expand_cpu_to_gpu(
None if all_greedy else self.temperature_cpu_tensor),
all_greedy=all_greedy,
all_random=all_random,
top_p=_expand_cpu_to_gpu(top_p_cpu),
top_k=_expand_cpu_to_gpu(top_k_cpu, dtype=torch.int32),
generators=expanded_generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids,
frequency_penalties=(
None if self.no_penalties else _expand_cpu_to_gpu(
self.frequency_penalties_cpu_tensor)),
presence_penalties=(
None if self.no_penalties else _expand_cpu_to_gpu(
self.presence_penalties_cpu_tensor)),
repetition_penalties=(
None if self.no_penalties else _expand_cpu_to_gpu(
self.repetition_penalties_cpu_tensor)),
output_token_ids=expanded_output_token_ids,
no_penalties=self.no_penalties,
allowed_token_ids_mask=_expand_cpu_to_gpu(
allowed_token_ids_mask, dtype=torch.bool),
bad_words_token_ids=expanded_bad_words_token_ids,
logitsprocs=self.logitsprocs,
)
@property @property
def pooling_metadata(self) -> PoolingMetadata: def pooling_metadata(self) -> PoolingMetadata:
if len(self.pooling_params) == 0: if len(self.pooling_params) == 0:
...@@ -685,7 +802,9 @@ class InputBatch: ...@@ -685,7 +802,9 @@ class InputBatch:
pooling_params=pooling_params, pooling_params=pooling_params,
) )
def _make_prompt_token_ids_tensor(self) -> torch.Tensor: def _make_prompt_token_ids_tensor(
self, repeat_counts_cpu: Optional[torch.Tensor] = None
) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty( prompt_token_ids_cpu_tensor = torch.empty(
(self.num_reqs, max_prompt_len), (self.num_reqs, max_prompt_len),
...@@ -700,6 +819,9 @@ class InputBatch: ...@@ -700,6 +819,9 @@ class InputBatch:
# token_id of this value. # token_id of this value.
for i in range(self.num_reqs): for i in range(self.num_reqs):
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
if repeat_counts_cpu is not None:
prompt_token_ids_cpu_tensor = prompt_token_ids_cpu_tensor \
.repeat_interleave(repeat_counts_cpu, dim=0)
return prompt_token_ids_cpu_tensor.to(device=self.device, return prompt_token_ids_cpu_tensor.to(device=self.device,
non_blocking=True) non_blocking=True)
......
...@@ -572,8 +572,21 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -572,8 +572,21 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.input_batch.condense() self.input_batch.condense()
# Allow attention backend to reorder the batch, potentially # Allow attention backend to reorder the batch, potentially
self._may_reorder_batch(scheduler_output) self._may_reorder_batch(scheduler_output)
# Refresh batch metadata with any pending updates. # Refresh batch metadata with any pending updates. If we are in spec
self.input_batch.refresh_metadata() # decode + reject mode, also expand sampling metadata to token shape
# using per-request repeat counts.
repeat_counts: Optional[torch.Tensor] = None
if envs.VLLM_REJECT_SAMPLE_OPT and \
scheduler_output.scheduled_spec_decode_tokens:
num_reqs = self.input_batch.num_reqs
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index.get(req_id)
if req_idx is not None:
num_draft_tokens[req_idx] = len(draft_token_ids)
repeat_counts = torch.from_numpy(num_draft_tokens).add_(1)
self.input_batch.refresh_metadata(repeat_counts)
def _get_cumsum_and_arange( def _get_cumsum_and_arange(
self, self,
...@@ -3360,8 +3373,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3360,8 +3373,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
) )
sampler_output.sampled_token_ids = output_token_ids sampler_output.sampled_token_ids = output_token_ids
else: else:
sampling_metadata.all_greedy = True # sampling_metadata.all_greedy = True
sampling_metadata.all_random = False # sampling_metadata.all_random = False
sampler_output = self.sampler( sampler_output = self.sampler(
logits=logits, logits=logits,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment