ngram_proposer.py 10.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4

5
import numpy as np
6
import torch
7
from numba import get_num_threads, jit, njit, prange, set_num_threads
8

9
10
from vllm.config import VllmConfig

11
12

class NgramProposer:
13
    def __init__(self, vllm_config: VllmConfig):
14
15
16
17
        assert vllm_config.speculative_config is not None
        assert vllm_config.speculative_config.prompt_lookup_min is not None
        assert vllm_config.speculative_config.prompt_lookup_max is not None

18
19
20
21
22
23
24
25
        # Minimum length of the n-gram to match.
        self.min_n = vllm_config.speculative_config.prompt_lookup_min
        # Maximum length of the n-gram to match.
        self.max_n = vllm_config.speculative_config.prompt_lookup_max
        # Number of tokens follow the match. If there are less than k
        # tokens follow the match, we will return the maximum amount of
        # tokens until the end.
        self.k = vllm_config.speculative_config.num_speculative_tokens
26
27
28
        # Maximum length of the model.
        self.max_model_len = vllm_config.model_config.max_model_len

29
30
        # Pre-allocate buffers for numba batch propose.
        max_num_seqs = vllm_config.scheduler_config.max_num_seqs
31
        self.valid_ngram_draft = np.zeros((max_num_seqs, self.k), dtype=np.int32)
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
        self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32)

        # Threshold of total number of tokens in the batch to enable
        # multi-threading in numba batch propose.
        self.num_tokens_threshold = 8192
        tp_size = vllm_config.parallel_config.tensor_parallel_size
        cpu_count = os.cpu_count()
        # Max number of threads for numba parallel processing.
        if cpu_count:
            # Divide by 2 to use physical cores
            # and not logical cores (hyper-threading).
            # Cap the number of threads to 8 to avoid using too many threads
            # since other components like frontend (incl tokenization)
            # and Structured Outputs also use multiple threads.
            # TODO(ekagra-ranjan): bump up the cap from 1 to 8
            # when TP parallelization for ngram is implemented.
            self.num_numba_thread_available = min(1, (cpu_count // 2))
            # Divide by tp_size to ensure each tensor parallel rank
            # has some threads since all ranks will run this.
            self.num_numba_thread_available //= tp_size
        else:
            self.num_numba_thread_available = 1

55
56
        # Trigger Numba JIT compilation for N-gram proposer.
        # This usually takes less than 1 second.
57
        self.propose(
58
            [[]] * 1024,
59
60
61
            np.zeros(1024, dtype=np.int32),
            np.zeros((1024, self.max_model_len), dtype=np.int32),
        )
62

63
    def batch_propose(
64
        self,
65
66
67
68
69
70
        num_requests: int,
        valid_ngram_requests: list,
        num_tokens_no_spec: np.ndarray,
        token_ids_cpu: np.ndarray,
    ) -> list[list[int]]:
        """Batch version of ngram proposer using numba for acceleration.
71

72
        Args:
73
            valid_ngram_requests:
74
                Set of indices of requests that need ngram proposals.
75
76
            num_tokens_no_spec:
                Numpy array of shape (batch_size,) representing the number
77
                of tokens without speculative tokens for each request.
78
79
            token_ids_cpu:
                Numpy array of shape (batch_size, max_model_len)
80
                representing the token IDs for each request.
81

82
        Returns:
83
84
            list[list[int]]:
                A list where each element is a list of proposed
85
                token IDs for the corresponding request.
86
        """
87
88
89
90
91
92
93
94
95
96
97
98
99
        draft_token_ids: list[list[int]] = []

        # Only run batch propose if there are requests needing ngram proposals.
        # avoid calling numba function with empty list which causes error
        # ValueError: cannot compute fingerprint of empty list
        if num_ngram_requests := len(valid_ngram_requests):
            original_num_numba_threads = get_num_threads()
            # Ensure we use at least one thread.
            # If total tokens is small, using multiple threads
            # may slow down due to overhead.
            total_tokens = np.sum(num_tokens_no_spec)
            if total_tokens >= self.num_tokens_threshold:
                final_num_threads = max(
100
101
                    1, min(self.num_numba_thread_available, num_ngram_requests)
                )
102
103
104
105
                set_num_threads(final_num_threads)
            else:
                set_num_threads(1)

106
107
108
109
110
111
112
113
114
115
116
            batch_propose_numba(
                valid_ngram_requests,
                num_tokens_no_spec,
                token_ids_cpu,
                self.min_n,
                self.max_n,
                self.max_model_len,
                self.k,
                self.valid_ngram_draft,
                self.valid_ngram_num_drafts,
            )
117
118
119
120
121

            # Restore original number of threads.
            set_num_threads(original_num_numba_threads)

        for i in range(num_requests):
122
123
124
125
            if i in valid_ngram_requests and self.valid_ngram_num_drafts[i] > 0:
                draft_token_ids.append(
                    self.valid_ngram_draft[i, : self.valid_ngram_num_drafts[i]].tolist()
                )
126
127
128
129
130
131
132
            else:
                draft_token_ids.append([])

        return draft_token_ids

    def propose(
        self,
133
        sampled_token_ids: list[list[int]],
134
135
        num_tokens_no_spec: np.ndarray,
        token_ids_cpu: np.ndarray,
136
137
138
        slot_mappings: dict[str, torch.Tensor]
        | list[dict[str, torch.Tensor]]
        | None = None,  # unused
139
140
141
142
    ) -> list[list[int]]:
        # find which requests need ngram proposals
        valid_ngram_requests = []
        for i, sampled_ids in enumerate(sampled_token_ids):
143
            num_sampled_ids = len(sampled_ids)
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
            if not num_sampled_ids:
                # Skip speculative decoding.
                continue

            num_tokens = num_tokens_no_spec[i]
            if num_tokens >= self.max_model_len:
                # Skip requests that have already reached the max model length.
                continue

            valid_ngram_requests.append(i)

        draft_token_ids = self.batch_propose(
            len(sampled_token_ids),
            valid_ngram_requests,
            num_tokens_no_spec,
            token_ids_cpu,
        )

        return draft_token_ids
163

164
165
166
167
    def load_model(self, *args, **kwargs):
        # No model to load.
        pass

168

169
@njit(parallel=True)
170
171
172
173
174
175
176
177
178
179
180
def batch_propose_numba(
    valid_ngram_requests: list,
    num_tokens_no_spec: np.ndarray,
    token_ids_cpu: np.ndarray,
    min_n: int,
    max_n: int,
    max_model_len: int,
    k: int,
    valid_ngram_draft: np.ndarray,
    valid_ngram_num_drafts: np.ndarray,
):
181
182
183
184
185
186
187
188
189
    for i in prange(len(valid_ngram_requests)):
        idx = valid_ngram_requests[i]
        num_tokens = num_tokens_no_spec[idx]
        context_token_ids = token_ids_cpu[idx, :num_tokens]
        drafter_output = _find_longest_matched_ngram_and_propose_tokens(
            origin_tokens=context_token_ids,
            min_ngram=min_n,
            max_ngram=max_n,
            max_model_len=max_model_len,
190
191
            k=k,
        )
192

193
        valid_ngram_num_drafts[idx] = drafter_output.shape[0]
194
        if len(drafter_output):
195
            valid_ngram_draft[idx, : drafter_output.shape[0]] = drafter_output
196
197


198
@jit(nopython=True)
199
200
201
202
203
204
205
def _find_longest_matched_ngram_and_propose_tokens(
    origin_tokens: np.ndarray,
    min_ngram: int,
    max_ngram: int,
    max_model_len: int,
    k: int,
) -> np.ndarray:
206
    """
207
208
209
210
    Find the longest n-gram which matches the suffix of the given tokens
    whose length is within [min_ngram, max_ngram] (inclusive).

    If found, we will extract k right after the matched ngram.
211
    """
212
213
214
    # Do not generate draft tokens is context is shorter than minimum n-gram
    total_token = origin_tokens.shape[0]
    if total_token < min_ngram:
215
        return np.empty((0,), dtype=origin_tokens.dtype)
216
217
218
219

    # Do not generate draft tokens beyond the max model length.
    k = min(k, max_model_len - total_token)
    if k <= 0:
220
        return np.empty((0,), dtype=origin_tokens.dtype)
221
222
223
224
225

    # Flip tokens, and the goal become to find longest ngram
    # on the rightmost position which matches the prefix with
    # length [min_n, max_n] (inclusive).
    tokens = origin_tokens[::-1]
226

227
228
229
230
231
232
233
234
235
236
237
    # Longest prefix (not including itself) which is a suffix of
    # the current position.
    #   lps[i] = max{v, where tokens[0:v] == tokens[i+1-v:i+1]}
    #
    # As ngram is capped by max_ngram to save memory, we only need to
    # store lps for the first max_ngram prefix.
    lps = np.zeros(max_ngram, dtype=np.int32)

    longest_ngram = 0
    position = 0

238
    # lps[0] always equal to 0, we start with index 1
239
240
241
242
243
244
245
    prev_lps = 0
    i = 1
    while i < total_token:
        # tokens[:prev_lps] is the longest prefix as a suffix of tokens[:i]
        if tokens[prev_lps] == tokens[i]:
            # Token match: tokens[:prev_lps+1] is the longest prefix as
            # a suffix of tokens[:i+1]
246
            prev_lps += 1
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
            # Check if we found a longer valid ngram.
            #
            # Update position when longest_ngram matched prev_lps,
            # as we want to get the target n-gram of the earliest position
            # in the original tokens (i.e.
            # latest position in the reversed tokens)
            if prev_lps >= longest_ngram:
                longest_ngram = prev_lps
                position = i
            if i < max_ngram:
                # Store LPS for the first max_ngram prefix
                lps[i] = prev_lps
            if prev_lps == max_ngram:
                # When prev_lps reached max_ngram, update prev_lps
                # to lps[max_ngram-1] to avoid matching ngram
                # longer than max_ngram
                prev_lps = lps[max_ngram - 1]
264
            i += 1
265
        elif prev_lps != 0:
266
            # Token mismatch: try the second-longest prefix
267
268
269
            # among all suffix of tokens[:i],
            # which is the longest prefix of tokens[:prev_lps]
            prev_lps = lps[prev_lps - 1]
270
        else:
271
272
            # Token mismatch, and no more prefix (except empty string)
            # as a suffix of tokens[:i]
273
            i += 1
274

275
276
    if longest_ngram < min_ngram:
        # No valid ngram is found
277
        return np.empty((0,), dtype=origin_tokens.dtype)
278
279
280
281
282
283
284

    # Flip the position back, so in origin_tokens,
    # origin_tokens[total_token-1-position:total_token-1-position+longest_ngram]
    # is the matched ngram, so we should start drafting tokens from
    # total_token-1-position+longest_ngram
    start_position = total_token - 1 - position + longest_ngram
    k = min(k, total_token - start_position)
285
    return origin_tokens[start_position : start_position + k]