# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional from functools import cached_property import torch import torch.nn as nn from vllm.logger import init_logger from vllm.triton_utils import tl, triton from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.platforms import current_platform logger = init_logger(__name__) # Maximum number of speculative draft tokens allowed per request in a single # step. This value is chosen to be large enough to handle typical use cases. MAX_SPEC_LEN = 32 PLACEHOLDER_TOKEN_ID = -1 class MtpRejectionSampler(RejectionSampler): """ The implementation strictly follows the algorithm described in https://arxiv.org/abs/2211.17192. However, we want to clarify the terminology used in the implementation: accepted tokens: tokens that are accepted based on the relationship between the "raw" draft and target probabilities. recovered tokens: tokens that are sampled based on the adjusted probability distribution, which is derived from both the draft and target probabilities. bonus tokens: If all proposed tokens are accepted, the bonus token is added to the end of the sequence. The bonus token is only sampled from the target probabilities. We pass in the bonus tokens instead of sampling them in the rejection sampler to allow for more flexibility in the sampling process. For example, we can use top_p, top_k sampling for bonus tokens, while spec decode does not support these sampling strategies. output tokens: Tokens are finally generated with the rejection sampler. output tokens = accepted tokens + recovered tokens + bonus tokens """ def __init__(self): super().__init__() # NOTE: A "bonus token" is accepted iff all proposal tokens are # accepted. There is always only one possible bonus token. We store this # value in a variable for readability. self._num_bonus_tokens = 1 def forward( self, metadata: SpecDecodeMetadata, # [num_tokens, vocab_size] draft_probs: Optional[torch.Tensor], # [num_tokens, vocab_size] target_logits: torch.Tensor, # [batch_size, 1] bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: ''' Args: metadata: Metadata for spec decoding. draft_probs (Optional[torch.Tensor]): Probability distribution for the draft tokens. Shape is [num_tokens, vocab_size]. Can be None if probabilities are not provided, which is the case for ngram spec decode. target_logits (torch.Tensor): Target model's logits probability distribution. Shape is [num_tokens, vocab_size]. Here, probabilities from different requests are flattened into a single tensor because this is the shape of the output logits. NOTE: `target_logits` can be updated in place to save memory. bonus_token_ids_tensor (torch.Tensor): A tensor containing bonus tokens. Shape is [batch_size, 1]. Bonus tokens are added to the end of the sequence if all proposed tokens are accepted. We generate the bonus tokens outside of the rejection sampler with the default sampling strategy. It allows for more flexibility in the sampling process such as top_p, top_k sampling. sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata): Additional metadata needed for sampling, such as temperature, top-k/top-p parameters, or other relevant information. Returns: output_token_ids (torch.Tensor): A tensor containing the final output token IDs. ''' assert metadata.max_spec_len <= MAX_SPEC_LEN assert draft_probs is not None # [num_tokens, vocab_size] # NOTE(woosuk): `target_logits` can be updated in place inside the # `compute_probs` function. num_draft_tokens = metadata.num_draft_tokens[0] target_probs = compute_probs( target_logits, metadata.cu_num_draft_tokens, sampling_metadata, num_draft_tokens ) target_probs = target_probs.view(-1, num_draft_tokens, target_probs.shape[-1]) draft_probs = draft_probs.view(-1, num_draft_tokens, draft_probs.shape[-1]) draft_token_ids = metadata.draft_token_ids.view(-1, num_draft_tokens) accepted, recovered_token_ids = ( self._batch_modified_rejection_sampling( target_probs, draft_probs, draft_token_ids, None, )) output_token_ids = self._create_output( accepted, recovered_token_ids, draft_token_ids, bonus_token_ids, ) return output_token_ids def _create_uniform_samples(self, seeded_seqs: Optional[dict[int, torch.Generator]], batch_size: int, k: int, device: torch.device) -> torch.Tensor: """ Generates a batch of uniform random samples, with optional seeding for specific sequences. This method creates a tensor of shape `(batch_size, k + 1)` filled with uniform random values in the range [0, 1). If `seeded_seqs` is provided, the sequences corresponding to specific indices will be generated using the provided `torch.Generator` for reproducibility. The other sequences will be generated without a seed. Args: seeded_seqs : Optional[dict[int, torch.Generator]] A dictionary mapping indices in the batch to `torch.Generator` objects. If `None`, all samples are generated without a seed. batch_size : int The number of sequences to generate. k : int The number of random samples per sequence. device : torch.device The device on which to allocate the tensor. Returns: uniform_rand : torch.Tensor A tensor of shape `(batch_size, k + 1)` containing uniform random values in the range [0, 1). """ if not seeded_seqs: return torch.rand(batch_size, k + 1, device=device) uniform_rand = torch.empty(batch_size, k + 1, device=device) non_seeded_indices = [] for idx in range(batch_size): generator = seeded_seqs.get(idx) if generator is None: non_seeded_indices.append(idx) else: uniform_rand[idx, :] = torch.rand(1, k + 1, dtype=self.probs_dtype, device=device, generator=generator) if non_seeded_indices: uniform_rand[non_seeded_indices, :] = torch.rand( len(non_seeded_indices), k + 1, dtype=self.probs_dtype, device=device) return uniform_rand def _get_accepted( self, target_probs: torch.Tensor, # [batch_size, k, vocab_size] draft_probs: torch.Tensor, # [batch_size, k, vocab_size] draft_token_ids: torch.Tensor, # [batch_size, k] seeded_seqs: Optional[dict[int, torch.Generator]], ) -> torch.Tensor: r"""Create bool matrix over the proposed draft tokens. If True, then a token can be accepted, else it should be rejected. Given $q(\hat{x}_{n+1}|x_1, \dots, x_n)$, the probability of $\hat{x}_{n+1}$ given context $x_1, \dots, x_n$ according to the target model, and $p(\hat{x}_{n+1}|x_1, \dots, x_n)$, the same conditional probability according to the draft model, the token is accepted with probability: $$ \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)} {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right) $$ This implementation does not apply causality. When using the output, if a token is rejected, subsequent tokens should not be used. Returns a bool tensor of shape [batch_size, k] specifying which tokens are accepted. """ batch_size, k, _ = draft_probs.shape batch_indices = torch.arange(batch_size, device=target_probs.device)[:, None] probs_indices = torch.arange(k, device=target_probs.device) # shape [batch_size, k] selected_draft_probs = draft_probs[batch_indices, probs_indices, draft_token_ids] # shape [batch_size, k] selected_target_probs = target_probs[batch_indices, probs_indices, draft_token_ids] uniform_rand = self._create_uniform_samples(seeded_seqs, batch_size, k - 1, target_probs.device) capped_ratio = torch.minimum( selected_target_probs / selected_draft_probs, torch.full((1, ), 1, device=target_probs.device)) accepted = uniform_rand < capped_ratio return accepted def _get_recovered_probs( self, target_probs: torch.Tensor, # [k, vocab_size] draft_probs: torch.Tensor, # [k, vocab_size] ) -> torch.Tensor: r"""Create a probability distribution for each proposed token which can be sampled if the proposed token is rejected. When this routine is applied sequentially, the true distribution of the target model is recovered (within hardware numerics). The probability distribution used in this rejection case is constructed as follows. Given $q(x|x_1, \dots, x_n)$, the probability of $x$ given context $x_1, \dots, x_n$ according to the target model and $p(x|x_1, \dots, x_n)$, the same conditional probability according to the draft model: $$ x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+ $$ where $(f(x))_+$ is defined as: $$ (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))} $$ See https://github.com/vllm-project/vllm/pull/2336 for a visualization of the draft, target, and recovered probability distributions. Returns a tensor of shape [batch_size, k, vocab_size]. Note: This batches operations on GPU and thus constructs the recovered distribution for all tokens, even if they are accepted. This causes division-by-zero errors, so we use self._smallest_positive_value to avoid that. This introduces some drift to the distribution. """ _, k, _ = draft_probs.shape # shape [batch_size, k, vocab_size] difference = target_probs - draft_probs # TODO(cade): Can we use logprobs instead of probs, and avoid the # division-by-zero errors without introducing distribution drift? # shape [batch_size, k, vocab_size] f = torch.clamp(difference, min=self._smallest_positive_value) # shape [batch_size, k, vocab_size] recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1) return recovered_probs def _batch_modified_rejection_sampling( self, target_probs: torch.Tensor, # [batch_size, k, vocab_size] draft_probs: torch.Tensor, # [batch_size, k, vocab_size] draft_token_ids: torch.Tensor, # [batch_size, k] seeded_seqs: Optional[dict[int, torch.Generator]], ) -> tuple[torch.Tensor, torch.Tensor]: """Perform modified rejection sampling on each sequence. Returns: A tuple of two tensors: 0: A bool tensor of which tokens in each sequence is accepted. shape = [batch_size, k] 1: Token ids sampled from a recovered distribution, to be used when a token is rejected. shape = [batch_size, k] """ batch_size, k, vocab_size = target_probs.shape # shape [batch_size, k] accepted = self._get_accepted(target_probs, draft_probs, draft_token_ids, seeded_seqs) recovered_probs = self._get_recovered_probs( target_probs, draft_probs).reshape(batch_size * k, vocab_size) # NOTE: the recovered_probs are overwritten by this method. recovered_token_ids = _multinomial( recovered_probs, num_samples=1, k=k, seeded_seqs=seeded_seqs or {}, ).reshape(batch_size, k) return accepted, recovered_token_ids def _create_output( self, accepted: torch.Tensor, # [batch_size, k] substitute_token_ids: torch.Tensor, # [batch_size, k] draft_token_ids: torch.Tensor, # [batch_size, k] bonus_token_ids: torch.Tensor, # [batch_size] ) -> torch.Tensor: """Format output. Returns a matrix of token ids. When a token is rejected via sampling, all subsequent token ids are set to -1 for the sequence. Args: accepted: A boolean tensor indicating if the corresponding draft token in draft_token_ids should be accepted or not. substitute_token_ids: A tensor of token_ids that can be used as substitutes for the draft token ids if the proposed token is rejected. draft_token_ids: A tensor of token ids speculated by the draft model. bonus_token_ids: Token ids to use as the bonus token if all the draft tokens are accepted. Returns: A tensor containing the accepted token ids. The shape of the tensor is [batch_size, k + num_bonus_tokens] """ batch_size, k = substitute_token_ids.shape bonus_token_ids = bonus_token_ids.squeeze(-1) # Determine the index of the first False value for each row. limits = (accepted == 0).max(1).indices limits[~(accepted == 0).any(1)] = k # Create masks using the indices. indices = torch.arange(k, device=accepted.device).unsqueeze(0) accepted_mask = indices < limits.unsqueeze(1) after_false_mask = indices == limits.unsqueeze(1) # Create an extended output tensor output_with_bonus_tokens = -torch.ones( (batch_size, k + self._num_bonus_tokens), dtype=self.token_id_dtype, device=accepted.device) output = output_with_bonus_tokens[:, :k] # Fill in the first k columns of the output tensor using masks and data # tensors. output[:, :k] = torch.where(accepted_mask, draft_token_ids, -torch.ones_like(draft_token_ids)) # Fill the last column. # We check output directly as accepted may have True values inconsistent # with causal acceptance. output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, bonus_token_ids, -1) # Fill the recovered token ids. output.mul_(~after_false_mask).add_( substitute_token_ids.mul(after_false_mask)) return output_with_bonus_tokens @staticmethod def parse_output( output_token_ids: torch.Tensor, vocab_size: int, ) -> list[list[int]]: """Parse the output of the rejection sampler. Args: output_token_ids: The sampled token IDs in shape [batch_size, max_spec_len + 1]. The rejected tokens are replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler and will be filtered out in this function. vocab_size: The size of the vocabulary. Returns: A list of lists of token IDs. """ output_token_ids_np = output_token_ids.cpu().numpy() # Create mask for valid tokens. valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (output_token_ids_np < vocab_size)) outputs = [ row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np) ] return outputs @cached_property def _smallest_positive_value(self) -> float: """Return the smallest positive value representable by the probs dtype. This value is used when constructing a distribution from which to sample recovered tokens in the first rejection case. See _get_recovered_probs for more details Note that this isn't actually the smallest positive value representable by float32, but the smallest positive normal value. See https://en.wikipedia.org/wiki/Subnormal_number for more information. """ return torch.finfo(self.probs_dtype).tiny @property def probs_dtype(self): return torch.float32 @property def token_id_dtype(self): return torch.int64 # torch.multinomial forces a GPU<->CPU sync. # Therefore, we use an optimized implementation instead that skips the sync. # Note that we always sample with replacement. # probs will be modified in place, but this is fine, as we pass # in a copy already. @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def _multinomial( probs: torch.Tensor, num_samples: int, k: int, seeded_seqs: dict[int, torch.Generator], ) -> torch.Tensor: if num_samples > 1: # This is equivalent to torch.repeat_interleaved (which also # forces a GPU<->CPU sync). probs = probs[:, None, :].expand(probs.shape[0], num_samples, probs.shape[1]).contiguous().view( -1, probs.shape[1]) q = torch.empty_like(probs) if not seeded_seqs: q.exponential_(1.0) else: start = 0 for idx in range(len(q) // k): end = start + k generator = seeded_seqs.get(idx) # Note: generator might be None for non seeded q[start:end].exponential_(1.0, generator=generator) start = end return probs.div_(q).argmax(dim=1).view(-1, num_samples) def compute_probs( logits: torch.Tensor, # [num_tokens, vocab_size] cu_num_draft_tokens: torch.Tensor, # [batch_size] sampling_metadata: SamplingMetadata, spec_len: int ) -> torch.Tensor: """Compute probability distribution from logits based on sampling metadata. This function applies temperature scaling to the logits and converts them to probabilities using softmax. For greedy decoding, it returns the original logits. Args: logits: Input logits tensor to be converted to probabilities. cu_num_draft_tokens: Cumulative number of draft tokens. sampling_metadata: Metadata containing sampling parameters such as temperature and whether greedy sampling is used. Returns: torch.Tensor: Probability distribution (softmax of scaled logits) if non-greedy sampling is used, otherwise returns the original logits. """ assert logits.ndim == 2 assert cu_num_draft_tokens.ndim == 1 if sampling_metadata.all_greedy: return logits # num_tokens = logits.shape[0] temperature = sampling_metadata.temperature.view(-1, 1).repeat(1, spec_len).view(-1) temperature = torch.where(temperature > 0, temperature, 1) # NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor. logits.div_(temperature.unsqueeze(-1)) # Get expanded top_k and top_p tensors. top_k = None if sampling_metadata.top_k is not None: top_k = sampling_metadata.top_k.view(-1, 1).repeat(1, spec_len).view(-1) top_p = None if sampling_metadata.top_p is not None: top_p = sampling_metadata.top_p.view(-1, 1).repeat(1, spec_len).view(-1) # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask, # which is slow for large vocab sizes. This may cause performance issues. logits = apply_top_k_top_p(logits, top_k, top_p) output_prob = logits.softmax(dim=-1, dtype=torch.float32) return output_prob