ngram_worker.py 7.18 KB
Newer Older
1
import weakref
2
from typing import List, Optional, Tuple
3
4
5

import torch

6
from vllm.sequence import ExecuteModelRequest, SamplerOutput
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import LoraNotSupportedWorkerBase


class NGramWorker(LoraNotSupportedWorkerBase):
    """NGramWorker provides a light drafter without need for model.

    Current NGramWorker only implement prompt lookup decoding,
    and in future we may also do RAG type drafter and other scenerios
    which don't rely on LLM model to give proposals.
    """

    def __init__(self, *args, **kwargs):
        # Get local_rank/vocab_size from kwargs attribute
        self.local_rank = kwargs["local_rank"]
        self.vocab_size = kwargs["model_config"].get_vocab_size()

        # Lazy initialization list.
        self._proposer: Top1Proposer

    def set_ngram_window_size(self, ngram_prompt_lookup_min: int,
                              ngram_prompt_lookup_max: int):
        # Search valid candidate window between
        # ngram_prompt_lookup_min/ngram_prompt_lookup_max
        self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
        self.ngram_prompt_lookup_min = ngram_prompt_lookup_min

    def init_device(self):
        self.device = torch.device(f"cuda:{self.local_rank}")
        self.load_model = lambda *args, **kwargs: None

        # Current only support Top1Proposer
        self._proposer = Top1Proposer(
41
            weakref.proxy(self),
42
43
44
45
46
47
48
49
            device=self.device,
            vocab_size=self.vocab_size,
        )

    def set_include_gpu_probs_tensor(self):
        # NGram don't need gpu sampler
        pass

50
    def execute_model(self, execute_model_req: ExecuteModelRequest) -> None:
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        """NGram doesn't depend on model execution, just pass this function"""
        pass

    def determine_num_available_blocks(self) -> None:
        """NGram doesn't depend on model execution, no need to check blocks"""
        pass

    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """As there is no cache need to handle, just pass this function"""
        pass

    def get_cache_block_size_bytes(self):
        """Return the size of a cache block in bytes."""
        return 0

    def sampler_output(
        self,
69
        execute_model_req: ExecuteModelRequest,
70
71
72
73
74
75
76
77
        sample_len: int,
    ) -> Tuple[Optional[List[SamplerOutput]], bool]:
        """NGram match algo to pick proposal candidate. Returns the list of
        sampler output, one per SequenceGroupMetadata.

        For ngram worker, we already done needed transposed internal, so the
        indicator pass to sampler_output_to_torch shall be False.
        """
78
        self._raise_if_unsupported(execute_model_req)
79
80

        has_spec_out = False
81
82
83
84
        token_id_list = []
        token_prob_list = []
        for idx, seq_group_metadata in enumerate(
                execute_model_req.seq_group_metadata_list):
85
86
87
88
89
90
91
92
93
            seq_data = next(iter(seq_group_metadata.seq_data.values()))

            input_ids = torch.as_tensor(seq_data.get_token_ids(),
                                        dtype=torch.long,
                                        device=self.device)
            input_length = seq_data.get_len()

            for ngram_size in range(
                    min(self.ngram_prompt_lookup_max, input_length - 1),
94
                    self.ngram_prompt_lookup_min - 1,
95
96
                    -1,
            ):
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
                ngram_tensor = input_ids[-ngram_size:]
                proposal_start_idx = None
                if ngram_size == 1:
                    # Do not match itself and do not use unfold and all
                    matches = (input_ids[:-1] == ngram_tensor)
                else:
                    windows = input_ids.unfold(dimension=0,
                                               size=ngram_size,
                                               step=1)
                    # Do not match itself
                    matches = (windows[:-1] == ngram_tensor).all(dim=-1)

                # first_match includes "values" (bool), indicating whether
                # the match is found, and "indices", indicating the index
                # of the first match.
                # Note that "first_match.values.item()" triggers GPU-CPU
                # sync so it is a bit inefficient, but we have not found
                # a better way to do this.
                first_match = matches.max(dim=-1)
                if first_match.values.item():
                    proposal_start_idx = first_match.indices.add_(ngram_size)
                    spec_indices = (
                        proposal_start_idx).repeat(sample_len) + torch.arange(
                            sample_len, device=self.device)
                    spec_indices.clamp_(max=input_ids.shape[-1] - 1)
                    res = input_ids.gather(dim=-1, index=spec_indices)
                    token_id_list.append(res)
                    token_prob_list.append(
                        torch.nn.functional.one_hot(
                            res,
                            num_classes=self.vocab_size).to(torch.float32))
128
129
130
                    has_spec_out = True
                    break
            else:
131
132
                token_id_list.append(None)
                token_prob_list.append(None)
133
134
135
136

        if not has_spec_out:
            return None, False

137
138
139
140
141
142
143
144
145
146
147
148
149
150
        outputs: List[Optional[SamplerOutput]] = []
        for idx in range(len(execute_model_req.seq_group_metadata_list)):
            if token_id_list[idx] is None:
                outputs.append(None)
            else:
                outputs.append(
                    SamplerOutput(
                        outputs=None,
                        sampled_token_probs=token_prob_list[idx],
                        logprobs=torch.zeros((sample_len, self.vocab_size),
                                             dtype=torch.float32,
                                             device=self.device),
                        sampled_token_ids=token_id_list[idx],
                    ))
151
152
153
154
155

        return outputs, False

    def get_spec_proposals(
        self,
156
        execute_model_req: ExecuteModelRequest,
157
158
159
160
161
    ) -> SpeculativeProposals:
        """Produce speculations given an input batch of sequences. The number of
        speculative tokens per sequence is determined by max_proposal_len.
        """

162
        return self._proposer.get_proposals(execute_model_req)
163
164
165

    def _raise_if_unsupported(
        self,
166
        execute_model_req: ExecuteModelRequest,
167
168
169
170
    ) -> None:
        """NGramWorker does not yet implement support for cache swap
        operations or beam search.
        """
171
172
173
174
175
        if any([
                execute_model_req.blocks_to_swap_in,
                execute_model_req.blocks_to_swap_out,
                execute_model_req.blocks_to_copy
        ]):
176
177
178
179
180
            raise NotImplementedError(
                "NGramWorker does not support cache operations")

        if any(
                len(seq_group_metadata.seq_data.keys()) != 1
181
182
                for seq_group_metadata in
                execute_model_req.seq_group_metadata_list):
183
184
            raise NotImplementedError(
                "NGramWorker does not support beam search.")