# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional import torch import torch.nn as nn from vllm.logger import init_logger from vllm.triton_utils import tl, triton from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.spec_decode.metadata import SpecDecodeMetadata logger = init_logger(__name__) PLACEHOLDER_TOKEN_ID: tl.constexpr = -1 GREEDY_TEMPERATURE: tl.constexpr = -1 # Maximum number of speculative draft tokens allowed per request in a single # step. This value is chosen to be large enough to handle typical use cases. MAX_SPEC_LEN = 32 class OptRejectionSampler(nn.Module): """ The implementation strictly follows the algorithm described in https://arxiv.org/abs/2211.17192. However, we want to clarify the terminology used in the implementation: accepted tokens: tokens that are accepted based on the relationship between the "raw" draft and target probabilities. recovered tokens: tokens that are sampled based on the adjusted probability distribution, which is derived from both the draft and target probabilities. bonus tokens: If all proposed tokens are accepted, the bonus token is added to the end of the sequence. The bonus token is only sampled from the target probabilities. We pass in the bonus tokens instead of sampling them in the rejection sampler to allow for more flexibility in the sampling process. For example, we can use top_p, top_k sampling for bonus tokens, while spec decode does not support these sampling strategies. output tokens: Tokens are finally generated with the rejection sampler. output tokens = accepted tokens + recovered tokens + bonus tokens """ def forward( self, metadata: SpecDecodeMetadata, # [num_tokens, vocab_size] draft_probs: Optional[torch.Tensor], # [num_tokens, vocab_size] target_logits: torch.Tensor, # [num_tokens, vocab_size] target_tokens: torch.Tensor, # [batch_size, 1] bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: ''' Args: metadata: Metadata for spec decoding. draft_probs (Optional[torch.Tensor]): Probability distribution for the draft tokens. Shape is [num_tokens, vocab_size]. Can be None if probabilities are not provided, which is the case for ngram spec decode. target_logits (torch.Tensor): Target model's logits probability distribution. Shape is [num_tokens, vocab_size]. Here, probabilities from different requests are flattened into a single tensor because this is the shape of the output logits. NOTE: `target_logits` can be updated in place to save memory. bonus_token_ids_tensor (torch.Tensor): A tensor containing bonus tokens. Shape is [batch_size, 1]. Bonus tokens are added to the end of the sequence if all proposed tokens are accepted. We generate the bonus tokens outside of the rejection sampler with the default sampling strategy. It allows for more flexibility in the sampling process such as top_p, top_k sampling. sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata): Additional metadata needed for sampling, such as temperature, top-k/top-p parameters, or other relevant information. Returns: output_token_ids (torch.Tensor): A tensor containing the final output token IDs. ''' assert metadata.max_spec_len <= MAX_SPEC_LEN target_probs = target_logits.softmax(dim=-1, dtype=torch.float32) draft_token_ids = metadata.draft_token_ids mask = draft_token_ids.eq(-1).to(torch.bool) draft_token_ids = torch.where(mask, 0, draft_token_ids).to(torch.long) # 兼容第一次decode output_token_ids = rejection_sample( draft_token_ids, metadata.num_draft_tokens, metadata.max_spec_len, metadata.cu_num_draft_tokens, draft_probs, target_probs, target_tokens, bonus_token_ids, sampling_metadata, ) return output_token_ids @staticmethod def parse_output( output_token_ids: torch.Tensor, vocab_size: int, ) -> list[list[int]]: """Parse the output of the rejection sampler. Args: output_token_ids: The sampled token IDs in shape [batch_size, max_spec_len + 1]. The rejected tokens are replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler and will be filtered out in this function. vocab_size: The size of the vocabulary. Returns: A list of lists of token IDs. """ output_token_ids_np = output_token_ids.cpu().numpy() # Create mask for valid tokens. valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (output_token_ids_np < vocab_size)) outputs = [ row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np) ] return outputs def rejection_sample( # [num_tokens] draft_token_ids: torch.Tensor, # [batch_size] num_draft_tokens: list[int], max_spec_len: int, # [batch_size] cu_num_draft_tokens: torch.Tensor, # [num_tokens, vocab_size] draft_probs: Optional[torch.Tensor], # [num_tokens, vocab_size] target_probs: torch.Tensor, # [num_tokens, vocab_size] target_tokens, # [batch_size, 1] bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: assert draft_token_ids.ndim == 1 assert draft_probs is None or draft_probs.ndim == 3 assert cu_num_draft_tokens.ndim == 1 assert target_probs.ndim == 2 batch_size = len(num_draft_tokens) num_tokens = draft_token_ids.shape[0] vocab_size = target_probs.shape[-1] device = target_probs.device assert draft_token_ids.is_contiguous() assert draft_probs is None or draft_probs.is_contiguous() assert target_probs.is_contiguous() assert bonus_token_ids.is_contiguous() assert target_probs.shape == (num_tokens, vocab_size) # Create output buffer. output_token_ids = torch.full( (batch_size, max_spec_len + 1), dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids. fill_value=PLACEHOLDER_TOKEN_ID, device=device, ) uniform_probs = torch.rand( (num_tokens, ), dtype=torch.float32, device=device, ) uniform_probs = uniform_probs * 0.1 + 0.1 # Rejection sampling for random sampling requests. rejection_random_sample_kernel[(batch_size, )]( output_token_ids, cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, target_tokens, bonus_token_ids, uniform_probs, max_spec_len, vocab_size, NO_DRAFT_PROBS=draft_probs is None, num_warps=1, ) return output_token_ids # NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. @triton.jit(do_not_specialize=["max_spec_len"]) def rejection_random_sample_kernel( output_token_ids_ptr, # [batch_size, max_spec_len + 1] cu_num_draft_tokens_ptr, # [batch_size] draft_token_ids_ptr, # [num_tokens] draft_probs_ptr, # [num_tokens, vocab_size] or None target_probs_ptr, # [num_tokens, vocab_size] target_token_ids_ptr, # [num_tokens, vocab_size] bonus_token_ids_ptr, # [batch_size] uniform_probs_ptr, # [num_tokens] max_spec_len, vocab_size, NO_DRAFT_PROBS: tl.constexpr, ): req_idx = tl.program_id(0) if req_idx == 0: start_idx = 0 else: start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) num_draft_tokens = end_idx - start_idx rejected = False for pos in range(num_draft_tokens): if not rejected: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) if NO_DRAFT_PROBS: draft_prob = 1 else: draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) target_token_id = tl.load(target_token_ids_ptr + (start_idx + pos)) target_token_id = target_token_id.to(tl.int64) uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) # NOTE(woosuk): While the draft probability should never be 0, # we check it to avoid NaNs. If it happens to be 0, we reject. if (draft_token_id == target_token_id) or (target_prob / draft_prob >= uniform_prob and draft_prob > 0): token_id = draft_token_id else: rejected = True token_id = target_token_id tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id) if not rejected: # If all tokens are accepted, append the bonus token. bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) tl.store( output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, bonus_token_id)