Unverified Commit d9406076 authored by Yiliu Dong's avatar Yiliu Dong Committed by GitHub
Browse files

[Core] Support `min_tokens` with speculative decoding (#32642)


Signed-off-by: default avatarqianlihuang <yiliu.dong@qq.com>
Co-authored-by: default avatarqianlihuang <yiliu.dong@qq.com>
parent 99c7892c
...@@ -32,8 +32,7 @@ example_prompts = [first_prompt, "In one word, the capital of France is "] + [ ...@@ -32,8 +32,7 @@ example_prompts = [first_prompt, "In one word, the capital of France is "] + [
default_params = dict( default_params = dict(
temperature=0.0, # greedy temperature=0.0, # greedy
max_tokens=30, max_tokens=30,
# spec decoding currently doesn't support min_tokens min_tokens=28,
# min_tokens=28,
) )
......
...@@ -276,9 +276,12 @@ def test_rejects_custom_logitsprocs( ...@@ -276,9 +276,12 @@ def test_rejects_custom_logitsprocs(
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork") monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork")
llm = LLM(**llm_kwargs) llm = LLM(**llm_kwargs)
# Require that no logitsprocs have been loaded # Require that no custom logitsprocs have been loaded
# (built-in processors may exist: MinTokensLogitsProcessor,
# LogitBiasLogitsProcessor, MinPLogitsProcessor)
worker = llm.llm_engine.model_executor.driver_worker.worker worker = llm.llm_engine.model_executor.driver_worker.worker
assert sum([1 for _ in worker.model_runner.input_batch.logitsprocs.all]) == 0 for proc in worker.model_runner.input_batch.logitsprocs.all:
assert not isinstance(proc, DummyLogitsProcessor)
return return
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN: if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN:
......
...@@ -678,9 +678,9 @@ class SamplingParams( ...@@ -678,9 +678,9 @@ class SamplingParams(
return return
# Some sampling parameters are not yet compatible with spec decoding. # Some sampling parameters are not yet compatible with spec decoding.
if self.min_tokens > 1 or self.min_p > _SAMPLING_EPS or self.logit_bias: if self.min_p > _SAMPLING_EPS or self.logit_bias:
raise ValueError( raise ValueError(
"The min_tokens, min_p, and logit_bias sampling parameters " "The min_p and logit_bias sampling parameters "
"are not yet supported with speculative decoding." "are not yet supported with speculative decoding."
) )
......
...@@ -202,10 +202,11 @@ def build_logitsprocs( ...@@ -202,10 +202,11 @@ def build_logitsprocs(
if custom_logitsprocs: if custom_logitsprocs:
raise ValueError(STR_SPEC_DEC_REJECTS_LOGITSPROCS) raise ValueError(STR_SPEC_DEC_REJECTS_LOGITSPROCS)
logger.warning( logger.warning(
"min_p, logit_bias, and min_tokens parameters won't currently work " "min_p and logit_bias parameters won't work with speculative decoding."
"with speculative decoding enabled." )
return LogitsProcessors(
[MinTokensLogitsProcessor(vllm_config, device, is_pin_memory)]
) )
return LogitsProcessors()
custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs) custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs)
return LogitsProcessors( return LogitsProcessors(
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, TypeVar from typing import TYPE_CHECKING, TypeVar
import numpy as np
import torch import torch
from vllm import SamplingParams from vllm import SamplingParams
...@@ -236,6 +237,59 @@ class MinTokensLogitsProcessor(LogitsProcessor): ...@@ -236,6 +237,59 @@ class MinTokensLogitsProcessor(LogitsProcessor):
logits.index_put_(self.logits_slice, self.neg_inf_tensor) logits.index_put_(self.logits_slice, self.neg_inf_tensor)
return logits return logits
def apply_with_spec_decode(
self,
logits: torch.Tensor,
num_draft_tokens: list[int],
) -> torch.Tensor:
"""Spec-decode version of apply().
Priority: ``min_tokens`` > ``stop_token_ids`` / EOS.
Example: ``num_draft_tokens = [2, 3, 1]``
→ ``logits`` shape ``[6, V]``, ``cumsum = [0, 2, 5, 6]``
→ request 0 owns rows 0‑1, request 1 rows 2‑4, request 2 row 5.
"""
if not self.min_toks:
return logits
num_draft_arr = np.array(num_draft_tokens, dtype=np.int64)
cumsum = np.concatenate([[0], np.cumsum(num_draft_arr)])
entries = [
(req_idx, min_tok, len(out_tok_ids), list(stop_tok_ids))
for req_idx, (min_tok, out_tok_ids, stop_tok_ids) in self.min_toks.items()
if stop_tok_ids
]
if not entries:
return logits
all_rows: list[np.ndarray] = [] # row indices to mask
all_toks: list[np.ndarray] = [] # stop-token ids at those rows
for req_idx, min_tok, current_len, stop_toks in entries:
remaining = min_tok - current_len
# How many leading draft positions still need stop-token masking.
n_mask = int(min(max(remaining, 0), num_draft_arr[req_idx]))
if n_mask > 0:
offset = cumsum[req_idx]
row_indices = np.arange(offset, offset + n_mask, dtype=np.int64)
n_stop = len(stop_toks)
all_rows.append(np.repeat(row_indices, n_stop))
all_toks.append(np.tile(stop_toks, n_mask))
if all_rows:
rows_arr = np.concatenate(all_rows)
toks_arr = np.concatenate(all_toks)
# (row_indices, token_indices) for index_put_ to set -inf.
logits_slice = (
torch.from_numpy(rows_arr).to(self.device, non_blocking=True),
torch.from_numpy(toks_arr).to(self.device, non_blocking=True),
)
logits.index_put_(logits_slice, self.neg_inf_tensor)
return logits
def process_dict_updates( def process_dict_updates(
req_entries: dict[int, T], req_entries: dict[int, T],
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterator from collections.abc import Iterable, Iterator
from itertools import chain from itertools import chain
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
...@@ -148,7 +148,7 @@ class BatchUpdateBuilder: ...@@ -148,7 +148,7 @@ class BatchUpdateBuilder:
class LogitsProcessors: class LogitsProcessors:
"""Encapsulates initialized logitsproc objects.""" """Encapsulates initialized logitsproc objects."""
def __init__(self, logitsprocs: Iterator["LogitsProcessor"] | None = None) -> None: def __init__(self, logitsprocs: Iterable["LogitsProcessor"] | None = None) -> None:
self.argmax_invariant: list[LogitsProcessor] = [] self.argmax_invariant: list[LogitsProcessor] = []
self.non_argmax_invariant: list[LogitsProcessor] = [] self.non_argmax_invariant: list[LogitsProcessor] = []
if logitsprocs: if logitsprocs:
......
...@@ -10,6 +10,7 @@ import torch.nn as nn ...@@ -10,6 +10,7 @@ import torch.nn as nn
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsLists, LogprobsTensors, SamplerOutput from vllm.v1.outputs import LogprobsLists, LogprobsTensors, SamplerOutput
from vllm.v1.sample.logits_processor.builtin import MinTokensLogitsProcessor
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
from vllm.v1.sample.ops.penalties import apply_all_penalties from vllm.v1.sample.ops.penalties import apply_all_penalties
...@@ -292,6 +293,12 @@ class RejectionSampler(nn.Module): ...@@ -292,6 +293,12 @@ class RejectionSampler(nn.Module):
logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens
) )
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
if isinstance(processor, MinTokensLogitsProcessor):
logits = processor.apply_with_spec_decode(
logits, metadata.num_draft_tokens
)
return logits return logits
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment