Unverified Commit dd572c0a authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V0 Deprecation] Remove V0 Spec Decode workers (#21152)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 9ffe905a
...@@ -2536,8 +2536,6 @@ class DeviceConfig: ...@@ -2536,8 +2536,6 @@ class DeviceConfig:
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp"] "mlp_speculator", "draft_model", "deepseek_mtp"]
SpeculativeAcceptanceMethod = Literal["rejection_sampler",
"typical_acceptance_sampler"]
@config @config
...@@ -2560,13 +2558,6 @@ class SpeculativeConfig: ...@@ -2560,13 +2558,6 @@ class SpeculativeConfig:
If using `ngram` method, the related configuration `prompt_lookup_max` and If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered.""" `prompt_lookup_min` should be considered."""
acceptance_method: SpeculativeAcceptanceMethod = "rejection_sampler"
"""The method to use for accepting draft tokens:\n
- "rejection_sampler" maps to `RejectionSampler`.\n
- "typical_acceptance_sampler" maps to `TypicalAcceptanceSampler`.
If using `typical_acceptance_sampler`, the related configuration
`posterior_threshold` and `posterior_alpha` should be considered."""
draft_tensor_parallel_size: Optional[int] = None draft_tensor_parallel_size: Optional[int] = None
"""The degree of the tensor parallelism for the draft model. Can only be 1 """The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size.""" or the same as the target model's tensor parallel size."""
...@@ -2593,9 +2584,6 @@ class SpeculativeConfig: ...@@ -2593,9 +2584,6 @@ class SpeculativeConfig:
will use the default version.""" will use the default version."""
# Advanced control # Advanced control
disable_mqa_scorer: bool = False
"""Disable the MQA scorer and fall back to batch expansion for scoring
proposals."""
disable_by_batch_size: Optional[int] = None disable_by_batch_size: Optional[int] = None
"""Disable speculative decoding for new incoming requests when the number """Disable speculative decoding for new incoming requests when the number
of enqueued requests is larger than this value, if provided.""" of enqueued requests is larger than this value, if provided."""
...@@ -2608,16 +2596,6 @@ class SpeculativeConfig: ...@@ -2608,16 +2596,6 @@ class SpeculativeConfig:
"""Minimum size of ngram token window when using Ngram proposer, if """Minimum size of ngram token window when using Ngram proposer, if
provided. Defaults to 1.""" provided. Defaults to 1."""
# Typical acceptance sampler configuration
posterior_threshold: Optional[float] = None
"""A threshold value that sets a lower bound on the posterior probability
of a token in the target model for it to be accepted. This threshold is
used only when we use the `TypicalAcceptanceSampler` for token acceptance.
"""
posterior_alpha: Optional[float] = None
"""Scaling factor for entropy-based threshold, applied when using
`TypicalAcceptanceSampler`."""
speculative_token_tree: Optional[str] = None speculative_token_tree: Optional[str] = None
"""Specifies the tree structure for speculative token generation. """Specifies the tree structure for speculative token generation.
""" """
...@@ -2795,8 +2773,8 @@ class SpeculativeConfig: ...@@ -2795,8 +2773,8 @@ class SpeculativeConfig:
elif (self.draft_model_config.hf_config.model_type == elif (self.draft_model_config.hf_config.model_type ==
"mlp_speculator"): "mlp_speculator"):
self.method = "mlp_speculator" self.method = "mlp_speculator"
elif (self.draft_model_config.hf_config.model_type == elif (self.draft_model_config.hf_config.model_type
"deepseek_mtp"): in ("deepseek_mtp", "mimo_mtp")):
self.method = "deepseek_mtp" self.method = "deepseek_mtp"
if self.num_speculative_tokens > 1: if self.num_speculative_tokens > 1:
logger.warning( logger.warning(
...@@ -2806,6 +2784,11 @@ class SpeculativeConfig: ...@@ -2806,6 +2784,11 @@ class SpeculativeConfig:
) )
else: else:
self.method = "draft_model" self.method = "draft_model"
raise NotImplementedError(
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or deepseek_mtp.")
# Replace hf_config for EAGLE draft_model # Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"): if self.method in ("eagle", "eagle3"):
...@@ -2864,12 +2847,6 @@ class SpeculativeConfig: ...@@ -2864,12 +2847,6 @@ class SpeculativeConfig:
self.target_parallel_config, self.target_parallel_config,
self.draft_tensor_parallel_size)) self.draft_tensor_parallel_size))
if self.acceptance_method == "typical_acceptance_sampler":
if self.posterior_threshold is None:
self.posterior_threshold = 0.09
if self.posterior_alpha is None:
self.posterior_alpha = 0.3
@staticmethod @staticmethod
def _maybe_override_draft_max_model_len( def _maybe_override_draft_max_model_len(
speculative_max_model_len: Optional[int], speculative_max_model_len: Optional[int],
...@@ -2975,30 +2952,6 @@ class SpeculativeConfig: ...@@ -2975,30 +2952,6 @@ class SpeculativeConfig:
if self.draft_model_config: if self.draft_model_config:
self.draft_model_config.verify_with_parallel_config( self.draft_model_config.verify_with_parallel_config(
self.draft_parallel_config) self.draft_parallel_config)
# Validate and set draft token acceptance related settings.
if self.acceptance_method is None:
raise ValueError("acceptance_method is not set. "
"Expected values are rejection_sampler or "
"typical_acceptance_sampler.")
if (self.acceptance_method != 'rejection_sampler'
and self.acceptance_method != 'typical_acceptance_sampler'):
raise ValueError(
"Expected acceptance_method to be either "
"rejection_sampler or typical_acceptance_sampler. Instead it "
f"is {self.acceptance_method}")
if self.acceptance_method == "typical_acceptance_sampler" and (
(self.posterior_threshold is not None
and self.posterior_threshold < 0) or
(self.posterior_alpha is not None and self.posterior_alpha < 0)):
raise ValueError(
"Expected the posterior_threshold and posterior_alpha of "
"typical_acceptance_sampler to be > 0. "
"Instead found posterior_threshold = "
f"{self.posterior_threshold} and posterior_alpha = "
f"{self.posterior_alpha}")
if (self.disable_by_batch_size is not None if (self.disable_by_batch_size is not None
and self.disable_by_batch_size < 2): and self.disable_by_batch_size < 2):
......
...@@ -1417,28 +1417,12 @@ class EngineArgs: ...@@ -1417,28 +1417,12 @@ class EngineArgs:
return False return False
# V1 supports N-gram, Medusa, and Eagle speculative decoding. # V1 supports N-gram, Medusa, and Eagle speculative decoding.
is_ngram_enabled = False if (self.speculative_config is not None
is_eagle_enabled = False and self.speculative_config.get("method") == "draft_model"):
is_medusa_enabled = False raise NotImplementedError(
if self.speculative_config is not None: "Speculative decoding with draft model is not supported yet. "
# This is supported but experimental (handled below). "Please consider using other speculative decoding methods "
speculative_method = self.speculative_config.get("method") "such as ngram, medusa, eagle, or deepseek_mtp.")
if speculative_method:
if speculative_method in ("ngram", "[ngram]"):
is_ngram_enabled = True
elif speculative_method == "medusa":
is_medusa_enabled = True
elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"):
is_eagle_enabled = True
else:
speculative_model = self.speculative_config.get("model")
if speculative_model in ("ngram", "[ngram]"):
is_ngram_enabled = True
if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled):
# Other speculative decoding methods are not supported yet.
_raise_or_fallback(feature_name="Speculative Decoding",
recommend_to_remove=False)
return False
# No XFormers so far. # No XFormers so far.
V1_BACKENDS = [ V1_BACKENDS = [
......
...@@ -1780,13 +1780,6 @@ class LLMEngine: ...@@ -1780,13 +1780,6 @@ class LLMEngine:
num_generation_tokens_from_prefill_groups) num_generation_tokens_from_prefill_groups)
num_tokens_iter = (num_generation_tokens_iter + num_tokens_iter = (num_generation_tokens_iter +
num_prompt_tokens_iter) num_prompt_tokens_iter)
# Spec decode, if enabled, emits specialized metrics from the worker in
# sampler output.
if model_output and isinstance(model_output[0], SamplerOutput) and (
model_output[0].spec_decode_worker_metrics is not None):
spec_decode_metrics = model_output[0].spec_decode_worker_metrics
else:
spec_decode_metrics = None
return Stats( return Stats(
now=now, now=now,
...@@ -1808,7 +1801,6 @@ class LLMEngine: ...@@ -1808,7 +1801,6 @@ class LLMEngine:
num_tokens_iter=num_tokens_iter, num_tokens_iter=num_tokens_iter,
time_to_first_tokens_iter=time_to_first_tokens_iter, time_to_first_tokens_iter=time_to_first_tokens_iter,
time_per_output_tokens_iter=time_per_output_tokens_iter, time_per_output_tokens_iter=time_per_output_tokens_iter,
spec_decode_metrics=spec_decode_metrics,
num_preemption_iter=num_preemption_iter, num_preemption_iter=num_preemption_iter,
# Request stats # Request stats
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time import time
from typing import TYPE_CHECKING
from typing import Counter as CollectionsCounter from typing import Counter as CollectionsCounter
from typing import Dict, List, Optional, Type, Union, cast from typing import Dict, List, Optional, Type, Union, cast
...@@ -19,9 +18,6 @@ if ray is not None: ...@@ -19,9 +18,6 @@ if ray is not None:
else: else:
ray_metrics = None ray_metrics = None
if TYPE_CHECKING:
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
logger = init_logger(__name__) logger = init_logger(__name__)
prometheus_client.disable_created_metrics() prometheus_client.disable_created_metrics()
...@@ -199,30 +195,6 @@ class Metrics: ...@@ -199,30 +195,6 @@ class Metrics:
documentation="Count of successfully processed requests.", documentation="Count of successfully processed requests.",
labelnames=labelnames + [Metrics.labelname_finish_reason]) labelnames=labelnames + [Metrics.labelname_finish_reason])
# Speculative decoding stats
self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
name="vllm:spec_decode_draft_acceptance_rate",
documentation="Speulative token acceptance rate.",
labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_spec_decode_efficiency = self._gauge_cls(
name="vllm:spec_decode_efficiency",
documentation="Speculative decoding system efficiency.",
labelnames=labelnames,
multiprocess_mode="sum")
self.counter_spec_decode_num_accepted_tokens = (self._counter_cls(
name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.",
labelnames=labelnames))
self.counter_spec_decode_num_draft_tokens = self._counter_cls(
name="vllm:spec_decode_num_draft_tokens_total",
documentation="Number of draft tokens.",
labelnames=labelnames)
self.counter_spec_decode_num_emitted_tokens = (self._counter_cls(
name="vllm:spec_decode_num_emitted_tokens_total",
documentation="Number of emitted tokens.",
labelnames=labelnames))
# --8<-- [end:metrics-definitions] # --8<-- [end:metrics-definitions]
...@@ -391,9 +363,6 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -391,9 +363,6 @@ class LoggingStatLogger(StatLoggerBase):
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
self.num_generation_tokens.append(stats.num_generation_tokens_iter) self.num_generation_tokens.append(stats.num_generation_tokens_iter)
# Update spec decode metrics
self.maybe_update_spec_decode_metrics(stats)
# Log locally every local_interval seconds. # Log locally every local_interval seconds.
if local_interval_elapsed(stats.now, self.last_local_log, if local_interval_elapsed(stats.now, self.last_local_log,
self.local_interval): self.local_interval):
...@@ -435,10 +404,6 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -435,10 +404,6 @@ class LoggingStatLogger(StatLoggerBase):
stats.gpu_prefix_cache_hit_rate * 100, stats.gpu_prefix_cache_hit_rate * 100,
stats.cpu_prefix_cache_hit_rate * 100, stats.cpu_prefix_cache_hit_rate * 100,
) )
if self.spec_decode_metrics is not None:
log_fn(
self._format_spec_decode_metrics_str(
self.spec_decode_metrics))
self._reset(stats, prompt_throughput, generation_throughput) self._reset(stats, prompt_throughput, generation_throughput)
...@@ -447,21 +412,9 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -447,21 +412,9 @@ class LoggingStatLogger(StatLoggerBase):
self.num_prompt_tokens = [] self.num_prompt_tokens = []
self.num_generation_tokens = [] self.num_generation_tokens = []
self.last_local_log = stats.now self.last_local_log = stats.now
self.spec_decode_metrics = None
self.last_prompt_throughput = prompt_throughput self.last_prompt_throughput = prompt_throughput
self.last_generation_throughput = generation_throughput self.last_generation_throughput = generation_throughput
def _format_spec_decode_metrics_str(
self, metrics: "SpecDecodeWorkerMetrics") -> str:
return ("Speculative metrics: "
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
f"System efficiency: {metrics.system_efficiency:.3f}, "
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
f"Number of accepted tokens: {metrics.accepted_tokens}, "
f"Number of draft tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens: {metrics.emitted_tokens}.")
def info(self, type: str, obj: SupportsMetricsInfo) -> None: def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError raise NotImplementedError
...@@ -579,33 +532,14 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -579,33 +532,14 @@ class PrometheusStatLogger(StatLoggerBase):
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
self.num_generation_tokens.append(stats.num_generation_tokens_iter) self.num_generation_tokens.append(stats.num_generation_tokens_iter)
# Update spec decode metrics
self.maybe_update_spec_decode_metrics(stats)
# Log locally every local_interval seconds. # Log locally every local_interval seconds.
if local_interval_elapsed(stats.now, self.last_local_log, if local_interval_elapsed(stats.now, self.last_local_log,
self.local_interval): self.local_interval):
if self.spec_decode_metrics is not None:
self._log_gauge(
self.metrics.gauge_spec_decode_draft_acceptance_rate,
self.spec_decode_metrics.draft_acceptance_rate)
self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
self.spec_decode_metrics.system_efficiency)
self._log_counter(
self.metrics.counter_spec_decode_num_accepted_tokens,
self.spec_decode_metrics.accepted_tokens)
self._log_counter(
self.metrics.counter_spec_decode_num_draft_tokens,
self.spec_decode_metrics.draft_tokens)
self._log_counter(
self.metrics.counter_spec_decode_num_emitted_tokens,
self.spec_decode_metrics.emitted_tokens)
# Reset tracked stats for next interval. # Reset tracked stats for next interval.
self.num_prompt_tokens = [] self.num_prompt_tokens = []
self.num_generation_tokens = [] self.num_generation_tokens = []
self.last_local_log = stats.now self.last_local_log = stats.now
self.spec_decode_metrics = None
def info(self, type: str, obj: SupportsMetricsInfo) -> None: def info(self, type: str, obj: SupportsMetricsInfo) -> None:
# Info type metrics are syntactic sugar for a gauge permanently set to 1 # Info type metrics are syntactic sugar for a gauge permanently set to 1
......
...@@ -16,10 +16,9 @@ do this in Python code and lazily import prometheus_client. ...@@ -16,10 +16,9 @@ do this in Python code and lazily import prometheus_client.
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional from typing import List
from vllm.config import SupportsMetricsInfo, VllmConfig from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
@dataclass @dataclass
...@@ -65,8 +64,6 @@ class Stats: ...@@ -65,8 +64,6 @@ class Stats:
running_lora_adapters: List[str] running_lora_adapters: List[str]
max_lora: str max_lora: str
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
class StatLoggerBase(ABC): class StatLoggerBase(ABC):
"""Base class for StatLogger.""" """Base class for StatLogger."""
...@@ -77,7 +74,6 @@ class StatLoggerBase(ABC): ...@@ -77,7 +74,6 @@ class StatLoggerBase(ABC):
self.num_generation_tokens: List[int] = [] self.num_generation_tokens: List[int] = []
self.last_local_log = time.time() self.last_local_log = time.time()
self.local_interval = local_interval self.local_interval = local_interval
self.spec_decode_metrics: Optional[SpecDecodeWorkerMetrics] = None
@abstractmethod @abstractmethod
def log(self, stats: Stats) -> None: def log(self, stats: Stats) -> None:
...@@ -86,9 +82,3 @@ class StatLoggerBase(ABC): ...@@ -86,9 +82,3 @@ class StatLoggerBase(ABC):
@abstractmethod @abstractmethod
def info(self, type: str, obj: SupportsMetricsInfo) -> None: def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError raise NotImplementedError
def maybe_update_spec_decode_metrics(self, stats: Stats):
"""Save spec decode metrics (since they are unlikely
to be emitted at same time as log interval)."""
if stats.spec_decode_metrics is not None:
self.spec_decode_metrics = stats.spec_decode_metrics
...@@ -104,11 +104,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -104,11 +104,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
seqs = sequence_group.get_seqs( seqs = sequence_group.get_seqs(
status=SequenceStatus.FINISHED_ABORTED) status=SequenceStatus.FINISHED_ABORTED)
for output in outputs:
if output.samples[0].output_token != VLLM_INVALID_TOKEN_ID:
sequence_group.metrics.spec_token_acceptance_counts[
output.step_index] += 1
assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences" assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
assert len(seqs) == 1, ( assert len(seqs) == 1, (
"Beam search not supported in multi-step decoding.") "Beam search not supported in multi-step decoding.")
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import cached_property
from importlib.util import find_spec
from typing import Optional
import torch
import torch.jit
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeStochasticBaseSampler)
from vllm.platforms import current_platform
logger = init_logger(__name__)
if find_spec("flashinfer"):
"""
Consider utilizing the FlashInfer rejection sampling kernel initially,
as it employs a dedicated kernel rather than relying on
Torch tensor operations. This design choice helps to fuse operations,
reduce memory I/O, and consequently enhances performance.
"""
from flashinfer.sampling import chain_speculative_sampling
else:
chain_speculative_sampling = None
class RejectionSampler(SpecDecodeStochasticBaseSampler):
"""Apply modified rejection sampling as described in "Accelerating Large
Language Model Decoding with Speculative Sampling"
https://arxiv.org/pdf/2302.01318.pdf.
"""
def __init__(self,
strict_mode: bool = False,
use_flashinfer: Optional[bool] = None):
"""Create a rejection sampler.
Args:
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
use_flashinfer: We will use this parameter to determine whether
to use the FlashInfer rejection sampling kernel or not. If it's
None, we will use the default value from the environment variable.
This parameter is only used for testing purposes.
"""
super().__init__(strict_mode=strict_mode)
if use_flashinfer is None:
self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and (
chain_speculative_sampling is not None)
else:
self.use_flashinfer = use_flashinfer
if self.use_flashinfer:
logger.info("Use flashinfer for rejection sampling.")
else:
logger.info("Use pytorch for rejection sampling.")
def forward(
self,
target_with_bonus_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
seeded_seqs: Optional[dict[int, torch.Generator]] = None,
) -> torch.Tensor:
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
according to the draft and target models.
In the worst case where all draft tokens are rejected, it is guaranteed
one correct token will be emitted.
In the case where all draft tokens are accepted, a bonus token will be
accepted as its cheap to have the target model score this speculative
sequence.
Args:
target_with_bonus_probs: The probability distribution
over token ids given context according to the target model.
shape = [batch_size, num_speculative_tokens + 1, vocab_size]
bonus_token_ids: The "bonus" token ids that are accepted iff all
speculative tokens in a sequence are accepted.
shape = [batch_size, num_bonus_tokens]
draft_probs: The probability distribution over token ids given
context according to the draft model.
shape = [batch_size, num_speculative_tokens, vocab_size]
draft_token_ids: The token ids that were sampled from the draft
probabilities.
shape = [batch_size, num_speculative_tokens]
seeded_seqs: Dict of batch row index to torch generator, for
sequences using seeded generation.
Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
was rejected.
shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
"""
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if self._strict_mode:
self._raise_if_incorrect_input(target_with_bonus_probs,
draft_token_ids, bonus_token_ids,
draft_probs)
batch_size, k, _ = draft_probs.shape
# batch_size = 0 when all requests in the batch are
# non_spec requests. In this case, output_token_ids is
# just an empty tensor.
if batch_size == 0:
return torch.empty(0, k + 1, device=draft_probs.device, dtype=int)
# If use Flashinfer chain_speculative_sampling kernel
# for rejection sampling
if self.use_flashinfer and chain_speculative_sampling is not None:
batch_size, k, _ = draft_probs.shape
(output_token_ids, accepted_token_num,
emitted_token_num) = chain_speculative_sampling(
draft_probs,
draft_token_ids,
target_with_bonus_probs,
)
# num_emitted_tokens returned by flashinfer
# does not include the bonus token
# Flashinfer stops at the first token that violates
# the condition p >= q and does not include recovery/bonus token.
# Therefore, we need to add batch_size here.
self.num_accepted_tokens += accepted_token_num.sum()
self.num_emitted_tokens += emitted_token_num.sum() + batch_size
self.num_draft_tokens += batch_size * k
else:
accepted, recovered_token_ids = (
self._batch_modified_rejection_sampling(
target_with_bonus_probs[:, :-1],
draft_probs,
draft_token_ids,
seeded_seqs,
))
output_token_ids = self._create_output(
accepted,
recovered_token_ids,
draft_token_ids,
bonus_token_ids,
)
return output_token_ids
def _batch_modified_rejection_sampling(
self,
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
seeded_seqs: Optional[dict[int, torch.Generator]],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Perform modified rejection sampling on each sequence.
Returns:
A tuple of two tensors:
0: A bool tensor of which tokens in each sequence is accepted.
shape = [batch_size, k]
1: Token ids sampled from a recovered distribution, to be used
when a token is rejected.
shape = [batch_size, k]
"""
batch_size, k, vocab_size = draft_probs.shape
# shape [batch_size, k]
accepted = self._get_accepted(target_probs, draft_probs,
draft_token_ids, seeded_seqs)
recovered_probs = self._get_recovered_probs(
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids = _multinomial(
recovered_probs,
num_samples=1,
k=k,
seeded_seqs=seeded_seqs or {},
).reshape(batch_size, k)
return accepted, recovered_token_ids
def _create_uniform_samples(self,
seeded_seqs: Optional[dict[int,
torch.Generator]],
batch_size: int, k: int,
device: torch.device) -> torch.Tensor:
"""
Generates a batch of uniform random samples, with optional seeding
for specific sequences.
This method creates a tensor of shape `(batch_size, k + 1)` filled
with uniform random values in the range [0, 1). If `seeded_seqs`
is provided, the sequences corresponding to specific indices
will be generated using the provided `torch.Generator` for
reproducibility. The other sequences will be generated without
a seed.
Args:
seeded_seqs : Optional[dict[int, torch.Generator]]
A dictionary mapping indices in the batch to
`torch.Generator` objects. If `None`, all samples are
generated without a seed.
batch_size : int
The number of sequences to generate.
k : int
The number of random samples per sequence.
device : torch.device
The device on which to allocate the tensor.
Returns:
uniform_rand : torch.Tensor
A tensor of shape `(batch_size, k + 1)` containing uniform
random values in the range [0, 1).
"""
if not seeded_seqs:
return torch.rand(batch_size, k + 1, device=device)
uniform_rand = torch.empty(batch_size, k + 1, device=device)
non_seeded_indices = []
for idx in range(batch_size):
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.append(idx)
else:
uniform_rand[idx, :] = torch.rand(1,
k + 1,
dtype=self.probs_dtype,
device=device,
generator=generator)
if non_seeded_indices:
uniform_rand[non_seeded_indices, :] = torch.rand(
len(non_seeded_indices),
k + 1,
dtype=self.probs_dtype,
device=device)
return uniform_rand
def _get_accepted(
self,
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
seeded_seqs: Optional[dict[int, torch.Generator]],
) -> torch.Tensor:
r"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
rejected.
Given $q(\hat{x}_{n+1}|x_1, \dots, x_n)$, the probability of
$\hat{x}_{n+1}$ given context $x_1, \dots, x_n$ according
to the target model, and $p(\hat{x}_{n+1}|x_1, \dots, x_n)$, the
same conditional probability according to the draft model, the token
is accepted with probability:
$$
\min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
{p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
$$
This implementation does not apply causality. When using the output,
if a token is rejected, subsequent tokens should not be used.
Returns a bool tensor of shape [batch_size, k] specifying which tokens
are accepted.
"""
batch_size, k, _ = draft_probs.shape
batch_indices = torch.arange(batch_size,
device=target_probs.device)[:, None]
probs_indices = torch.arange(k, device=target_probs.device)
# shape [batch_size, k]
selected_draft_probs = draft_probs[batch_indices, probs_indices,
draft_token_ids]
# shape [batch_size, k]
selected_target_probs = target_probs[batch_indices, probs_indices,
draft_token_ids]
uniform_rand = self._create_uniform_samples(seeded_seqs, batch_size,
k - 1, target_probs.device)
capped_ratio = torch.minimum(
selected_target_probs / selected_draft_probs,
torch.full((1, ), 1, device=target_probs.device))
accepted = uniform_rand < capped_ratio
return accepted
def _get_recovered_probs(
self,
target_probs: torch.Tensor, # [k, vocab_size]
draft_probs: torch.Tensor, # [k, vocab_size]
) -> torch.Tensor:
r"""Create a probability distribution for each proposed token which can
be sampled if the proposed token is rejected.
When this routine is applied sequentially, the true distribution of the
target model is recovered (within hardware numerics).
The probability distribution used in this rejection case is constructed
as follows. Given $q(x|x_1, \dots, x_n)$, the probability of
$x$ given context $x_1, \dots, x_n$ according to the target
model and $p(x|x_1, \dots, x_n)$, the same conditional probability
according to the draft model:
$$
x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
$$
where $(f(x))_+$ is defined as:
$$
(f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
$$
See https://github.com/vllm-project/vllm/pull/2336 for a visualization
of the draft, target, and recovered probability distributions.
Returns a tensor of shape [batch_size, k, vocab_size].
Note:
This batches operations on GPU and thus constructs the recovered
distribution for all tokens, even if they are accepted. This causes
division-by-zero errors, so we use self._smallest_positive_value to
avoid that. This introduces some drift to the distribution.
"""
_, k, _ = draft_probs.shape
# shape [batch_size, k, vocab_size]
difference = target_probs - draft_probs
# TODO(cade): Can we use logprobs instead of probs, and avoid the
# division-by-zero errors without introducing distribution drift?
# shape [batch_size, k, vocab_size]
f = torch.clamp(difference, min=self._smallest_positive_value)
# shape [batch_size, k, vocab_size]
recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1)
return recovered_probs
@cached_property
def _smallest_positive_value(self) -> float:
"""Return the smallest positive value representable by the probs dtype.
This value is used when constructing a distribution from which to sample
recovered tokens in the first rejection case.
See _get_recovered_probs for more details
Note that this isn't actually the smallest positive value representable
by float32, but the smallest positive normal value.
See https://en.wikipedia.org/wiki/Subnormal_number for more information.
"""
return torch.finfo(self.probs_dtype).tiny
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def _multinomial(
probs: torch.Tensor,
num_samples: int,
k: int,
seeded_seqs: dict[int, torch.Generator],
) -> torch.Tensor:
if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])
q = torch.empty_like(probs)
if not seeded_seqs:
q.exponential_(1.0)
else:
start = 0
for idx in range(len(q) // k):
end = start + k
generator = seeded_seqs.get(idx)
# Note: generator might be None for non seeded
q[start:end].exponential_(1.0, generator=generator)
start = end
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
...@@ -21,7 +21,6 @@ from vllm.sampling_params import SamplingType ...@@ -21,7 +21,6 @@ from vllm.sampling_params import SamplingType
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, Logprob, CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SampleLogprobs, SequenceOutput) PromptLogprobs, SampleLogprobs, SequenceOutput)
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
# yapf: disable # yapf: disable
...@@ -119,9 +118,6 @@ class SamplerOutput( ...@@ -119,9 +118,6 @@ class SamplerOutput(
# specified in lieu of prompt token ids or text. # specified in lieu of prompt token ids or text.
sampled_token_embeds: Optional[torch.Tensor] = None sampled_token_embeds: Optional[torch.Tensor] = None
# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
# Optional last hidden states from the model. # Optional last hidden states from the model.
hidden_states: Optional[torch.Tensor] = None hidden_states: Optional[torch.Tensor] = None
...@@ -159,11 +155,9 @@ class SamplerOutput( ...@@ -159,11 +155,9 @@ class SamplerOutput(
else self.sampled_token_probs.shape) else self.sampled_token_probs.shape)
sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
self.sampled_token_ids.shape) self.sampled_token_ids.shape)
return ( return (f"SamplerOutput(outputs={self.outputs}, "
f"SamplerOutput(outputs={self.outputs}, " f"sampled_token_probs={sampled_token_probs_repr}, "
f"sampled_token_probs={sampled_token_probs_repr}, " f"sampled_token_ids={sampled_token_ids_repr})")
f"sampled_token_ids={sampled_token_ids_repr}, "
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
class Sampler(nn.Module): class Sampler(nn.Module):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod
from typing import Optional, Union
import torch
import torch.jit
import torch.nn as nn
from vllm.platforms import current_platform
class SpecDecodeBaseSampler(nn.Module):
"""Base class for samplers used for Speculative Decoding verification
step.
"""
def __init__(self, strict_mode: bool = False):
"""Base class constructor.
Args:
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super().__init__()
self._strict_mode = strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
# accepted. There is always only one possible bonus token. We store this
# value in a variable for readability.
self._num_bonus_tokens = 1
self.num_accepted_tokens: Optional[torch.Tensor] = None
self.num_emitted_tokens: Optional[torch.Tensor] = None
self.num_draft_tokens: int = 0
def init_gpu_tensors(self, device: Union[int, str]) -> None:
assert self.num_accepted_tokens is None
if isinstance(device, int):
device = f"{current_platform.device_type}:{device}"
elif not isinstance(device, str):
raise ValueError(f"Device must be int or str, get {type(device)}")
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
self.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
def init_tensors(self,
device: Union[int, str],
device_type: Union[torch.device, str] = 'cuda') -> None:
assert self.num_accepted_tokens is None
if isinstance(device_type, torch.device):
device_type = device_type.type
if isinstance(device, int):
device = f"{device_type}:{device}"
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
self.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
@property
def probs_dtype(self):
return torch.float32
@property
def token_id_dtype(self):
return torch.int64
def _create_output(
self,
accepted: torch.Tensor, # [batch_size, k]
substitute_token_ids: torch.Tensor, # [batch_size, k]
draft_token_ids: torch.Tensor, # [batch_size, k]
bonus_token_ids: torch.Tensor, # [batch_size]
) -> torch.Tensor:
"""Format output. Returns a matrix of token ids. When
a token is rejected via sampling, all subsequent token ids are
set to -1 for the sequence.
Args:
accepted: A boolean tensor indicating if the corresponding
draft token in draft_token_ids should be accepted or not.
substitute_token_ids: A tensor of token_ids that can be used
as substitutes for the draft token ids if the proposed token
is rejected.
draft_token_ids: A tensor of token ids speculated by the
draft model.
bonus_token_ids: Token ids to use as the bonus token if
all the draft tokens are accepted.
Returns:
A tensor containing the accepted token ids. The shape of the
tensor is [batch_size, k + num_bonus_tokens]
"""
batch_size, k = substitute_token_ids.shape
bonus_token_ids = bonus_token_ids.squeeze(-1)
# Determine the index of the first False value for each row.
limits = (accepted == 0).max(1).indices
limits[~(accepted == 0).any(1)] = k
# Create masks using the indices.
indices = torch.arange(k, device=accepted.device).unsqueeze(0)
accepted_mask = indices < limits.unsqueeze(1)
after_false_mask = indices == limits.unsqueeze(1)
# Create an extended output tensor
output_with_bonus_tokens = -torch.ones(
(batch_size, k + self._num_bonus_tokens),
dtype=self.token_id_dtype,
device=accepted.device)
output = output_with_bonus_tokens[:, :k]
# Fill in the first k columns of the output tensor using masks and data
# tensors.
output[:, :k] = torch.where(accepted_mask, draft_token_ids,
-torch.ones_like(draft_token_ids))
# Fill the last column.
# We check output directly as accepted may have True values inconsistent
# with causal acceptance.
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
bonus_token_ids, -1)
# Fill the recovered token ids.
output.mul_(~after_false_mask).add_(
substitute_token_ids.mul(after_false_mask))
self.num_accepted_tokens += accepted.sum()
self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
self.num_draft_tokens += batch_size * k
return output_with_bonus_tokens
def _raise_if_incorrect_input(
self,
target_with_bonus_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None,
) -> None:
self._raise_if_incorrect_shape(target_with_bonus_probs,
draft_token_ids, bonus_token_ids,
draft_probs)
self._raise_if_incorrect_dtype(target_with_bonus_probs,
draft_token_ids, bonus_token_ids,
draft_probs)
self._raise_if_inconsistent_device(target_with_bonus_probs,
draft_token_ids, bonus_token_ids,
draft_probs)
self._raise_if_out_of_bounds_vocab(target_with_bonus_probs.shape[-1],
draft_token_ids, bonus_token_ids)
def _raise_if_incorrect_shape(
self,
target_with_bonus_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None,
) -> None:
(target_batch_size, num_target_probs,
target_vocab_size) = target_with_bonus_probs.shape
# Does not count the extra token
num_target_probs -= 1
# validate the shape of draft token ids.
draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
assert draft_token_ids_batch_size == target_batch_size
assert num_draft_token_ids == num_target_probs
# validate the shape of bonus token ids
bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
assert bonus_batch_size == target_batch_size
assert num_bonus_tokens == self._num_bonus_tokens
# validate the shape of draft probs if it is set
if draft_probs is not None:
(draft_batch_size, num_draft_probs,
draft_vocab_size) = draft_probs.shape
assert draft_batch_size == target_batch_size
assert num_draft_probs == num_target_probs
assert (draft_vocab_size == target_vocab_size
), f"{draft_vocab_size=} {target_vocab_size=}"
def _raise_if_incorrect_dtype(
self,
target_with_bonus_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None,
) -> None:
assert target_with_bonus_probs.dtype == self.probs_dtype
assert draft_token_ids.dtype == self.token_id_dtype
assert bonus_token_ids.dtype == self.token_id_dtype
if draft_probs is not None:
assert draft_probs.dtype == self.probs_dtype
def _raise_if_inconsistent_device(
self,
target_with_bonus_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None,
) -> None:
devices = [
t.device for t in [
target_with_bonus_probs, bonus_token_ids, draft_probs,
draft_token_ids
] if t is not None
]
assert all([devices[0] == device for device in devices])
def _raise_if_out_of_bounds_vocab(
self,
vocab_size: int,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
) -> None:
assert torch.all(bonus_token_ids < vocab_size)
assert torch.all(bonus_token_ids >= 0)
assert torch.all(draft_token_ids < vocab_size)
assert torch.all(draft_token_ids >= 0)
class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
"""Base class for samplers used for Speculative Decoding verification
step which are deterministic.
"""
@abstractmethod
def forward(
self,
target_with_bonus_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
"""Base class for samplers used for Speculative Decoding verification
step which are stochastic
"""
@abstractmethod
def forward(
self,
target_with_bonus_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
seeded_seqs: Optional[dict[int, torch.Generator]] = None,
) -> torch.Tensor:
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.jit
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeDeterministicBaseSampler)
class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
"""Apply typical acceptance sampling as described in section 3.3.1 in
"MEDUSA: Simple LLM Inference Acceleration Framework with
Multiple Decoding Heads"
https://arxiv.org/pdf/2401.10774
"""
def __init__(
self,
posterior_threshold: float,
posterior_alpha: float,
strict_mode: bool = False,
):
"""Create a Typical Acceptance Sampler.
Args:
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
posterior_threshold : A threshold value that sets a lower bound
on the posterior probability of a token in target model for it
to be accepted.
posterior_alpha : A scaling factor for the entropy-based
threshold in typical acceptance sampling.
"""
self._posterior_threshold = posterior_threshold
self._posterior_alpha = posterior_alpha
super().__init__(strict_mode=strict_mode)
def forward(
self,
target_with_bonus_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> torch.Tensor:
"""Sample token ids using typical acceptance sampling. This accepts
or rejects tokens proposed by the draft model using the probability
of each token according to the draft and target models.
In the worst case where all draft tokens are rejected, it is guaranteed
one token will be emitted.
In the case where all draft tokens are accepted, the bonus token will be
accepted.
Args:
target_probs: The probability distribution over token ids given
context according to the target model.
shape = [batch_size, num_speculative_tokens, vocab_size]
bonus_token_ids: The "bonus" token ids that are accepted iff all
speculative tokens in a sequence are accepted.
shape = [batch_size, num_bonus_tokens]
draft_probs: This parameter is unused by the acceptance sampler.
draft_token_ids: The token ids that were sampled from the draft
probabilities.
shape = [batch_size, num_speculative_tokens]
Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
was rejected.
shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
"""
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if self._strict_mode:
self._raise_if_incorrect_input(target_with_bonus_probs,
draft_token_ids, bonus_token_ids)
target_probs = target_with_bonus_probs[:, :-1]
accepted = self._evaluate_accepted_tokens(target_probs,
draft_token_ids)
recovered_token_ids = self._get_recovered_token_ids(target_probs)
output_token_ids = self._create_output(accepted, recovered_token_ids,
draft_token_ids,
bonus_token_ids)
return output_token_ids
def _evaluate_accepted_tokens(self, target_probs, draft_token_ids):
r"""
Evaluates and returns a mask of accepted tokens based on the
posterior probabilities.
Args:
target_probs (torch.Tensor): A tensor of shape
(batch_size, k, vocab_size) representing the probabilities of
each token in the vocabulary for each position in the proposed
sequence. This is the distribution generated by the target
model.
draft_token_ids (torch.Tensor): A tensor of shape (batch_size, k)
representing the proposed token ids.
A draft token_id x_{n+k} is accepted if it satisfies the
following condition
$$
p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) >
\min \left( \epsilon, \delta * \exp \left(
-H(p_{\text{original}}(
\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right)
$$
where $p_{\text{original}}$ corresponds to target_probs
and $\epsilon$ and $\delta$ correspond to hyperparameters
specified using self._posterior_threshold and self._posterior_alpha
This method computes the posterior probabilities for the given
draft token ids based on the provided target probabilities. It
calculates the entropy of the posterior distribution and determines
a dynamic threshold for each token position using the provided
posterior_threshold and posterior_alpha values. The method then
returns a boolean mask indicating which tokens can be accepted.
Returns:
torch.Tensor: A boolean tensor of shape (batch_size, k) where each
element indicates whether the corresponding draft token has
been accepted or rejected. True indicates acceptance and false
indicates rejection.
"""
device = target_probs.device
candidates_prob = torch.gather(
target_probs, dim=-1,
index=draft_token_ids.unsqueeze(-1)).squeeze(-1)
# A small constant added to prevent computing the logarithm of zero,
# which can lead to undefined values.
epsilon = 1e-5
posterior_entropy = -torch.sum(
target_probs * torch.log(target_probs + epsilon), dim=-1)
threshold = torch.minimum(
torch.ones_like(posterior_entropy, device=device) *
self._posterior_threshold,
torch.exp(-posterior_entropy) * self._posterior_alpha,
)
accepted_mask = candidates_prob > threshold
return accepted_mask
def _get_recovered_token_ids(self, target_probs):
"""
The recovered token ids will fill the first unmatched token
by the target token.
Args:
target_probs (torch.Tensor): A tensor of shape
(batch_size, k, vocab_size) containing the target probability
distribution.
Returns:
torch.Tensor: A tensor of shape (batch_size, k) with the recovered
token ids which are selected from target probs.
"""
max_indices = torch.argmax(target_probs, dim=-1)
return max_indices
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Optional
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix
logger = init_logger(__name__)
class DummyInputLayerNorm(nn.Module):
def __init__(self, weight=None, bias=None):
super().__init__()
self.weight = nn.Parameter(weight) if weight is not None else None
self.bias = nn.Parameter(bias) if bias is not None else None
def forward(self, x):
return x
class DummyOutputNorm(nn.Module):
def forward(self, x, residual):
if residual is None:
return x
else:
return x + residual, None
class EAGLE(nn.Module):
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
Reference implementation: https://github.com/SafeAILab/EAGLE
Differences from reference implementation:
1. In reference, LlamaDecoderLayer implementation doesn't have
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427).
Following this approach, our implementation also disables
the input_layernorm for the first decoder layer.
2. We allow any decoder layer to be used in EAGLE whereas in reference
decoder layer is fixed to be LlamaDecoderLayer.
3. We have an optional token_map which reduces draft vocab to most
frequently used tokens to give some additional speed-up by reducing
sampling overhead. This is disabled unless the checkpoint file has
explicit token_map tensor and config has an optional attribute
truncated_vocab_size < vocab_size. To use this technique, one has to find
the top-k most frequent tokens in target dataset and add that as a tensor
in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute.
4. We allow an enhanced EAGLE architecture similar to the DeepSeek MTP
module with regards to the use of additional RMS norms. The original
EAGLE architecture 1) skips the pre-attention norm in its first
transformer block, and 2) skips the final output norm, both of which we
found to be suboptimal. We also add the support for separate norms
applying to both the token embedding and hidden states before projection
as in DeepSeek MTP, which we found to improve performance as well.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.dtype = vllm_config.model_config.dtype
self.config = config
architectures = getattr(self.config.model, "architectures", [])
model_cls, _ = ModelRegistry.resolve_model_cls(architectures)
self.model = model_cls(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.fc = nn.Linear(config.model.hidden_size * 2,
config.model.hidden_size,
bias=getattr(self.config, "eagle_fc_bias", False))
# Modify layer normalization and residual connections as suggested
# in the EAGLE framework: https://github.com/SafeAILab/EAGLE
# While weights and biases are generally not needed,
# they are retained here to support certain unit tests
# (e.g., spec_decode/e2e/test_eagle_correctness.py).
if not hasattr(self.config.model,
"skip_prenorm") or self.config.model.skip_prenorm:
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm(
weight=self.model.model.layers[0].input_layernorm.weight)
if not hasattr(
self.config.model,
"skip_output_norm") or self.config.model.skip_output_norm:
self.model.model.norm = DummyOutputNorm()
self.add_para_norm = False
if hasattr(self.config.model,
"add_para_norm") and self.config.model.add_para_norm:
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.add_para_norm = True
self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size
self.unpadded_vocab_size = self.truncated_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=self.truncated_vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.truncated_vocab_size,
logit_scale)
# Token map is a idx to token mapping to reduce the vocab size for
# the draft model. Using smaller vocab size for draft, containing
# only most frequent tokens reduces the speculation overhead. This
# doesn't affect the acceptance rate much and thus gives more speed
# -up. By default, this is disabled and is only used if the EAGLE
# checkpoint file has token_map tensor.
self.token_map = None
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
# Handle both empty previous_hidden_states
# and mismatched batch size
batch_size = inputs_embeds.size(0)
if previous_hidden_states.size(0) == 0 or \
previous_hidden_states.size(0) != batch_size:
hidden_dim = self.config.model.hidden_size
device = inputs_embeds.device
# Create zero tensor with matching batch size
previous_hidden_states = \
torch.zeros(batch_size, hidden_dim, device=device)
if self.add_para_norm:
inputs_embeds = torch.cat([
self.enorm(inputs_embeds),
self.hnorm(previous_hidden_states)
],
dim=-1)
else:
inputs_embeds = torch.cat([inputs_embeds, previous_hidden_states],
dim=-1)
inputs_embeds = self.fc(inputs_embeds)
inputs_embeds[positions == 0] = 0 # masking inputs at position=0
hidden_states = self.model.model(
input_ids=None,
inputs_embeds=inputs_embeds,
positions=positions,
intermediate_tensors=intermediate_tensors,
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
if self.token_map is not None:
_logits = logits
logits = -torch.inf * torch.ones(
size=(*_logits.shape[:-1], self.orig_vocab_size),
device=_logits.device,
dtype=_logits.dtype)
logits[..., self.token_map] = _logits
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# This implementation is incompatible with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B
# due to missing lm_head weights and its config being that of a
# Llama model. Here's a compatible version with the same weights:
# https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm
# Also, here's an example script for converting trained EAGLE
# checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d
model_weights = {}
for name, loaded_weight in weights:
if name == "token_map":
if self.config.truncated_vocab_size < self.config.vocab_size:
self.token_map = nn.Parameter(loaded_weight,
requires_grad=False)
elif name.startswith("fc.weight"):
weight_loader = getattr(self.fc.weight, "weight_loader",
default_weight_loader)
weight_loader(self.fc.weight, loaded_weight)
elif name.startswith("fc.bias"):
if self.fc.bias is not None:
weight_loader = getattr(self.fc.bias, "weight_loader",
default_weight_loader)
weight_loader(self.fc.bias, loaded_weight)
else:
logger.warning_once("Found bias in the loaded weights but "
"the model config doesn't have bias.")
elif name.startswith("enorm.weight"):
weight_loader = getattr(self.enorm.weight, "weight_loader",
default_weight_loader)
weight_loader(self.enorm.weight, loaded_weight)
elif name.startswith("hnorm.weight"):
weight_loader = getattr(self.hnorm.weight, "weight_loader",
default_weight_loader)
weight_loader(self.hnorm.weight, loaded_weight)
elif name.startswith("model.lm_head.") or name.startswith(
"model.model."):
model_weights[name.split("model.", 1)[-1]] = loaded_weight
elif name.startswith("lm_head.") or name.startswith("model."):
model_weights[name] = loaded_weight
else:
model_weights[f"model.{name}"] = loaded_weight
if "lm_head.weight" in model_weights:
lm_head_weight = model_weights.pop("lm_head.weight")
if self.token_map is not None and\
lm_head_weight.shape[0] > self.token_map.shape[0]:
lm_head_weight = lm_head_weight[self.token_map]
else:
# NOTE(Shangming): initialize the placeholder for lm_head weight.
lm_head_weight = torch.zeros(
self.lm_head.org_vocab_size,
self.lm_head.embedding_dim,
dtype=self.dtype,
)
weight_loader = getattr(self.lm_head.weight, "weight_loader",
default_weight_loader)
weight_loader(self.lm_head.weight, lm_head_weight)
self.model.load_weights(model_weights.items())
...@@ -239,14 +239,15 @@ _MULTIMODAL_MODELS = { ...@@ -239,14 +239,15 @@ _MULTIMODAL_MODELS = {
_SPECULATIVE_DECODING_MODELS = { _SPECULATIVE_DECODING_MODELS = {
"MiMoMTPModel": ("mimo_mtp", "MiMoMTP"), "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
"EAGLEModel": ("eagle", "EAGLE"),
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"), "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"), "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"MedusaModel": ("medusa", "Medusa"), "MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), # Temporarily disabled.
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
# "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
} }
_TRANSFORMERS_MODELS = { _TRANSFORMERS_MODELS = {
......
...@@ -132,14 +132,10 @@ class CudaPlatformBase(Platform): ...@@ -132,14 +132,10 @@ class CudaPlatformBase(Platform):
parallel_config.worker_cls = \ parallel_config.worker_cls = \
"vllm.worker.multi_step_worker.MultiStepWorker" "vllm.worker.multi_step_worker.MultiStepWorker"
elif vllm_config.speculative_config: elif vllm_config.speculative_config:
if envs.VLLM_USE_V1: if not envs.VLLM_USE_V1:
parallel_config.worker_cls = \ raise NotImplementedError(
"vllm.v1.worker.gpu_worker.Worker" "Speculative decoding is not supported on vLLM V0.")
else: parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.worker.Worker"
else: else:
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
parallel_config.worker_cls = \ parallel_config.worker_cls = \
......
...@@ -326,15 +326,10 @@ class RocmPlatform(Platform): ...@@ -326,15 +326,10 @@ class RocmPlatform(Platform):
parallel_config.worker_cls = \ parallel_config.worker_cls = \
"vllm.worker.multi_step_worker.MultiStepWorker" "vllm.worker.multi_step_worker.MultiStepWorker"
elif vllm_config.speculative_config: elif vllm_config.speculative_config:
if envs.VLLM_USE_V1: if not envs.VLLM_USE_V1:
raise NotImplementedError( raise NotImplementedError(
"Speculative decoding is not yet supported on vLLM V1." "Speculative decoding is not supported on vLLM V0.")
) parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
else:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.worker.Worker"
else: else:
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
parallel_config.worker_cls = \ parallel_config.worker_cls = \
......
...@@ -112,13 +112,6 @@ class RequestMetrics: ...@@ -112,13 +112,6 @@ class RequestMetrics:
model_execute_time: The time spent in the model execute function. This model_execute_time: The time spent in the model execute function. This
will include model forward, block/sync across will include model forward, block/sync across
workers, cpu-gpu sync time and sampling time. workers, cpu-gpu sync time and sampling time.
spec_token_acceptance_counts: number of accepted speculative tokens at
each position; the first token is from
the target model and is always accepted;
e.g., when it's [10, 8, 4, 2] for a req,
it means there were 10 forward passes in
total, and there were 8, 4, 2 accepted
tokens at 1st, 2nd, 3rd speculation step.
""" """
arrival_time: float arrival_time: float
last_token_time: float last_token_time: float
...@@ -129,7 +122,6 @@ class RequestMetrics: ...@@ -129,7 +122,6 @@ class RequestMetrics:
scheduler_time: Optional[float] = None scheduler_time: Optional[float] = None
model_forward_time: Optional[float] = None model_forward_time: Optional[float] = None
model_execute_time: Optional[float] = None model_execute_time: Optional[float] = None
spec_token_acceptance_counts: Optional[list[int]] = None
class SequenceDataDelta( class SequenceDataDelta(
...@@ -748,9 +740,7 @@ class SequenceGroup: ...@@ -748,9 +740,7 @@ class SequenceGroup:
last_token_time=arrival_time, last_token_time=arrival_time,
first_scheduled_time=None, first_scheduled_time=None,
first_token_time=None, first_token_time=None,
time_in_queue=None, time_in_queue=None)
spec_token_acceptance_counts=[0] *
draft_size)
self.last_token_latency = 0.0 self.last_token_latency = 0.0
self.lora_request = lora_request self.lora_request = lora_request
self.prompt_logprobs: Optional[PromptLogprobs] = None self.prompt_logprobs: Optional[PromptLogprobs] = None
...@@ -1390,8 +1380,6 @@ class ExecuteModelRequest( ...@@ -1390,8 +1380,6 @@ class ExecuteModelRequest(
previous_hidden_states: Optional[HiddenStates] = None previous_hidden_states: Optional[HiddenStates] = None
# The number of forward steps to run. # The number of forward steps to run.
num_steps: int = 1 num_steps: int = 1
# The step index for spec model input.
spec_step_idx: Optional[int] = None
# Finished request ids since last step. # Finished request ids since last step.
finished_requests_ids: list[str] = msgspec.field(default_factory=list) finished_requests_ids: list[str] = msgspec.field(default_factory=list)
# The last sampled token ids for multi step decoding. # The last sampled token ids for multi step decoding.
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from array import array
from itertools import chain, count
from typing import Iterator, List, Optional, Tuple
import torch
from vllm import SamplingParams
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
ExecuteModelRequest, SequenceData,
SequenceGroupMetadata, get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
SeqId = int
TargetSeqId = int
TokenId = int
DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()
class BatchExpansionTop1Scorer(SpeculativeScorer):
"""Implements a speculative scorer that uses batch expansion to get
probabilities of speculative tokens according to the scoring model.
Batch expansion converts a list of sequences and multiple query positions
to a new batch of sequences, each with a single query position. This allows
for MQA-like scoring in speculative decoding without requiring an MQA
kernel.
It is strictly less efficient than MQA scoring.
It only supports scoring the top1 proposal tokens of the proposer, instead
of topk/tree.
"""
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
def score_proposals(
self,
execute_model_req: ExecuteModelRequest,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
"""Score the proposed tokens via the scorer model.
This converts each input sequence to a set of k+1 target sequences. The
target sequences have the unique continuations to be scored and a
unique sequence ID that is different from all input sequence ids.
If a speculative sequence length would exceed the max model length, then
no speculation is produced for that sequence.
Args:
execute_model_req: The execution request.
proposals: The speculative proposals to score.
Returns:
SpeculativeScores: The scores of each speculative token, along with
which sequences were ignored during scoring.
"""
# TODO(cade) perform this on GPU to remove blocking call.
proposal_lens_list = proposals.proposal_lens.tolist()
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
# Filter the list to ignore invalid proposals.
proposal_token_ids_list_without_skips = [
proposals for proposals in proposal_token_ids_list
if VLLM_INVALID_TOKEN_ID not in proposals
]
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) = self._expand_batch(
seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list_without_skips,
proposal_lens_list=proposal_lens_list,
)
target_sampler_output = self._scorer_worker.execute_model(
execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list))
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]
if not non_spec_indices:
# All sequence groups in batch have spec decoding enabled
return self._contract_batch_all_spec(
target_sampler_output=target_sampler_output,
proposals=proposals,
)
else:
# Batch has a mix of spec decode enabled and disabled seq groups
return self._contract_batch(
execute_model_req.seq_group_metadata_list,
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)
def _expand_batch(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_token_ids_list: List[List[TokenId]],
proposal_lens_list: List[int],
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
"""Given the input sequences and potentially multiple corresponding
proposal tokens, create a new batch where each sequence has a single
query token.
"""
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
(spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \
split_batch_by_proposal_len(
seq_group_metadata_list, proposal_lens_list)
spec_expanded_seqs = self._create_scoring_model_input(
seq_group_metadata_list=spec_seqs,
proposal_token_ids=proposal_token_ids_list,
# NOTE: We determine the seq ids in the expanded batch using the
# full seq_group_metadata_list, instead of only spec_seqs.
target_seq_ids_iter=self._create_target_seq_id_iterator(
seq_ids=get_all_seq_ids(seq_group_metadata_list)),
)
num_scoring_tokens = len(spec_expanded_seqs)
# Batch speculative and non-speculative (e.g. chunked prefill) requests
# but make sure order is prefill|decode due to backend requirement.
target_seq_group_metadata_list = non_spec_seqs + spec_expanded_seqs
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens)
def _contract_non_speculative(
self, scores: SpeculativeScores,
seq_group_metadata_list: List[SequenceGroupMetadata],
non_spec_indices: List[int], non_spec_outputs: SpeculativeScores,
has_prompt_log: bool) -> SpeculativeScores:
"""
Augment input `scores` with non-speculative requests outputs.
This includes decode requests with speculation turned off, as well
as prefill requests when `enable_chunked_prefill` is set.
For the latter, prefills are further separated into terminal and
non-terminal chunks (from which no token is sampled).
"""
if not non_spec_indices:
return scores
if has_prompt_log:
# When prompt_logprobs is enabled, prefills yield output token
# (and respective prob) in the last entry (prompt|out):
# [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
# With chunked prefill, non-terminal chunks have -1 on each
# position: they're still picked, but they're discarded later.
seq_meta = seq_group_metadata_list
nospec_sizes = torch.tensor([
seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1
for i in non_spec_indices
])
nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1)
else:
# In this case only sampled tokens are returned, select all.
nospec_sampled_token_idxs = list(
range(len(non_spec_outputs.token_ids)))
scores.token_ids[non_spec_indices, :1] = \
non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1)
scores.probs[non_spec_indices, :1, :] = \
non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1)
scores.logprobs[non_spec_indices, :1, :] = \
non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1)
if scores.hidden_states is not None:
assert non_spec_outputs.hidden_states is not None
scores.hidden_states[non_spec_indices, :1, :] = \
non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1)
return scores
def _contract_batch(
self,
contracted_seq_group_metadata_list: List[SequenceGroupMetadata],
target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int],
k: int) -> SpeculativeScores:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
contracted_bs = len(contracted_seq_group_metadata_list)
(target_token_ids, target_probs, target_logprobs, target_hidden_states,
non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs,
non_spec_target_hidden_states) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
expanded_batch_size, k = proposals.proposal_token_ids.shape
# The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for
# non-speculative sequences, prefill chunks with no out tokens included
non_spec_expanded_bs = len(non_spec_indices)
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
target_probs = target_probs.reshape(*target_token_ids.shape,
self._vocab_size)
target_logprobs = target_logprobs.reshape(target_probs.shape)
if target_hidden_states is not None:
target_hidden_states = target_hidden_states.reshape(
*target_token_ids.shape, target_hidden_states.shape[-1])
all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
fill_value=-1)
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
fill_value=-float("inf"))
if target_sampler_output.hidden_states is not None:
all_hidden_states = target_hidden_states.new_zeros(
size=(contracted_bs, k + 1, target_hidden_states.shape[-1]))
else:
all_hidden_states = None
has_prompt_log = any((sg.sampling_params.prompt_logprobs
and sg.sampling_params.prompt_logprobs > 0)
for sg in contracted_seq_group_metadata_list)
# When prompt logprobs is enabled, lens of returned tensors go from
# n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
# We adjust stride accordingly to get the generated tokens and
# their probs, but pass on prompt_logprobs as is.
prompt_logprobs = None
if (not self._scorer_worker.model_runner.disable_logprobs\
and has_prompt_log):
prompt_logprobs = [
o.prompt_logprobs for o in target_sampler_output.outputs
]
elif not has_prompt_log:
# When prompt logprobs are not to be returned,
# we can ignore non-terminal chunks (no out token).
non_spec_indices = [
idx for idx in non_spec_indices
if contracted_seq_group_metadata_list[idx].do_sample
]
# "Contract" speculative.
if spec_indices:
all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs
all_logprobs[spec_indices] = target_logprobs
if all_hidden_states is not None:
all_hidden_states[spec_indices] = target_hidden_states
spec_scores = SpeculativeScores(probs=all_probs,
token_ids=all_tokens,
logprobs=all_logprobs,
hidden_states=all_hidden_states,
prompt_logprobs=prompt_logprobs)
non_spec_outputs = SpeculativeScores(
probs=non_spec_target_probs,
token_ids=non_spec_target_token_ids,
logprobs=non_spec_target_logprobs,
hidden_states=non_spec_target_hidden_states)
# Contract remaining nonspec entries based on non_spec_indices, if any.
return self._contract_non_speculative(
spec_scores, contracted_seq_group_metadata_list, non_spec_indices,
non_spec_outputs, has_prompt_log)
def _contract_batch_all_spec(
self,
target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
It assumes all sequences in the batch were previously expanded.
"""
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
contracted_bs, k = proposals.proposal_token_ids.shape
# Reshape tensors to original batch size
target_token_ids = target_sampler_output.sampled_token_ids.reshape(
contracted_bs, k + 1)
target_probs = target_sampler_output.sampled_token_probs.reshape(
*target_token_ids.shape, self._vocab_size)
target_logprobs = target_sampler_output.logprobs.reshape(
target_probs.shape)
target_hidden_states = target_sampler_output.hidden_states
if target_hidden_states is not None:
target_hidden_states = target_hidden_states.reshape(
*target_token_ids.shape, target_hidden_states.shape[-1])
return SpeculativeScores(probs=target_probs,
token_ids=target_token_ids,
logprobs=target_logprobs,
hidden_states=target_hidden_states,
prompt_logprobs=None)
def _create_scoring_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]:
"""Given the original input sequences and proposed tokens from the draft
model, create a list of target sequences that can be used for scoring.
target_seq_ids_iter provides sequence ids for the expanded batch,
fulfilling the requirement that no seq id in the expanded batch is equal
to the seq id in the original batch.
"""
if not seq_group_metadata_list:
return []
target_seq_group_metadata = list(
chain.from_iterable(
self._create_target_seq_group_metadata(
seq_group_metadata,
proposal_token_ids,
i,
target_seq_ids_iter,
) for i, seq_group_metadata in enumerate(
seq_group_metadata_list)))
return target_seq_group_metadata
def _create_target_seq_group_metadata(
self,
input_seq_group_metadata: SequenceGroupMetadata,
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
batch_index: int,
target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]:
"""Given an input sequence group metadata and a list of draft tokens,
create a list of target SequenceGroupMetadata, one for each
token id that needs to be scored.
Naive speculative decoding requires K target model scores, one for each
draft model token. However one can add a bonus token such that if each
token is accepted, then a final token may be sampled from the model.
This function creates K+1 target SequenceGroupMetadata to take
advantage of the bonus token.
"""
assert len(input_seq_group_metadata.seq_data) == 1, (
"Beam search "
"not supported in speculative decoding")
input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))
token_ids_to_score = self._get_token_ids_to_score(
proposal_token_ids[batch_index])
sampling_params = input_seq_group_metadata.sampling_params
target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
for i, token_ids in enumerate(token_ids_to_score):
target_seq_group_metadata_list.append(
self._create_single_target_seq_group_metadata(
input_seq_group_metadata,
input_seq_id,
next(target_seq_ids_iter),
token_ids,
sampling_params=sampling_params,
))
return target_seq_group_metadata_list
@staticmethod
def _create_single_target_seq_group_metadata(
seq_group_metadata: SequenceGroupMetadata,
seq_id: SeqId,
target_seq_id: TargetSeqId,
token_ids: List[TokenId],
sampling_params: SamplingParams,
) -> SequenceGroupMetadata:
"""Create a single target SequenceGroupMetadata.
Args:
seq_group_metadata: The metadata for the input sequence.
seq_id: The input sequence ID.
target_seq_id: The corresponding target sequence ID.
token_ids: The list of token ids that are to be appended to the
input sequence.
"""
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_token_ids = seq_data.prompt_token_ids_array
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
mrope_position_delta = seq_data.mrope_position_delta
new_seq_data_dict = {
target_seq_id:
SequenceData(
prompt_token_ids,
_output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
new_output_token_ids),
),
}
# This is a hack. Technically, spec decoding should compute
# num_lookahead slots at one shot, but instead, it expands the batch
# and evaluate one by one right now. context_len is seq_len - 1 because
# the kv cache is filled by a previous batch in the batch expansion.
for data in new_seq_data_dict.values():
data.update_num_computed_tokens(data.get_len() - 1)
data.mrope_position_delta = mrope_position_delta
return SequenceGroupMetadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
seq_data=new_seq_data_dict,
sampling_params=sampling_params,
block_tables={
target_seq_id: seq_group_metadata.block_tables[seq_id],
},
lora_request=None,
token_chunk_size=1,
)
@staticmethod
def _split_scoring_output(
sampler_output: SamplerOutput, num_scoring_tokens: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
torch.Tensor, Optional[torch.Tensor]]:
"""Split the target model output into speculative and non-speculative
output.
"""
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
#
# First samples are non-speculative, latter samples are from speculative
# scoring (prefill|decode order).
split_sizes = (sampler_output.sampled_token_ids.numel() -
num_scoring_tokens, num_scoring_tokens)
(non_spec_probs,
spec_probs) = sampler_output.sampled_token_probs.split(split_sizes)
(non_spec_sampled_tokens, spec_sampled_tokens
) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
(non_spec_logprobs,
spec_logprobs) = sampler_output.logprobs.split(split_sizes)
if sampler_output.hidden_states is not None:
(non_spec_hidden_states, spec_hidden_states
) = sampler_output.hidden_states.split(split_sizes)
else:
non_spec_hidden_states, spec_hidden_states = None, None
return (spec_sampled_tokens, spec_probs, spec_logprobs,
spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
non_spec_logprobs, non_spec_hidden_states)
@staticmethod
def _create_target_seq_id_iterator(
seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
"""Create an iterator for creating target sequence ids.
Target sequence ids are distinct from sequence ids because we create a
distinct target sequence id for each proposal token to be scored.
This implementation increments a counter starting at 1 + max of all
provided input sequence ids.
"""
return count(start=max(seq_ids) + 1)
@staticmethod
def _get_token_ids_to_score(
full_spec_token_ids: List[TokenId] # shape: [k]
) -> List[List[TokenId]]:
"""Given an int tensor of proposal token ids, return a list of
token ids that should be scored.
Returns k+1 output lists. The additional one is used for generating the
bonus token.
Example:
Input: [0, 1, 2, 3] (k=4)
Output: (k+1 lists)
[]
[0]
[0, 1]
[0, 1, 2]
[0, 1, 2, 3]
"""
empty_token_ids: List[TokenId] = []
token_ids_to_score = [empty_token_ids]
token_ids_to_score.extend(full_spec_token_ids[:i + 1]
for i in range(len(full_spec_token_ids)))
return token_ids_to_score
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import List, Optional
import torch
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.sampler import SamplerOutput
try:
try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
except (ModuleNotFoundError, ImportError):
# vllm_flash_attn is not installed, try the ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
except (ModuleNotFoundError, ImportError) as err:
raise RuntimeError(
"Draft model speculative decoding currently only supports "
"CUDA and ROCm flash attention backend.") from err
from vllm.logger import init_logger
from vllm.multimodal import MultiModalKwargs
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase,
ModelRunnerWrapperBase)
logger = init_logger(__name__)
# A flag to enable debug prints for the updated input tensors
# before each step.
debug_advance_input = False
# A flag to allow GPU advance step for draft model runner.
# Set to False for debugging.
allow_gpu_advance_step = True
class TP1DraftModelRunner(ModelRunnerWrapperBase):
"""Specialized model runner for speculative decoding draft model.
Since the draft model always execute k forward passes consecutively to
generate k speculative tokens in a single speculative decoding step,
we could get rid of most CPU-GPU synchronization and data transfer
overheads by keeping model input and output tensors on GPU all the time.
TODOs:
1. Currently supports only flash-attn, add support for other attn_backends.
2. Support TP > 1 (this requires some designs because we do not expect
any broadcasting inside execute_model).
"""
def __init__(self, model_runner: ModelRunnerBase):
super().__init__(model_runner)
self.indices_of_seq_with_bonus_tokens = None
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
num_queries):
assert sampling_metadata.num_prompts == 0
assert len(sampling_metadata.seq_groups) == num_queries
assert sampling_metadata.selected_token_indices.shape == (
num_queries, )
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
# Verify that all sequences are decodes
for i in range(num_queries):
seq_group = sampling_metadata.seq_groups[i]
assert seq_group.is_prompt is False # No prompt
assert seq_group.prompt_logprob_indices == [] # No prompt
assert seq_group.sample_indices == [i] # Simple
def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
last_output: SamplerOutput) -> ModelRunnerInputBase:
# Currently, we expect "decode mode" only
assert not model_input.is_prompt
# Get num_seqs
num_seqs = len(model_input.seq_lens)
num_queries = len(model_input.query_lens)
# Get output tokens GPU tensor
sampled_token_ids = last_output.sampled_token_ids
assert sampled_token_ids is not None
# Update attn_metadata
attn_metadata = model_input.attn_metadata
assert isinstance(attn_metadata, FlashAttentionMetadata)
attn_metadata.advance_step(model_input, sampled_token_ids,
self.block_size, num_seqs, num_queries)
# Update sampling_metadata
sampling_metadata = model_input.sampling_metadata
self._update_sampling_metadata(sampling_metadata, num_seqs,
num_queries)
# Create new input
new_model_input = self._model_input_cls(
input_tokens=model_input.input_tokens,
input_positions=model_input.input_positions,
attn_metadata=attn_metadata,
seq_lens=attn_metadata.seq_lens,
query_lens=model_input.query_lens,
lora_mapping=model_input.lora_mapping,
lora_requests=model_input.lora_requests,
multi_modal_kwargs=model_input.multi_modal_kwargs,
sampling_metadata=model_input.sampling_metadata,
is_prompt=False,
)
# Ensure we skip CPU samples
assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True
# We can reuse sampling tensors since every decode iteration is the same
new_model_input.sampling_metadata.reuse_sampling_tensors = True
if debug_advance_input:
logger.debug("NEW INPUT: ")
logger.debug(" input_tokens = %s", new_model_input.input_tokens)
logger.debug(" input_positions = %s",
new_model_input.input_positions)
logger.debug(" seq_lens = %d", new_model_input.seq_lens)
logger.debug(" query_lens = %d", new_model_input.query_lens)
logger.debug(" attn_metadata:")
logger.debug(" seq_lens_tensor: %s",
attn_metadata.seq_lens_tensor)
logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping)
logger.debug(" block_tables: %s", attn_metadata.block_tables)
return new_model_input
def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
"""Determines if draft_model_runner GPU multi-step can be used.
Currently required conditions are:
1. Only decodes
2. Only flash-attn
3. No LORA
4. No prompt_adapter_config
"""
if not allow_gpu_advance_step:
return False
# We allow multi-step GPU only in decode mode
for seq_group in execute_model_req.seq_group_metadata_list:
if seq_group.is_prompt:
return False
# TODO: Add support for other attn backends
if self.attn_backend.get_name() not in ("FLASH_ATTN", ):
return False
# TODO: Add support for LORA
if self.lora_config:
return False
# TODO: Add soft-tuning prompt adapter support
return not self.prompt_adapter_config
def set_indices_of_seq_with_bonus_tokens(self,
indices_of_seq_with_bonus_tokens):
self.indices_of_seq_with_bonus_tokens = indices_of_seq_with_bonus_tokens
@torch.inference_mode()
def execute_model(
self,
model_input: ModelRunnerInputBase,
kv_caches: List[torch.Tensor],
previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
**kwargs,
) -> Optional[List[SamplerOutput]]:
"""Executes num_steps forward passes with advacement of input tensors
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
Optimizations used:
1. Input tensors are updated on the GPU directly
2. Skips GPU=>CPU serialization of sampler outputs (we don't need
them since we do batch expansion later that uses GPU outputs)
3. Reuses sampling tensors (since we run only decodes and they have
a repeating sampling logic)
"""
# When num_steps == 1, we execute the fallback here for the GPU
# advance_step, which runs prepare_inputs on CPU and for each spec
# iteration invokes this function only once
# (Look at multi-step-worker code)
is_fallback = num_steps == 1
if not is_fallback:
# Since we do not broadcast data inside execute_model anymore,
# we need to figure out the best way to support TP > 1 in this
# case, because we will at least need to broadcast the sampled
# tokens to all workers.
if not self.is_driver_worker:
raise ValueError("TP1DraftModelRunner only supports TP=1.")
# Sanity
if self.lora_config is not None:
raise ValueError("TP1DraftModelRunner has no support for LORA")
if self.prompt_adapter_config is not None:
raise ValueError("TP1DraftModelRunner has no support for "
"prompt_adapter_config")
if model_input.inputs_embeds is not None:
raise ValueError("TP1DraftModelRunner has no support for "
"inputs_embeds")
if model_input.multi_modal_kwargs:
raise ValueError(
"TP1DraftModelRunner has no support for multi_modal_kwargs"
)
else:
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
self.attn_state.begin_forward(model_input)
# Detect exec mode
assert model_input.attn_metadata is not None
use_cuda_graph = False
if model_input.attn_metadata.num_prefills > 0:
# In this case, execute_model(..) was called directly
if num_steps > 1:
raise ValueError(
"execute_model(..) of draft_model_runner can be called "
"directly only with a single-step prefill")
else:
# We can skip CPU samples for spec token generation.
# (We do allow CPU samples for num_steps == 1 to support the
# fallback case, where supports_gpu_multi_step(..) does not pass)
model_input.sampling_metadata.skip_sampler_cpu_output = (
not is_fallback)
# Attn attr defines if we use cuda graphs
use_cuda_graph = model_input.attn_metadata.use_cuda_graph
# Get model
if use_cuda_graph:
if model_input.inputs_embeds is None:
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, False)])
else:
graph_batch_size = model_input.inputs_embeds.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, True)])
if previous_hidden_states is not None:
hidden_states = torch.cat([
previous_hidden_states,
torch.empty([
graph_batch_size - previous_hidden_states.shape[0],
*previous_hidden_states.shape[1:]
],
dtype=previous_hidden_states.dtype,
device=previous_hidden_states.device)
])
else:
hidden_states = None
else:
model_executable = self.model
hidden_states = previous_hidden_states
outputs: List[SamplerOutput] = []
for step in range(num_steps):
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
model_execute_kwargs = {"previous_hidden_states": hidden_states} \
if previous_hidden_states is not None else {}
compute_logits_kwargs = {}
# Run model
if hasattr(self.model.config, "num_nextn_predict_layers"):
# for DeepSeek MTP only to use the corresponding layer for
# each step
spec_step_idx = kwargs.get("spec_step_idx", step)
model_execute_kwargs["spec_step_idx"] = spec_step_idx
compute_logits_kwargs["spec_step_idx"] = spec_step_idx
with set_forward_context(model_input.attn_metadata,
self.vllm_config):
hidden_states = model_executable(
input_ids=model_input.input_tokens,
inputs_embeds=None,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
device=self.device,
),
**model_execute_kwargs,
)
# Compute the logits.
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata,
**compute_logits_kwargs)
if not self.is_driver_worker:
return []
# Sample the next token.
output = self.model_runner.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
outputs.append(output)
if self.return_hidden_states and is_fallback:
if use_cuda_graph:
indices = model_input.sampling_metadata\
.selected_token_indices
output.hidden_states = hidden_states[:len(indices)]
else:
output.hidden_states = hidden_states
if model_input.attn_metadata.num_prefills == 0 \
and self.indices_of_seq_with_bonus_tokens is not None:
assert output.sampled_token_ids is not None
# output.sampled_token_ids should be of shape (num_seqs, 1)
nums_seqs, num_tokens_per_seq = output.sampled_token_ids.shape
assert num_tokens_per_seq == 1
count = 0
for i in range(nums_seqs):
bonus_seq_idx = self.indices_of_seq_with_bonus_tokens[
count]
if i != bonus_seq_idx:
# The following might cause a cpu->gpu sync
# However, the performance impact is negligible as we
# benchmarked on H100.
output.sampled_token_ids[
i, :] = model_input.input_tokens[bonus_seq_idx]
else:
count += 1
# Prepare inputs for the next step
if step != num_steps - 1:
model_input = self._gpu_advance_step(model_input, outputs[-1])
return outputs
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Set, Union
import torch
from vllm.sequence import ExecuteModelRequest, PromptLogprobs
from vllm.worker.worker_base import WorkerBase
@dataclass
class SpeculativeProposals:
"""Datastructure used to represent proposal tokens from some proposer. It
also tracks how many speculative tokens each sequence has.
"""
# Speculative proposal tokens.
proposal_token_ids: torch.Tensor
# Probabilities of the proposal tokens according to the proposer.
proposal_probs: torch.Tensor
# The valid length of each proposal; can be zero.
proposal_lens: torch.Tensor
# A flag to mark that there's no available proposals
no_proposals: bool = False
def __repr__(self):
return (f"SpeculativeProposals("
f"proposal_token_ids={self.proposal_token_ids}, "
f"proposal_probs={self.proposal_probs.shape}, "
f"proposal_lens={self.proposal_lens})")
@dataclass
class SpeculativeScores:
"""Datastructure used to represent the scores of speculative tokens
according to the scoring model.
"""
# Probabilities of the speculative tokens according to the scoring model.
probs: torch.Tensor
# Log-probabilities of the speculative tokens according to the scoring
# model. These values can be used to generate Logprob objects that are
# returned to the user.
logprobs: torch.Tensor
# Token ids sampled from the scoring model. Used for speculative bonus
# tokens and also non-speculative normal decoding.
token_ids: torch.Tensor
# Optional last hidden states from the scoring model.
hidden_states: Optional[torch.Tensor] = None
# Scoring model may also return logprobs for prompt tokens
# for each request, when chunked prefill is enabled.
prompt_logprobs: Optional[List[PromptLogprobs]] = None
def __repr__(self):
return (f"SpeculativeScores("
f"probs={self.probs.shape}, "
f"token_ids={self.token_ids.shape})")
class SpeculativeProposer(ABC):
@abstractmethod
def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
# If set, this contains all sequence IDs that were assigned
# bonus tokens in their last forward pass.
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> SpeculativeProposals:
raise NotImplementedError
class SpeculativeScorer(ABC):
def __init__(self, scorer_worker: WorkerBase,
device: Union[torch.device, str], vocab_size: int):
self._scorer_worker = scorer_worker
if isinstance(device, torch.device):
device = device.type
self._device = device
self._vocab_size = vocab_size
@abstractmethod
def score_proposals(
self,
execute_model_req: ExecuteModelRequest,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import weakref
from typing import List, Optional, Set, Tuple
import torch
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import DelegateWorkerBase
class MedusaWorker(NonLLMProposerWorkerBase, DelegateWorkerBase):
"""Worker for Medusa.
"""
def __init__(self, *args, **kwargs):
DelegateWorkerBase.__init__(self, *args, **kwargs)
# Lazy initialization list.
self._proposer: Top1Proposer
def init_device(self):
self.worker.init_device()
self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
self.device,
self.vocab_size,
max_proposal_len=self.max_model_len,
)
def set_include_gpu_probs_tensor(self):
pass
def set_should_modify_greedy_probs_inplace(self):
pass
@torch.inference_mode()
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
# Unused parameter.
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass to generate sample_len future tokens.
Returns the list of sampler output, one per layer, along with indicator
of whether torch tensor in sampler output need to be transposed in
latter sampler_output_to_torch logic.
For medusa worker, this indicator shall be False.
"""
self._raise_if_unsupported(execute_model_req)
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
seq_lens, query_lens = self._prepare_input_tensors(
seq_group_metadata_list)
generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory, generators)
model_outputs = self.model_runner.model.generate_proposals(
previous_hidden_states=execute_model_req.previous_hidden_states.
hidden_states,
sampling_metadata=sampling_metadata)
return model_outputs, False
def _prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[List[int], List[int]]:
if not seq_group_metadata_list:
return [], []
seq_lens: List[int] = []
query_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
is_prompt = seq_group_metadata.is_prompt
for seq_data in seq_group_metadata.seq_data.values():
seq_data_len = seq_data.get_len()
if is_prompt:
context_len = seq_data.get_num_computed_tokens()
seq_len = min(
seq_data_len,
context_len + seq_group_metadata.token_chunk_size)
seq_lens.append(seq_len)
query_lens.append(seq_len - context_len)
else:
seq_lens.append(seq_data_len)
query_lens.append(1)
return seq_lens, query_lens
def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return self._proposer.get_spec_proposals(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
def _raise_if_unsupported(
self,
execute_model_req: ExecuteModelRequest,
) -> None:
"""MedusaWorker does not yet implement support for cache swap
operations or beam search.
"""
if any([
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy
]):
raise NotImplementedError(
"MedusaWorker does not support cache operations")
if any(
len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in
execute_model_req.seq_group_metadata_list):
raise NotImplementedError(
"MedusaWorker does not support beam search.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment