# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional from collections.abc import Sequence from dataclasses import replace import torch import torch.nn as nn from vllm.logger import init_logger from vllm.triton_utils import tl, triton from vllm.v1.outputs import LogprobsLists, LogprobsTensors, SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts from vllm.v1.sample.ops.penalties import apply_all_penalties from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.sample.sampler import Sampler from vllm.v1.spec_decode.metadata import SpecDecodeMetadata logger = init_logger(__name__) PLACEHOLDER_TOKEN_ID: tl.constexpr = -1 GREEDY_TEMPERATURE: tl.constexpr = 0 # Maximum number of speculative draft tokens allowed per request in a single # step. This value is chosen to be large enough to handle typical use cases. MAX_SPEC_LEN = 128 class OptRejectionSampler(nn.Module): """ The implementation strictly follows the algorithm described in https://arxiv.org/abs/2211.17192. However, we want to clarify the terminology used in the implementation: accepted tokens: tokens that are accepted based on the relationship between the "raw" draft and target probabilities. recovered tokens: tokens that are sampled based on the adjusted probability distribution, which is derived from both the draft and target probabilities. bonus tokens: If all proposed tokens are accepted, the bonus token is added to the end of the sequence. The bonus token is only sampled from the target probabilities. We pass in the bonus tokens instead of sampling them in the rejection sampler to allow for more flexibility in the sampling process. For example, we can use top_p, top_k sampling for bonus tokens, while spec decode does not support these sampling strategies. output tokens: Tokens are finally generated with the rejection sampler. output tokens = accepted tokens + recovered tokens + bonus tokens """ def __init__(self, sampler: Sampler): super().__init__() self.sampler = sampler logprobs_mode = self.sampler.logprobs_mode self.is_processed_logprobs_mode = logprobs_mode.startswith("processed") self.is_logits_logprobs_mode = logprobs_mode.endswith("logits") def forward( self, metadata: SpecDecodeMetadata, # [num_tokens, vocab_size] draft_probs: torch.Tensor | None, # [num_tokens + batch_size, vocab_size] logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: """ Args: metadata: Metadata for spec decoding. draft_probs (Optional[torch.Tensor]): Probability distribution for the draft tokens. Shape is [num_tokens, vocab_size]. Can be None if probabilities are not provided, which is the case for ngram spec decode. logits (torch.Tensor): Target model's logits probability distribution. Shape is [num_tokens + batch_size, vocab_size]. Here, probabilities from different requests are flattened into a single tensor because this is the shape of the output logits. NOTE: `logits` can be updated in place to save memory. sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata): Additional metadata needed for sampling, such as temperature, top-k/top-p parameters, or other relevant information. Returns: SamplerOutput: Contains the final output token IDs and their logprobs if requested. """ assert metadata.max_spec_len <= MAX_SPEC_LEN bonus_logits_indices = metadata.bonus_logits_indices target_logits_indices = metadata.target_logits_indices # When indexing with a tensor (bonus_logits_indices), PyTorch # creates a new tensor with separate storage from the original # logits tensor. This means any in-place operations on bonus_logits # won't affect the original logits tensor. assert logits is not None sampling_metadata.all_greedy = True sampling_metadata.all_random = False sampler_output = self.sampler( logits=logits, sampling_metadata=replace( sampling_metadata, max_num_logprobs=-1, ), predict_bonus_token=True, # Override the logprobs mode to return logits because they are # needed later to compute the accepted token logprobs. logprobs_mode_override="processed_logits" if self.is_processed_logprobs_mode else "raw_logits", ) target_logits = logits[target_logits_indices] target_tokens = sampler_output.sampled_token_ids[target_logits_indices] bonus_token_ids = sampler_output.sampled_token_ids[bonus_logits_indices] # Compute probability distribution from target logits. target_probs = target_logits.softmax(dim=-1, dtype=torch.float32) output_token_ids = rejection_sample( metadata.draft_token_ids, metadata.num_draft_tokens, metadata.max_spec_len, metadata.cu_num_draft_tokens, draft_probs, target_probs, target_tokens, bonus_token_ids, sampling_metadata, ) logprobs_tensors = None if sampling_metadata.max_num_logprobs is not None: logprobs_tensors = self._get_logprobs_tensors( sampling_metadata.max_num_logprobs, metadata, sampler_output.logprobs_tensors.logprobs, output_token_ids, ) return SamplerOutput( sampled_token_ids=output_token_ids, logprobs_tensors=logprobs_tensors, ) def _get_logprobs_tensors( self, max_num_logprobs: int, metadata: SpecDecodeMetadata, logits: torch.Tensor, sampled_token_ids: torch.Tensor, ) -> LogprobsTensors: cu_num_sampled_tokens = torch.zeros_like(metadata.cu_num_sampled_tokens) cu_num_sampled_tokens[1:] = metadata.cu_num_sampled_tokens[:-1] final_logits = logits.to(torch.float32) # NOTE: To avoid cpu-gpu synchronization, we now simply compute indices for # all draft tokens, including the rejected ones. The rejected tokens will # be filtered out in the `parse_output`. logit_start_indices = cu_num_sampled_tokens offsets = torch.arange( sampled_token_ids.shape[-1], device=logit_start_indices.device, dtype=logit_start_indices.dtype, ) accepted_logit_indices = ( logit_start_indices.unsqueeze(1) + offsets.unsqueeze(0) ).flatten() accepted_logit_indices.clamp_(max=final_logits.shape[0] - 1) accepted_tokens = sampled_token_ids.clone().flatten() # we replace rejected token ids with 0 to avoid gather_logprobs error accepted_tokens[accepted_tokens == PLACEHOLDER_TOKEN_ID] = 0 # Compute logprobs for accepted tokens. accepted_logits = final_logits[accepted_logit_indices] accepted_logprobs = ( accepted_logits if self.is_logits_logprobs_mode else self.sampler.compute_logprobs(accepted_logits) ) return self.sampler.gather_logprobs( accepted_logprobs, max_num_logprobs, accepted_tokens.to(torch.int64), ) @staticmethod def parse_output( output_token_ids: torch.Tensor, vocab_size: int, discard_req_indices: Sequence[int] = (), logprobs_tensors: LogprobsTensors | None = None, ) -> tuple[list[list[int]], LogprobsLists | None]: """Parse the output of the rejection sampler. Args: output_token_ids: The sampled token IDs in shape [batch_size, max_spec_len + 1]. The rejected tokens are replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler and will be filtered out in this function. vocab_size: The size of the vocabulary. discard_req_indices: Optional row indices to discard tokens in. logprobs_tensors: Optional logprobs tensors to filter. Returns: A list of lists of token IDs. """ output_token_ids_np = output_token_ids.cpu().numpy() # Create mask for valid tokens. valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & ( output_token_ids_np < vocab_size ) output_logprobs = None if logprobs_tensors is not None: cu_num_tokens = [0] + valid_mask.sum(axis=1).cumsum().tolist() filtered_tensors = logprobs_tensors.filter(valid_mask.flatten()) output_logprobs = filtered_tensors.tolists(cu_num_tokens) if len(discard_req_indices) > 0: valid_mask[discard_req_indices] = False outputs = [ row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np) ] return outputs, output_logprobs def apply_logits_processors( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, metadata: SpecDecodeMetadata, ) -> torch.Tensor: has_penalties = not sampling_metadata.no_penalties any_penalties_or_bad_words = ( sampling_metadata.bad_words_token_ids or has_penalties ) output_token_ids = sampling_metadata.output_token_ids if any_penalties_or_bad_words: output_token_ids = self._combine_outputs_with_spec_tokens( output_token_ids, sampling_metadata.spec_token_ids, ) # Calculate indices of target logits. if sampling_metadata.allowed_token_ids_mask is not None or has_penalties: num_requests = len(sampling_metadata.output_token_ids) num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu") original_indices = torch.arange(num_requests, device="cpu") repeat_indices_cpu = original_indices.repeat_interleave(num_draft_tokens) repeat_indices = repeat_indices_cpu.to( device=logits.device, non_blocking=True ) logits = self.apply_penalties( logits, sampling_metadata, metadata, repeat_indices, output_token_ids ) # Apply allowed token ids. if sampling_metadata.allowed_token_ids_mask is not None: token_mask = sampling_metadata.allowed_token_ids_mask[repeat_indices] logits.masked_fill_(token_mask, float("-inf")) # Apply bad words exclusion. if bad_words_token_ids := sampling_metadata.bad_words_token_ids: apply_bad_words_with_drafts( logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens ) return logits @staticmethod def apply_penalties( logits: torch.Tensor, sampling_metadata: SamplingMetadata, metadata: SpecDecodeMetadata, repeat_indices: torch.Tensor, output_token_ids: list[list[int]], ) -> torch.Tensor: if sampling_metadata.no_penalties: return logits assert sampling_metadata.prompt_token_ids is not None prompt_token_ids = sampling_metadata.prompt_token_ids[repeat_indices] presence_penalties = sampling_metadata.presence_penalties[repeat_indices] frequency_penalties = sampling_metadata.frequency_penalties[repeat_indices] repetition_penalties = sampling_metadata.repetition_penalties[repeat_indices] logits = apply_all_penalties( logits, prompt_token_ids, presence_penalties, frequency_penalties, repetition_penalties, output_token_ids, ) return logits @staticmethod def _combine_outputs_with_spec_tokens( output_token_ids: list[list[int]], spec_token_ids: list[list[int]] | None = None, ) -> list[list[int]]: if spec_token_ids is None: return output_token_ids result = [] for out, spec in zip(output_token_ids, spec_token_ids): if len(spec) == 0: continue result.append(out) for i in range(len(spec) - 1): result.append([*result[-1], spec[i]]) return result def rejection_sample( # [num_tokens] draft_token_ids: torch.Tensor, # [batch_size] num_draft_tokens: list[int], max_spec_len: int, # [batch_size] cu_num_draft_tokens: torch.Tensor, # [num_tokens, vocab_size] draft_probs: Optional[torch.Tensor], # [num_tokens, vocab_size] target_probs: torch.Tensor, # [num_tokens, vocab_size] target_tokens, # [batch_size, 1] bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: assert draft_token_ids.ndim == 1 assert draft_probs is None or draft_probs.ndim == 3 assert cu_num_draft_tokens.ndim == 1 assert target_probs.ndim == 2 batch_size = len(num_draft_tokens) num_tokens = draft_token_ids.shape[0] vocab_size = target_probs.shape[-1] device = target_probs.device assert draft_token_ids.is_contiguous() assert draft_probs is None or draft_probs.is_contiguous() assert target_probs.is_contiguous() assert bonus_token_ids.is_contiguous() assert target_probs.shape == (num_tokens, vocab_size) # Create output buffer. output_token_ids = torch.full( (batch_size, max_spec_len + 1), dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids. fill_value=PLACEHOLDER_TOKEN_ID, device=device, ) uniform_probs = torch.rand( (num_tokens, ), dtype=torch.float32, device=device, ) uniform_probs = uniform_probs * 0.1 + 0.1 # Rejection sampling for random sampling requests. rejection_random_sample_kernel[(batch_size, )]( output_token_ids, cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, target_tokens, bonus_token_ids, uniform_probs, max_spec_len, vocab_size, NO_DRAFT_PROBS=draft_probs is None, num_warps=1, ) return output_token_ids # NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. @triton.jit(do_not_specialize=["max_spec_len"]) def rejection_random_sample_kernel( output_token_ids_ptr, # [batch_size, max_spec_len + 1] cu_num_draft_tokens_ptr, # [batch_size] draft_token_ids_ptr, # [num_tokens] draft_probs_ptr, # [num_tokens, vocab_size] or None target_probs_ptr, # [num_tokens, vocab_size] target_token_ids_ptr, # [num_tokens, vocab_size] bonus_token_ids_ptr, # [batch_size] uniform_probs_ptr, # [num_tokens] max_spec_len, vocab_size, NO_DRAFT_PROBS: tl.constexpr, ): req_idx = tl.program_id(0) if req_idx == 0: start_idx = 0 else: start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) num_draft_tokens = end_idx - start_idx rejected = False for pos in range(num_draft_tokens): if not rejected: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) if draft_token_id < 0: draft_token_id = 0 if NO_DRAFT_PROBS: draft_prob = 1 else: draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) draft_token_id = draft_token_id.to(tl.int64) target_token_id = tl.load(target_token_ids_ptr + (start_idx + pos)) target_token_id = target_token_id.to(tl.int64) uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) # NOTE(woosuk): While the draft probability should never be 0, # we check it to avoid NaNs. If it happens to be 0, we reject. if (draft_token_id == target_token_id) or (target_prob / draft_prob >= uniform_prob and draft_prob > 0): token_id = draft_token_id else: rejected = True token_id = target_token_id tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id) if not rejected: # If all tokens are accepted, append the bonus token. bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) tl.store( output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, bonus_token_id)