ngram_worker.py 7.71 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import weakref
5
from typing import List, Optional, Set, Tuple
6
7

import torch
8
import torch.nn as nn
9

10
from vllm.config import VllmConfig
11
12
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
13
from vllm.spec_decode.interfaces import SpeculativeProposals
14
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
15
16
17
from vllm.spec_decode.top1_proposer import Top1Proposer


18
19
20
21
class _DummyModel(nn.Module):
    pass


22
class NGramWorker(NonLLMProposerWorkerBase):
23
24
    """NGramWorker provides a light drafter without need for model.

25
    Current NGramWorker only implements prompt lookup decoding,
26
    and in future we may also do RAG type drafter and other scenarios
27
28
29
    which don't rely on LLM model to give proposals.
    """

30
31
32
33
34
35
36
37
38
    def __init__(
        self,
        vllm_config: VllmConfig,
        local_rank: int,
        device_type: str = "cuda",
        **kwargs,
    ):
        super().__init__(vllm_config)

39
        # Get local_rank/vocab_size from kwargs attribute
40
41
        self.local_rank = local_rank
        self.device_type = device_type
42
43
44
45
46
47
48
49
50
51
52
53

        # Lazy initialization list.
        self._proposer: Top1Proposer

    def set_ngram_window_size(self, ngram_prompt_lookup_min: int,
                              ngram_prompt_lookup_max: int):
        # Search valid candidate window between
        # ngram_prompt_lookup_min/ngram_prompt_lookup_max
        self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
        self.ngram_prompt_lookup_min = ngram_prompt_lookup_min

    def init_device(self):
54
        self.device = torch.device(f"{self.device_type}:{self.local_rank}")
55

56
        # Current NGramWorker only supports Top1Proposer
57
        self._proposer = Top1Proposer(
58
            weakref.proxy(self),  # type: ignore[arg-type]
59
60
61
62
            device=self.device,
            vocab_size=self.vocab_size,
        )

63
64
65
66
67
68
    def load_model(self) -> None:
        pass  # Dummy

    def get_model(self) -> nn.Module:
        return _DummyModel()

69
70
    def sampler_output(
        self,
71
        execute_model_req: ExecuteModelRequest,
72
        sample_len: int,
73
74
75
        # Unused parameter. NGramWorker does not use the KV Cache and
        # therefore does not need this parameter.
        seq_ids_with_bonus_token_in_last_step: Set[int],
76
    ) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]:
77
78
79
80
81
82
        """NGram match algo to pick proposal candidate. Returns the list of
        sampler output, one per SequenceGroupMetadata.

        For ngram worker, we already done needed transposed internal, so the
        indicator pass to sampler_output_to_torch shall be False.
        """
83
        self._raise_if_unsupported(execute_model_req)
84
85

        has_spec_out = False
86
87
        token_id_list: List[Optional[torch.Tensor]] = []
        token_prob_list: List[Optional[torch.Tensor]] = []
88
89
        for idx, seq_group_metadata in enumerate(
                execute_model_req.seq_group_metadata_list):
90
91
            seq_data = next(iter(seq_group_metadata.seq_data.values()))

92
93
94
95
96
97
98
            seq_len = seq_data.get_len()
            # When seq_len is less than 3072 (3K), we use CPU to perform
            # the ngram match. Otherwise, we use the device specified in
            # the model config (normally GPU). 3072 is a rough threshold
            # based on profiling on H100, and it can be adjusted based
            # on the actual performance on different hardware.
            cur_device = "cpu" if seq_len < 3072 else self.device
99
100
            input_ids = torch.as_tensor(seq_data.get_token_ids(),
                                        dtype=torch.long,
101
                                        device=cur_device)
102
103
104
105
            input_length = seq_data.get_len()

            for ngram_size in range(
                    min(self.ngram_prompt_lookup_max, input_length - 1),
106
                    self.ngram_prompt_lookup_min - 1,
107
108
                    -1,
            ):
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
                ngram_tensor = input_ids[-ngram_size:]
                if ngram_size == 1:
                    # Do not match itself and do not use unfold and all
                    matches = (input_ids[:-1] == ngram_tensor)
                else:
                    windows = input_ids.unfold(dimension=0,
                                               size=ngram_size,
                                               step=1)
                    # Do not match itself
                    matches = (windows[:-1] == ngram_tensor).all(dim=-1)

                # first_match includes "values" (bool), indicating whether
                # the match is found, and "indices", indicating the index
                # of the first match.
                first_match = matches.max(dim=-1)
                if first_match.values.item():
                    proposal_start_idx = first_match.indices.add_(ngram_size)
                    spec_indices = (
                        proposal_start_idx).repeat(sample_len) + torch.arange(
128
                            sample_len, device=cur_device)
129
                    spec_indices.clamp_(max=input_ids.shape[-1] - 1)
130
131
                    res = input_ids.gather(dim=-1,
                                           index=spec_indices).to(self.device)
132
133
134
135
136
                    token_id_list.append(res)
                    token_prob_list.append(
                        torch.nn.functional.one_hot(
                            res,
                            num_classes=self.vocab_size).to(torch.float32))
137
138
139
                    has_spec_out = True
                    break
            else:
140
141
                token_id_list.append(None)
                token_prob_list.append(None)
142
143
144
145

        if not has_spec_out:
            return None, False

146
147
148
149
150
151
152
153
154
155
156
157
158
159
        outputs: List[Optional[SamplerOutput]] = []
        for idx in range(len(execute_model_req.seq_group_metadata_list)):
            if token_id_list[idx] is None:
                outputs.append(None)
            else:
                outputs.append(
                    SamplerOutput(
                        outputs=None,
                        sampled_token_probs=token_prob_list[idx],
                        logprobs=torch.zeros((sample_len, self.vocab_size),
                                             dtype=torch.float32,
                                             device=self.device),
                        sampled_token_ids=token_id_list[idx],
                    ))
160
161
162
163
164

        return outputs, False

    def get_spec_proposals(
        self,
165
        execute_model_req: ExecuteModelRequest,
166
167
168
        # Unused parameter. NGramWorker does not use the KV Cache and
        # therefore does not need this parameter.
        seq_ids_with_bonus_token_in_last_step: Set[int],
169
170
171
172
    ) -> SpeculativeProposals:
        """Produce speculations given an input batch of sequences. The number of
        speculative tokens per sequence is determined by max_proposal_len.
        """
173
174
        return self._proposer.get_spec_proposals(
            execute_model_req, seq_ids_with_bonus_token_in_last_step)
175
176
177

    def _raise_if_unsupported(
        self,
178
        execute_model_req: ExecuteModelRequest,
179
180
181
182
    ) -> None:
        """NGramWorker does not yet implement support for cache swap
        operations or beam search.
        """
183
184
185
186
187
        if any([
                execute_model_req.blocks_to_swap_in,
                execute_model_req.blocks_to_swap_out,
                execute_model_req.blocks_to_copy
        ]):
188
189
190
191
192
            raise NotImplementedError(
                "NGramWorker does not support cache operations")

        if any(
                len(seq_group_metadata.seq_data.keys()) != 1
193
194
                for seq_group_metadata in
                execute_model_req.seq_group_metadata_list):
195
196
            raise NotImplementedError(
                "NGramWorker does not support beam search.")