interfaces.py 3.49 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import ABC, abstractmethod
4
from dataclasses import dataclass
5
from typing import List, Optional, Set, Union
6
7
8

import torch

9
from vllm.sequence import ExecuteModelRequest, PromptLogprobs
10
from vllm.worker.worker_base import WorkerBase
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27


@dataclass
class SpeculativeProposals:
    """Datastructure used to represent proposal tokens from some proposer. It
    also tracks how many speculative tokens each sequence has.
    """

    # Speculative proposal tokens.
    proposal_token_ids: torch.Tensor

    # Probabilities of the proposal tokens according to the proposer.
    proposal_probs: torch.Tensor

    # The valid length of each proposal; can be zero.
    proposal_lens: torch.Tensor

28
29
30
    # A flag to mark that there's no available proposals
    no_proposals: bool = False

31
32
33
34
35
36
37
38
39
40
41
42
43
    # The cart_candidates used in tree-style generation
    cart_candidates: Optional[torch.Tensor] = None

    # The cart_candidates used in tree-style generation
    retrieve_indices: Optional[torch.Tensor] = None

    # tree-style attention masks
    tree_attn_masks: Optional[torch.Tensor] = None

    # tree-style cartesian candidates
    tree_position_ids: Optional[torch.Tensor] = None


44
45
    def __repr__(self):
        return (f"SpeculativeProposals("
46
                f"proposal_token_ids={self.proposal_token_ids}, "
47
                f"proposal_probs={self.proposal_probs.shape}, "
48
                f"proposal_lens={self.proposal_lens})")
49
50
51
52
53
54
55
56
57
58
59


@dataclass
class SpeculativeScores:
    """Datastructure used to represent the scores of speculative tokens
    according to the scoring model.
    """

    # Probabilities of the speculative tokens according to the scoring model.
    probs: torch.Tensor

60
61
62
63
64
    # Log-probabilities of the speculative tokens according to the scoring
    # model. These values can be used to generate Logprob objects that are
    # returned to the user.
    logprobs: torch.Tensor

65
66
67
68
    # Token ids sampled from the scoring model. Used for speculative bonus
    # tokens and also non-speculative normal decoding.
    token_ids: torch.Tensor

69
70
71
    # Optional last hidden states from the scoring model.
    hidden_states: Optional[torch.Tensor] = None

72
73
74
    # Optional lm_head logits from the scoring model.
    logits: Optional[torch.Tensor] = None

75
76
77
78
    # Scoring model may also return logprobs for prompt tokens
    # for each request, when chunked prefill is enabled.
    prompt_logprobs: Optional[List[PromptLogprobs]] = None

79
80
81
82
83
84
85
86
87
    def __repr__(self):
        return (f"SpeculativeScores("
                f"probs={self.probs.shape}, "
                f"token_ids={self.token_ids.shape})")


class SpeculativeProposer(ABC):

    @abstractmethod
88
    def get_spec_proposals(
89
        self,
90
        execute_model_req: ExecuteModelRequest,
91
92
93
        # If set, this contains all sequence IDs that were assigned
        # bonus tokens in their last forward pass.
        seq_ids_with_bonus_token_in_last_step: Set[int],
94
95
96
97
98
99
    ) -> SpeculativeProposals:
        raise NotImplementedError


class SpeculativeScorer(ABC):

100
101
    def __init__(self, scorer_worker: WorkerBase,
                 device: Union[torch.device, str], vocab_size: int):
102
        self._scorer_worker = scorer_worker
103
104
        if isinstance(device, torch.device):
            device = device.type
105
106
107
        self._device = device
        self._vocab_size = vocab_size

108
109
110
    @abstractmethod
    def score_proposals(
        self,
111
        execute_model_req: ExecuteModelRequest,
112
        proposals: SpeculativeProposals,
113
    ) -> SpeculativeScores:
114
        raise NotImplementedError