ngram_proposer.py 11.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4

5
import numpy as np
6
from numba import get_num_threads, jit, njit, prange, set_num_threads
7

8
9
from vllm.config import VllmConfig

10
11
12

class NgramProposer:

13
    def __init__(self, vllm_config: VllmConfig):
14
15
16
17
        assert vllm_config.speculative_config is not None
        assert vllm_config.speculative_config.prompt_lookup_min is not None
        assert vllm_config.speculative_config.prompt_lookup_max is not None

18
19
20
21
22
23
24
25
        # Minimum length of the n-gram to match.
        self.min_n = vllm_config.speculative_config.prompt_lookup_min
        # Maximum length of the n-gram to match.
        self.max_n = vllm_config.speculative_config.prompt_lookup_max
        # Number of tokens follow the match. If there are less than k
        # tokens follow the match, we will return the maximum amount of
        # tokens until the end.
        self.k = vllm_config.speculative_config.num_speculative_tokens
26
27
28
        # Maximum length of the model.
        self.max_model_len = vllm_config.model_config.max_model_len

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        # Pre-allocate buffers for numba batch propose.
        max_num_seqs = vllm_config.scheduler_config.max_num_seqs
        self.valid_ngram_draft = np.zeros((max_num_seqs, self.k),
                                          dtype=np.int32)
        self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32)

        # Threshold of total number of tokens in the batch to enable
        # multi-threading in numba batch propose.
        self.num_tokens_threshold = 8192
        tp_size = vllm_config.parallel_config.tensor_parallel_size
        cpu_count = os.cpu_count()
        # Max number of threads for numba parallel processing.
        if cpu_count:
            # Divide by 2 to use physical cores
            # and not logical cores (hyper-threading).
            # Cap the number of threads to 8 to avoid using too many threads
            # since other components like frontend (incl tokenization)
            # and Structured Outputs also use multiple threads.
            # TODO(ekagra-ranjan): bump up the cap from 1 to 8
            # when TP parallelization for ngram is implemented.
            self.num_numba_thread_available = min(1, (cpu_count // 2))
            # Divide by tp_size to ensure each tensor parallel rank
            # has some threads since all ranks will run this.
            self.num_numba_thread_available //= tp_size
        else:
            self.num_numba_thread_available = 1

56
57
        # Trigger Numba JIT compilation for N-gram proposer.
        # This usually takes less than 1 second.
58
59
60
        self.propose([[]] * 1024, [""] * 1024, np.zeros(1024, dtype=np.int32),
                     np.zeros((1024, self.max_model_len), dtype=np.int32),
                     set())
61

62
    def batch_propose(
63
        self,
64
65
66
67
68
69
        num_requests: int,
        valid_ngram_requests: list,
        num_tokens_no_spec: np.ndarray,
        token_ids_cpu: np.ndarray,
    ) -> list[list[int]]:
        """Batch version of ngram proposer using numba for acceleration.
70
71
        
        Args:
72
73
            valid_ngram_requests: 
                Set of indices of requests that need ngram proposals.
74
75
            num_tokens_no_spec: 
                Numpy array of shape (batch_size,) representing the number 
76
77
78
79
                of tokens without speculative tokens for each request.
            token_ids_cpu: 
                Numpy array of shape (batch_size, max_model_len) 
                representing the token IDs for each request.
80

81
        Returns:
82
83
84
            list[list[int]]: 
                A list where each element is a list of proposed 
                token IDs for the corresponding request.
85
        """
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        draft_token_ids: list[list[int]] = []

        # Only run batch propose if there are requests needing ngram proposals.
        # avoid calling numba function with empty list which causes error
        # ValueError: cannot compute fingerprint of empty list
        if num_ngram_requests := len(valid_ngram_requests):
            original_num_numba_threads = get_num_threads()
            # Ensure we use at least one thread.
            # If total tokens is small, using multiple threads
            # may slow down due to overhead.
            total_tokens = np.sum(num_tokens_no_spec)
            if total_tokens >= self.num_tokens_threshold:
                final_num_threads = max(
                    1, min(self.num_numba_thread_available,
                           num_ngram_requests))
                set_num_threads(final_num_threads)
            else:
                set_num_threads(1)

            batch_propose_numba(valid_ngram_requests, num_tokens_no_spec,
                                token_ids_cpu, self.min_n, self.max_n,
                                self.max_model_len, self.k,
                                self.valid_ngram_draft,
                                self.valid_ngram_num_drafts)

            # Restore original number of threads.
            set_num_threads(original_num_numba_threads)

        for i in range(num_requests):
            if i in valid_ngram_requests and \
                self.valid_ngram_num_drafts[i] > 0:
                draft_token_ids.append(self.valid_ngram_draft[
                    i, :self.valid_ngram_num_drafts[i]].tolist())
            else:
                draft_token_ids.append([])

        return draft_token_ids

    def propose(
        self,
        sampled_token_ids: list[list[int]],
        req_ids: list[str],
        num_tokens_no_spec: np.ndarray,
        token_ids_cpu: np.ndarray,
        spec_decode_unsupported_reqs: set,
    ) -> list[list[int]]:

        # find which requests need ngram proposals
        valid_ngram_requests = []
        for i, sampled_ids in enumerate(sampled_token_ids):
            num_sampled_ids = len(sampled_ids)
            if not num_sampled_ids:
                # Skip speculative decoding.
                continue

            # Skip requests that require sampling parameters that are not
            # supported with speculative decoding.
            req_id = req_ids[i]
            if req_id in spec_decode_unsupported_reqs:
                continue

            num_tokens = num_tokens_no_spec[i]
            if num_tokens >= self.max_model_len:
                # Skip requests that have already reached the max model length.
                continue

            valid_ngram_requests.append(i)

        draft_token_ids = self.batch_propose(
            len(sampled_token_ids),
            valid_ngram_requests,
            num_tokens_no_spec,
            token_ids_cpu,
        )

        return draft_token_ids
162

163
164
165
166
    def load_model(self, *args, **kwargs):
        # No model to load.
        pass

167

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
@njit(parallel=True)
def batch_propose_numba(valid_ngram_requests: list,
                        num_tokens_no_spec: np.ndarray,
                        token_ids_cpu: np.ndarray, min_n: int, max_n: int,
                        max_model_len: int, k: int,
                        valid_ngram_draft: np.ndarray,
                        valid_ngram_num_drafts: np.ndarray):
    for i in prange(len(valid_ngram_requests)):
        idx = valid_ngram_requests[i]
        num_tokens = num_tokens_no_spec[idx]
        context_token_ids = token_ids_cpu[idx, :num_tokens]
        drafter_output = _find_longest_matched_ngram_and_propose_tokens(
            origin_tokens=context_token_ids,
            min_ngram=min_n,
            max_ngram=max_n,
            max_model_len=max_model_len,
            k=k)

        valid_ngram_num_drafts[i] = drafter_output.shape[0]
        if len(drafter_output):
            valid_ngram_draft[i, :drafter_output.shape[0]] = drafter_output


191
@jit(nopython=True)
192
193
194
195
196
def _find_longest_matched_ngram_and_propose_tokens(origin_tokens: np.ndarray,
                                                   min_ngram: int,
                                                   max_ngram: int,
                                                   max_model_len: int,
                                                   k: int) -> np.ndarray:
197
    """
198
199
200
201
    Find the longest n-gram which matches the suffix of the given tokens
    whose length is within [min_ngram, max_ngram] (inclusive).

    If found, we will extract k right after the matched ngram.
202
    """
203
204
205
    # Do not generate draft tokens is context is shorter than minimum n-gram
    total_token = origin_tokens.shape[0]
    if total_token < min_ngram:
206
        return np.empty((0, ), dtype=origin_tokens.dtype)
207
208
209
210

    # Do not generate draft tokens beyond the max model length.
    k = min(k, max_model_len - total_token)
    if k <= 0:
211
        return np.empty((0, ), dtype=origin_tokens.dtype)
212
213
214
215
216

    # Flip tokens, and the goal become to find longest ngram
    # on the rightmost position which matches the prefix with
    # length [min_n, max_n] (inclusive).
    tokens = origin_tokens[::-1]
217

218
219
220
221
222
223
224
225
226
227
228
    # Longest prefix (not including itself) which is a suffix of
    # the current position.
    #   lps[i] = max{v, where tokens[0:v] == tokens[i+1-v:i+1]}
    #
    # As ngram is capped by max_ngram to save memory, we only need to
    # store lps for the first max_ngram prefix.
    lps = np.zeros(max_ngram, dtype=np.int32)

    longest_ngram = 0
    position = 0

229
    # lps[0] always equal to 0, we start with index 1
230
231
232
233
234
235
236
    prev_lps = 0
    i = 1
    while i < total_token:
        # tokens[:prev_lps] is the longest prefix as a suffix of tokens[:i]
        if tokens[prev_lps] == tokens[i]:
            # Token match: tokens[:prev_lps+1] is the longest prefix as
            # a suffix of tokens[:i+1]
237
            prev_lps += 1
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
            # Check if we found a longer valid ngram.
            #
            # Update position when longest_ngram matched prev_lps,
            # as we want to get the target n-gram of the earliest position
            # in the original tokens (i.e.
            # latest position in the reversed tokens)
            if prev_lps >= longest_ngram:
                longest_ngram = prev_lps
                position = i
            if i < max_ngram:
                # Store LPS for the first max_ngram prefix
                lps[i] = prev_lps
            if prev_lps == max_ngram:
                # When prev_lps reached max_ngram, update prev_lps
                # to lps[max_ngram-1] to avoid matching ngram
                # longer than max_ngram
                prev_lps = lps[max_ngram - 1]
255
            i += 1
256
257
258
259
260
        elif prev_lps != 0:
            # Token mismatch: try the second longest prefix
            # among all suffix of tokens[:i],
            # which is the longest prefix of tokens[:prev_lps]
            prev_lps = lps[prev_lps - 1]
261
        else:
262
263
            # Token mismatch, and no more prefix (except empty string)
            # as a suffix of tokens[:i]
264
            i += 1
265

266
267
    if longest_ngram < min_ngram:
        # No valid ngram is found
268
        return np.empty((0, ), dtype=origin_tokens.dtype)
269
270
271
272
273
274
275
276

    # Flip the position back, so in origin_tokens,
    # origin_tokens[total_token-1-position:total_token-1-position+longest_ngram]
    # is the matched ngram, so we should start drafting tokens from
    # total_token-1-position+longest_ngram
    start_position = total_token - 1 - position + longest_ngram
    k = min(k, total_token - start_position)
    return origin_tokens[start_position:start_position + k]