# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import msgspec from abc import ABC import torch from vllm.sampling_params import SamplingParams _SAMPLING_EPS = 1e-5 def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool: """True if request is incompatible with speculative decoding""" return (sampling_params.frequency_penalty != 0.0 or sampling_params.presence_penalty != 0.0 or sampling_params.repetition_penalty != 1.0 or sampling_params.min_p > _SAMPLING_EPS or sampling_params.logprobs is not None) class DraftProbs(ABC): # type: ignore[call-arg] """Draft probs corresponding to in-progress sequences.""" # spec tokens probs. draft_probs: torch.Tensor # The request id list. _req_ids: list[str] def __init__(self, draft_probs, req_ids): assert len(req_ids) == len(draft_probs) self.draft_probs = draft_probs self._req_ids = req_ids def update(self, draft_probs: torch.Tensor, tmp_req_ids: list[str]): diff_req_ids = [item for item in self._req_ids if item not in tmp_req_ids] index = [self._req_ids.index(req_id) for req_id in diff_req_ids] self._req_ids = diff_req_ids self.draft_probs = self.draft_probs[index] self.draft_probs = torch.cat([self.draft_probs, draft_probs]) self._req_ids.extend(tmp_req_ids) assert len(self._req_ids) == len(self.draft_probs) def prune(self, req_ids: list[str]): new_req_ids = [req_id for req_id in self._req_ids if req_id not in req_ids] if new_req_ids != self._req_ids: # Batch contents changed - prune removed sequences. index = [self._req_ids.index(req_id) for req_id in new_req_ids] self.draft_probs = self.draft_probs[index] self._req_ids = new_req_ids def get_probs(self, req_ids: list[str]): index = [self._req_ids.index(req_id) for req_id in req_ids] return self.draft_probs[index]