interfaces.py 2.02 KB
Newer Older
1
from abc import ABC, abstractmethod
2
from dataclasses import dataclass
3
4
5

import torch

6
from vllm.sequence import ExecuteModelRequest
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25


@dataclass
class SpeculativeProposals:
    """Datastructure used to represent proposal tokens from some proposer. It
    also tracks how many speculative tokens each sequence has.
    """

    # Speculative proposal tokens.
    proposal_token_ids: torch.Tensor

    # Probabilities of the proposal tokens according to the proposer.
    proposal_probs: torch.Tensor

    # The valid length of each proposal; can be zero.
    proposal_lens: torch.Tensor

    def __repr__(self):
        return (f"SpeculativeProposals("
26
                f"proposal_token_ids={self.proposal_token_ids}, "
27
                f"proposal_probs={self.proposal_probs.shape}, "
28
                f"proposal_lens={self.proposal_lens})")
29
30
31
32
33
34
35
36
37
38
39


@dataclass
class SpeculativeScores:
    """Datastructure used to represent the scores of speculative tokens
    according to the scoring model.
    """

    # Probabilities of the speculative tokens according to the scoring model.
    probs: torch.Tensor

40
41
42
43
44
    # Log-probabilities of the speculative tokens according to the scoring
    # model. These values can be used to generate Logprob objects that are
    # returned to the user.
    logprobs: torch.Tensor

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    # Token ids sampled from the scoring model. Used for speculative bonus
    # tokens and also non-speculative normal decoding.
    token_ids: torch.Tensor

    def __repr__(self):
        return (f"SpeculativeScores("
                f"probs={self.probs.shape}, "
                f"token_ids={self.token_ids.shape})")


class SpeculativeProposer(ABC):

    @abstractmethod
    def get_proposals(
        self,
60
        execute_model_req: ExecuteModelRequest,
61
62
63
64
65
66
67
68
69
    ) -> SpeculativeProposals:
        raise NotImplementedError


class SpeculativeScorer(ABC):

    @abstractmethod
    def score_proposals(
        self,
70
        execute_model_req: ExecuteModelRequest,
71
        proposals: SpeculativeProposals,
72
    ) -> SpeculativeScores:
73
        raise NotImplementedError