Commit a0d556fe authored by laibao's avatar laibao
Browse files

[bugfix] 优化 reject-sampling 的 InputBatch 元数据处理

- 在 InputBatch.refresh_metadata 中为展开后的采样元数据引入 repeat_count 记录重复次数
- 完善元数据刷新逻辑以适配 reject-sampling 优化路径
- 更新 GPUModelRunnerBase,在 batch 处理阶段正确消费新的采样元数据与重复计数
parent e9532d9e
......@@ -8,6 +8,7 @@ from typing import Optional, cast
import numpy as np
import torch
from vllm import envs
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams
......@@ -16,6 +17,10 @@ from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
LogitBiasLogitsProcessor,
LogitsProcessorManager,
MinPLogitsProcessor,
MinTokensLogitsProcessor,
MoveDirectionality,
init_builtin_logitsprocs)
from vllm.v1.sample.metadata import SamplingMetadata
......@@ -192,6 +197,10 @@ class InputBatch:
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: set[str] = set()
# Track whether sampling metadata is currently expanded to
# per-token shape (spec decode reject path).
self._sampling_metadata_is_expanded = False
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
dtype=np.int32)
......@@ -593,7 +602,7 @@ class InputBatch:
del self._req_ids[self.num_reqs:]
del self.req_output_token_ids[self.num_reqs:]
def refresh_metadata(self):
def refresh_metadata(self, repeat_counts: Optional[torch.Tensor] = None):
"""Apply batch updates, reset input batch at end of step
* Apply batch add/remove/permute to logits procs' states
......@@ -602,8 +611,17 @@ class InputBatch:
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
for logit_proc in self.logitsprocs.all:
logit_proc.update_state(batch_update)
if batch_update:
needs_rebuild = (batch_update
or repeat_counts is not None
or self._sampling_metadata_is_expanded)
if needs_rebuild:
if repeat_counts is None:
self.sampling_metadata = self._make_sampling_metadata()
else:
self.sampling_metadata = self._make_sampling_metadata_expanded(
repeat_counts)
self._sampling_metadata_is_expanded = repeat_counts is not None
# Expanded metadata is built on demand; do not cache a copy here.
def _make_sampling_metadata(self) -> SamplingMetadata:
num_reqs = self.num_reqs
......@@ -666,6 +684,107 @@ class InputBatch:
logitsprocs=self.logitsprocs,
)
def _make_sampling_metadata_expanded(
self, repeat_counts: torch.Tensor
) -> SamplingMetadata:
num_reqs = self.num_reqs
repeat_counts_cpu = repeat_counts
all_greedy = self.all_greedy
all_random = self.all_random
# For reject-sampling optimization, force greedy sampling to keep
# rejection sampler assumptions (per-request shapes) intact.
def _expand_cpu_to_gpu(
t: Optional[torch.Tensor],
*,
dtype: Optional[torch.dtype] = None,
) -> Optional[torch.Tensor]:
if t is None:
return None
base = t[:num_reqs]
if repeat_counts_cpu is not None:
base = base.repeat_interleave(repeat_counts_cpu, dim=0)
return base.to(device=self.device,
dtype=dtype if dtype is not None else None,
non_blocking=True)
needs_prompt_token_ids = (not self.no_penalties or
(self.num_reqs > 0
and self.logits_processing_needs_token_ids))
if needs_prompt_token_ids:
# The prompt tokens are used only for applying penalties or
# step pooling during the sampling/pooling process.
# Hence copy these tensors only when there are requests which
# need penalties/step_pooler to be applied.
prompt_token_ids = self._make_prompt_token_ids_tensor(
repeat_counts_cpu)
else:
prompt_token_ids = None
allowed_token_ids_mask: Optional[torch.Tensor] = None
if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None
allowed_token_ids_mask = self.allowed_token_ids_mask_cpu_tensor
# Expand per-request metadata to per-token shape when repeat_counts
# is provided (spec decode reject-sampling path).
top_p_cpu = None if self.no_top_p else self.top_p_cpu_tensor
top_k_cpu = None if self.no_top_k else self.top_k_cpu_tensor
repeat_list = repeat_counts_cpu.tolist()
row_offsets: list[int] = []
total_rows = 0
for repeat in repeat_list:
row_offsets.append(total_rows)
total_rows += int(repeat)
expanded_logitsprocs = self._expand_logitsprocs(
repeat_list, row_offsets, total_rows)
expanded_output_token_ids: list[list[int]] = []
expanded_bad_words_token_ids: dict[int, list[list[int]]] = {}
expanded_generators: dict[int, torch.Generator] = {}
row_idx = 0
for req_idx in range(num_reqs):
repeat = int(repeat_list[req_idx])
if repeat <= 0:
continue
output_tokens = self.req_output_token_ids[req_idx]
assert output_tokens is not None
bad_words = self.bad_words_token_ids.get(req_idx)
generator = self.generators.get(req_idx)
for _ in range(repeat):
expanded_output_token_ids.append(output_tokens)
if bad_words is not None:
expanded_bad_words_token_ids[row_idx] = bad_words
if generator is not None:
expanded_generators[row_idx] = generator
row_idx += 1
return SamplingMetadata(
temperature=_expand_cpu_to_gpu(
None if all_greedy else self.temperature_cpu_tensor),
all_greedy=all_greedy,
all_random=all_random,
top_p=_expand_cpu_to_gpu(top_p_cpu),
top_k=_expand_cpu_to_gpu(top_k_cpu, dtype=torch.int32),
generators=expanded_generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids,
frequency_penalties=(
None if self.no_penalties else _expand_cpu_to_gpu(
self.frequency_penalties_cpu_tensor)),
presence_penalties=(
None if self.no_penalties else _expand_cpu_to_gpu(
self.presence_penalties_cpu_tensor)),
repetition_penalties=(
None if self.no_penalties else _expand_cpu_to_gpu(
self.repetition_penalties_cpu_tensor)),
output_token_ids=expanded_output_token_ids,
no_penalties=self.no_penalties,
allowed_token_ids_mask=_expand_cpu_to_gpu(
allowed_token_ids_mask, dtype=torch.bool),
bad_words_token_ids=expanded_bad_words_token_ids,
logitsprocs=expanded_logitsprocs,
)
@property
def pooling_metadata(self) -> PoolingMetadata:
if len(self.pooling_params) == 0:
......@@ -685,7 +804,9 @@ class InputBatch:
pooling_params=pooling_params,
)
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
def _make_prompt_token_ids_tensor(
self, repeat_counts_cpu: Optional[torch.Tensor] = None
) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(
(self.num_reqs, max_prompt_len),
......@@ -700,9 +821,116 @@ class InputBatch:
# token_id of this value.
for i in range(self.num_reqs):
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
if repeat_counts_cpu is not None:
prompt_token_ids_cpu_tensor = prompt_token_ids_cpu_tensor \
.repeat_interleave(repeat_counts_cpu, dim=0)
return prompt_token_ids_cpu_tensor.to(device=self.device,
non_blocking=True)
def _expand_logitsprocs(
self, repeat_list: list[int], row_offsets: list[int], total_rows: int
) -> LogitsProcessorManager:
"""Expand per-request logits processors to per-token shape for
repeat_counts-expanded batches."""
def _expand_min_p(proc: MinPLogitsProcessor) -> MinPLogitsProcessor:
expanded = MinPLogitsProcessor(
max_num_reqs=total_rows,
pin_memory=self.pin_memory,
device=self.device)
if total_rows == 0:
expanded.min_p_count = 0
expanded.min_p = expanded.min_p_device[:0]
expanded.min_p.unsqueeze_(1)
return expanded
base_min_p_cpu = torch.from_numpy(proc.min_p_cpu[:self.num_reqs])
repeats = torch.tensor(repeat_list, dtype=torch.int64)
expanded_min_p_cpu = base_min_p_cpu.repeat_interleave(repeats)
expanded.min_p_cpu_tensor[:total_rows].copy_(expanded_min_p_cpu)
expanded.min_p = expanded.min_p_device[:total_rows]
expanded.min_p.copy_(expanded.min_p_cpu_tensor[:total_rows],
non_blocking=True)
expanded.min_p.unsqueeze_(1)
expanded.min_p_count = int((expanded_min_p_cpu != 0).sum().item())
return expanded
def _expand_logit_bias(
proc: LogitBiasLogitsProcessor) -> LogitBiasLogitsProcessor:
expanded = LogitBiasLogitsProcessor(pin_memory=self.pin_memory,
device=self.device)
# Preserve biases dict for truthiness and reuse.
expanded.biases = proc.biases
if not proc.biases or total_rows == 0:
return expanded
req_indices: list[int] = []
tok_indices: list[int] = []
bias_vals: list[float] = []
for req_idx, lb in proc.biases.items():
repeat = repeat_list[req_idx]
if repeat <= 0:
continue
start = row_offsets[req_idx]
tok_ids = list(lb.keys())
biases = list(lb.values())
for row in range(start, start + repeat):
req_indices.extend([row] * len(tok_ids))
tok_indices.extend(tok_ids)
bias_vals.extend(biases)
if bias_vals:
expanded.bias_tensor = expanded._device_tensor(
bias_vals, torch.float32)
expanded.logits_slice = (
expanded._device_tensor(req_indices, torch.int32),
expanded._device_tensor(tok_indices, torch.int32),
)
return expanded
def _expand_min_tokens(
proc: MinTokensLogitsProcessor) -> MinTokensLogitsProcessor:
expanded = MinTokensLogitsProcessor(pin_memory=self.pin_memory,
device=self.device)
expanded.min_toks = proc.min_toks
if not proc.min_toks or total_rows == 0:
return expanded
req_indices: list[int] = []
tok_indices: list[int] = []
for req_idx, (_, _, stop_tok_ids) in proc.min_toks.items():
repeat = repeat_list[req_idx]
if repeat <= 0:
continue
start = row_offsets[req_idx]
stop_ids = list(stop_tok_ids)
for row in range(start, start + repeat):
req_indices.extend([row] * len(stop_ids))
tok_indices.extend(stop_ids)
if tok_indices:
expanded.logits_slice = (
expanded._device_tensor(req_indices, torch.int32),
expanded._device_tensor(tok_indices, torch.int32),
)
return expanded
expanded_argmax: list = []
for proc in self.logitsprocs.argmax_invariant:
if isinstance(proc, MinPLogitsProcessor):
expanded_argmax.append(_expand_min_p(proc))
else:
expanded_argmax.append(proc)
expanded_non_argmax: list = []
for proc in self.logitsprocs.non_argmax_invariant:
if isinstance(proc, LogitBiasLogitsProcessor):
expanded_non_argmax.append(_expand_logit_bias(proc))
elif isinstance(proc, MinTokensLogitsProcessor):
expanded_non_argmax.append(_expand_min_tokens(proc))
else:
expanded_non_argmax.append(proc)
return LogitsProcessorManager(
argmax_invariant=expanded_argmax,
non_argmax_invariant=expanded_non_argmax,
)
def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
......
......@@ -572,8 +572,21 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.input_batch.condense()
# Allow attention backend to reorder the batch, potentially
self._may_reorder_batch(scheduler_output)
# Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata()
# Refresh batch metadata with any pending updates. If we are in spec
# decode + reject mode, also expand sampling metadata to token shape
# using per-request repeat counts.
repeat_counts: Optional[torch.Tensor] = None
if envs.VLLM_REJECT_SAMPLE_OPT and \
scheduler_output.scheduled_spec_decode_tokens:
num_reqs = self.input_batch.num_reqs
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index.get(req_id)
if req_idx is not None:
num_draft_tokens[req_idx] = len(draft_token_ids)
repeat_counts = torch.from_numpy(num_draft_tokens).add_(1)
self.input_batch.refresh_metadata(repeat_counts)
def _get_cumsum_and_arange(
self,
......@@ -3360,8 +3373,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
)
sampler_output.sampled_token_ids = output_token_ids
else:
sampling_metadata.all_greedy = True
sampling_metadata.all_random = False
# sampling_metadata.all_greedy = True
# sampling_metadata.all_random = False
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment