# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import msgspec from abc import ABC import torch from vllm.sampling_params import SamplingParams from vllm.triton_utils import tl, triton from vllm.utils import async_tensor_h2d _SAMPLING_EPS = 1e-5 def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool: """True if request is incompatible with speculative decoding""" return (sampling_params.frequency_penalty != 0.0 or sampling_params.presence_penalty != 0.0 or sampling_params.repetition_penalty != 1.0 or sampling_params.min_p > _SAMPLING_EPS or sampling_params.logprobs is not None) @triton.jit def prepare_eagle_input_kernel( out_ptr, cu_query_lens_ptr, cu_num_tokens_ptr, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(0) # [start_pos, end_pos) start_pos = tl.load(cu_num_tokens_ptr + pid) end_pos = tl.load(cu_num_tokens_ptr + pid + 1) num_tokens = end_pos - start_pos index_start = tl.load(cu_query_lens_ptr + pid) num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE) for i in tl.range(num_blocks): offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) tl.store( out_ptr + start_pos + offset, index_start + offset, mask=offset < num_tokens, ) class DraftProbs(ABC): # type: ignore[call-arg] """Draft probs corresponding to in-progress sequences.""" # spec tokens probs. draft_probs: torch.Tensor # The request id list. _req_ids: list[str] def __init__(self, draft_probs, req_ids): assert len(req_ids) == len(draft_probs) self.draft_probs = draft_probs self._req_ids = req_ids def update(self, draft_probs: torch.Tensor, tmp_req_ids: list[str]): diff_req_ids = [item for item in self._req_ids if item not in tmp_req_ids] index = [self._req_ids.index(req_id) for req_id in diff_req_ids] self._req_ids = diff_req_ids self.draft_probs = self.draft_probs[index] self.draft_probs = torch.cat([self.draft_probs, draft_probs]) self._req_ids.extend(tmp_req_ids) assert len(self._req_ids) == len(self.draft_probs) def prune(self, req_ids: list[str]): new_req_ids = [req_id for req_id in self._req_ids if req_id not in req_ids] if new_req_ids != self._req_ids: # Batch contents changed - prune removed sequences. index = [self._req_ids.index(req_id) for req_id in new_req_ids] self.draft_probs = self.draft_probs[index] self._req_ids = new_req_ids def get_probs(self, req_ids: list[str]): index = [self._req_ids.index(req_id) for req_id in req_ids] index_tensor = async_tensor_h2d( index, dtype=torch.int32, target_device=self.draft_probs.device, pin_memory=True) return self.draft_probs[index_tensor]