Commit a0d556fe authored by laibao's avatar laibao
Browse files

[bugfix] 优化 reject-sampling 的 InputBatch 元数据处理

- 在 InputBatch.refresh_metadata 中为展开后的采样元数据引入 repeat_count 记录重复次数
- 完善元数据刷新逻辑以适配 reject-sampling 优化路径
- 更新 GPUModelRunnerBase,在 batch 处理阶段正确消费新的采样元数据与重复计数
parent e9532d9e
...@@ -8,6 +8,7 @@ from typing import Optional, cast ...@@ -8,6 +8,7 @@ from typing import Optional, cast
import numpy as np import numpy as np
import torch import torch
from vllm import envs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -16,6 +17,10 @@ from vllm.utils import swap_dict_values ...@@ -16,6 +17,10 @@ from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
LogitBiasLogitsProcessor,
LogitsProcessorManager,
MinPLogitsProcessor,
MinTokensLogitsProcessor,
MoveDirectionality, MoveDirectionality,
init_builtin_logitsprocs) init_builtin_logitsprocs)
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
...@@ -192,6 +197,10 @@ class InputBatch: ...@@ -192,6 +197,10 @@ class InputBatch:
self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: set[str] = set() self.repetition_penalties_reqs: set[str] = set()
# Track whether sampling metadata is currently expanded to
# per-token shape (spec decode reject path).
self._sampling_metadata_is_expanded = False
# lora related # lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ), self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
dtype=np.int32) dtype=np.int32)
...@@ -593,7 +602,7 @@ class InputBatch: ...@@ -593,7 +602,7 @@ class InputBatch:
del self._req_ids[self.num_reqs:] del self._req_ids[self.num_reqs:]
del self.req_output_token_ids[self.num_reqs:] del self.req_output_token_ids[self.num_reqs:]
def refresh_metadata(self): def refresh_metadata(self, repeat_counts: Optional[torch.Tensor] = None):
"""Apply batch updates, reset input batch at end of step """Apply batch updates, reset input batch at end of step
* Apply batch add/remove/permute to logits procs' states * Apply batch add/remove/permute to logits procs' states
...@@ -602,8 +611,17 @@ class InputBatch: ...@@ -602,8 +611,17 @@ class InputBatch:
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs) batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
for logit_proc in self.logitsprocs.all: for logit_proc in self.logitsprocs.all:
logit_proc.update_state(batch_update) logit_proc.update_state(batch_update)
if batch_update: needs_rebuild = (batch_update
self.sampling_metadata = self._make_sampling_metadata() or repeat_counts is not None
or self._sampling_metadata_is_expanded)
if needs_rebuild:
if repeat_counts is None:
self.sampling_metadata = self._make_sampling_metadata()
else:
self.sampling_metadata = self._make_sampling_metadata_expanded(
repeat_counts)
self._sampling_metadata_is_expanded = repeat_counts is not None
# Expanded metadata is built on demand; do not cache a copy here.
def _make_sampling_metadata(self) -> SamplingMetadata: def _make_sampling_metadata(self) -> SamplingMetadata:
num_reqs = self.num_reqs num_reqs = self.num_reqs
...@@ -666,6 +684,107 @@ class InputBatch: ...@@ -666,6 +684,107 @@ class InputBatch:
logitsprocs=self.logitsprocs, logitsprocs=self.logitsprocs,
) )
def _make_sampling_metadata_expanded(
self, repeat_counts: torch.Tensor
) -> SamplingMetadata:
num_reqs = self.num_reqs
repeat_counts_cpu = repeat_counts
all_greedy = self.all_greedy
all_random = self.all_random
# For reject-sampling optimization, force greedy sampling to keep
# rejection sampler assumptions (per-request shapes) intact.
def _expand_cpu_to_gpu(
t: Optional[torch.Tensor],
*,
dtype: Optional[torch.dtype] = None,
) -> Optional[torch.Tensor]:
if t is None:
return None
base = t[:num_reqs]
if repeat_counts_cpu is not None:
base = base.repeat_interleave(repeat_counts_cpu, dim=0)
return base.to(device=self.device,
dtype=dtype if dtype is not None else None,
non_blocking=True)
needs_prompt_token_ids = (not self.no_penalties or
(self.num_reqs > 0
and self.logits_processing_needs_token_ids))
if needs_prompt_token_ids:
# The prompt tokens are used only for applying penalties or
# step pooling during the sampling/pooling process.
# Hence copy these tensors only when there are requests which
# need penalties/step_pooler to be applied.
prompt_token_ids = self._make_prompt_token_ids_tensor(
repeat_counts_cpu)
else:
prompt_token_ids = None
allowed_token_ids_mask: Optional[torch.Tensor] = None
if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None
allowed_token_ids_mask = self.allowed_token_ids_mask_cpu_tensor
# Expand per-request metadata to per-token shape when repeat_counts
# is provided (spec decode reject-sampling path).
top_p_cpu = None if self.no_top_p else self.top_p_cpu_tensor
top_k_cpu = None if self.no_top_k else self.top_k_cpu_tensor
repeat_list = repeat_counts_cpu.tolist()
row_offsets: list[int] = []
total_rows = 0
for repeat in repeat_list:
row_offsets.append(total_rows)
total_rows += int(repeat)
expanded_logitsprocs = self._expand_logitsprocs(
repeat_list, row_offsets, total_rows)
expanded_output_token_ids: list[list[int]] = []
expanded_bad_words_token_ids: dict[int, list[list[int]]] = {}
expanded_generators: dict[int, torch.Generator] = {}
row_idx = 0
for req_idx in range(num_reqs):
repeat = int(repeat_list[req_idx])
if repeat <= 0:
continue
output_tokens = self.req_output_token_ids[req_idx]
assert output_tokens is not None
bad_words = self.bad_words_token_ids.get(req_idx)
generator = self.generators.get(req_idx)
for _ in range(repeat):
expanded_output_token_ids.append(output_tokens)
if bad_words is not None:
expanded_bad_words_token_ids[row_idx] = bad_words
if generator is not None:
expanded_generators[row_idx] = generator
row_idx += 1
return SamplingMetadata(
temperature=_expand_cpu_to_gpu(
None if all_greedy else self.temperature_cpu_tensor),
all_greedy=all_greedy,
all_random=all_random,
top_p=_expand_cpu_to_gpu(top_p_cpu),
top_k=_expand_cpu_to_gpu(top_k_cpu, dtype=torch.int32),
generators=expanded_generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids,
frequency_penalties=(
None if self.no_penalties else _expand_cpu_to_gpu(
self.frequency_penalties_cpu_tensor)),
presence_penalties=(
None if self.no_penalties else _expand_cpu_to_gpu(
self.presence_penalties_cpu_tensor)),
repetition_penalties=(
None if self.no_penalties else _expand_cpu_to_gpu(
self.repetition_penalties_cpu_tensor)),
output_token_ids=expanded_output_token_ids,
no_penalties=self.no_penalties,
allowed_token_ids_mask=_expand_cpu_to_gpu(
allowed_token_ids_mask, dtype=torch.bool),
bad_words_token_ids=expanded_bad_words_token_ids,
logitsprocs=expanded_logitsprocs,
)
@property @property
def pooling_metadata(self) -> PoolingMetadata: def pooling_metadata(self) -> PoolingMetadata:
if len(self.pooling_params) == 0: if len(self.pooling_params) == 0:
...@@ -685,7 +804,9 @@ class InputBatch: ...@@ -685,7 +804,9 @@ class InputBatch:
pooling_params=pooling_params, pooling_params=pooling_params,
) )
def _make_prompt_token_ids_tensor(self) -> torch.Tensor: def _make_prompt_token_ids_tensor(
self, repeat_counts_cpu: Optional[torch.Tensor] = None
) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty( prompt_token_ids_cpu_tensor = torch.empty(
(self.num_reqs, max_prompt_len), (self.num_reqs, max_prompt_len),
...@@ -700,9 +821,116 @@ class InputBatch: ...@@ -700,9 +821,116 @@ class InputBatch:
# token_id of this value. # token_id of this value.
for i in range(self.num_reqs): for i in range(self.num_reqs):
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
if repeat_counts_cpu is not None:
prompt_token_ids_cpu_tensor = prompt_token_ids_cpu_tensor \
.repeat_interleave(repeat_counts_cpu, dim=0)
return prompt_token_ids_cpu_tensor.to(device=self.device, return prompt_token_ids_cpu_tensor.to(device=self.device,
non_blocking=True) non_blocking=True)
def _expand_logitsprocs(
self, repeat_list: list[int], row_offsets: list[int], total_rows: int
) -> LogitsProcessorManager:
"""Expand per-request logits processors to per-token shape for
repeat_counts-expanded batches."""
def _expand_min_p(proc: MinPLogitsProcessor) -> MinPLogitsProcessor:
expanded = MinPLogitsProcessor(
max_num_reqs=total_rows,
pin_memory=self.pin_memory,
device=self.device)
if total_rows == 0:
expanded.min_p_count = 0
expanded.min_p = expanded.min_p_device[:0]
expanded.min_p.unsqueeze_(1)
return expanded
base_min_p_cpu = torch.from_numpy(proc.min_p_cpu[:self.num_reqs])
repeats = torch.tensor(repeat_list, dtype=torch.int64)
expanded_min_p_cpu = base_min_p_cpu.repeat_interleave(repeats)
expanded.min_p_cpu_tensor[:total_rows].copy_(expanded_min_p_cpu)
expanded.min_p = expanded.min_p_device[:total_rows]
expanded.min_p.copy_(expanded.min_p_cpu_tensor[:total_rows],
non_blocking=True)
expanded.min_p.unsqueeze_(1)
expanded.min_p_count = int((expanded_min_p_cpu != 0).sum().item())
return expanded
def _expand_logit_bias(
proc: LogitBiasLogitsProcessor) -> LogitBiasLogitsProcessor:
expanded = LogitBiasLogitsProcessor(pin_memory=self.pin_memory,
device=self.device)
# Preserve biases dict for truthiness and reuse.
expanded.biases = proc.biases
if not proc.biases or total_rows == 0:
return expanded
req_indices: list[int] = []
tok_indices: list[int] = []
bias_vals: list[float] = []
for req_idx, lb in proc.biases.items():
repeat = repeat_list[req_idx]
if repeat <= 0:
continue
start = row_offsets[req_idx]
tok_ids = list(lb.keys())
biases = list(lb.values())
for row in range(start, start + repeat):
req_indices.extend([row] * len(tok_ids))
tok_indices.extend(tok_ids)
bias_vals.extend(biases)
if bias_vals:
expanded.bias_tensor = expanded._device_tensor(
bias_vals, torch.float32)
expanded.logits_slice = (
expanded._device_tensor(req_indices, torch.int32),
expanded._device_tensor(tok_indices, torch.int32),
)
return expanded
def _expand_min_tokens(
proc: MinTokensLogitsProcessor) -> MinTokensLogitsProcessor:
expanded = MinTokensLogitsProcessor(pin_memory=self.pin_memory,
device=self.device)
expanded.min_toks = proc.min_toks
if not proc.min_toks or total_rows == 0:
return expanded
req_indices: list[int] = []
tok_indices: list[int] = []
for req_idx, (_, _, stop_tok_ids) in proc.min_toks.items():
repeat = repeat_list[req_idx]
if repeat <= 0:
continue
start = row_offsets[req_idx]
stop_ids = list(stop_tok_ids)
for row in range(start, start + repeat):
req_indices.extend([row] * len(stop_ids))
tok_indices.extend(stop_ids)
if tok_indices:
expanded.logits_slice = (
expanded._device_tensor(req_indices, torch.int32),
expanded._device_tensor(tok_indices, torch.int32),
)
return expanded
expanded_argmax: list = []
for proc in self.logitsprocs.argmax_invariant:
if isinstance(proc, MinPLogitsProcessor):
expanded_argmax.append(_expand_min_p(proc))
else:
expanded_argmax.append(proc)
expanded_non_argmax: list = []
for proc in self.logitsprocs.non_argmax_invariant:
if isinstance(proc, LogitBiasLogitsProcessor):
expanded_non_argmax.append(_expand_logit_bias(proc))
elif isinstance(proc, MinTokensLogitsProcessor):
expanded_non_argmax.append(_expand_min_tokens(proc))
else:
expanded_non_argmax.append(proc)
return LogitsProcessorManager(
argmax_invariant=expanded_argmax,
non_argmax_invariant=expanded_non_argmax,
)
def make_lora_inputs( def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray self, num_scheduled_tokens: np.ndarray
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
......
...@@ -572,8 +572,21 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -572,8 +572,21 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.input_batch.condense() self.input_batch.condense()
# Allow attention backend to reorder the batch, potentially # Allow attention backend to reorder the batch, potentially
self._may_reorder_batch(scheduler_output) self._may_reorder_batch(scheduler_output)
# Refresh batch metadata with any pending updates. # Refresh batch metadata with any pending updates. If we are in spec
self.input_batch.refresh_metadata() # decode + reject mode, also expand sampling metadata to token shape
# using per-request repeat counts.
repeat_counts: Optional[torch.Tensor] = None
if envs.VLLM_REJECT_SAMPLE_OPT and \
scheduler_output.scheduled_spec_decode_tokens:
num_reqs = self.input_batch.num_reqs
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index.get(req_id)
if req_idx is not None:
num_draft_tokens[req_idx] = len(draft_token_ids)
repeat_counts = torch.from_numpy(num_draft_tokens).add_(1)
self.input_batch.refresh_metadata(repeat_counts)
def _get_cumsum_and_arange( def _get_cumsum_and_arange(
self, self,
...@@ -3360,8 +3373,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3360,8 +3373,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
) )
sampler_output.sampled_token_ids = output_token_ids sampler_output.sampled_token_ids = output_token_ids
else: else:
sampling_metadata.all_greedy = True # sampling_metadata.all_greedy = True
sampling_metadata.all_random = False # sampling_metadata.all_random = False
sampler_output = self.sampler( sampler_output = self.sampler(
logits=logits, logits=logits,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment