ngram_proposer.py 4.19 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Optional
4

5
import numpy as np
6
from numba import jit
7

8
9
from vllm.config import VllmConfig

10
11
12

class NgramProposer:

13
    def __init__(self, vllm_config: VllmConfig):
14
15
16
17
18
19
20
21
        # Minimum length of the n-gram to match.
        self.min_n = vllm_config.speculative_config.prompt_lookup_min
        # Maximum length of the n-gram to match.
        self.max_n = vllm_config.speculative_config.prompt_lookup_max
        # Number of tokens follow the match. If there are less than k
        # tokens follow the match, we will return the maximum amount of
        # tokens until the end.
        self.k = vllm_config.speculative_config.num_speculative_tokens
22
23
24
        # Maximum length of the model.
        self.max_model_len = vllm_config.model_config.max_model_len

25
26
27
        # Trigger Numba JIT compilation for N-gram proposer.
        # This usually takes less than 1 second.
        self.propose(np.zeros(1024, dtype=np.int32))
28

29
30
31
32
    def propose(
        self,
        context_token_ids: np.ndarray,
    ) -> Optional[np.ndarray]:
33
34
35
36
37
38
        """Proposes the next sequence of tokens based on n-gram pattern 
        matching in the context. The function finds matches of the last n 
        tokens in the previous context, and returns k tokens that followed 
        that match.
        
        Args:
39
            context_token_ids: Numpy array of token IDs representing the 
40
                               context sequence.
41

42
        Returns:
43
44
            np.ndarray: The sequence of tokens that followed 
                        the matched n-gram in the context.
45
            None: If no matching n-gram pattern is found.
46

47
        Example:
48
49
50
            If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and
            k = 4:
            - The last 3 (= max_n) tokens [4,2,3] cannot find a match.
51
52
53
54
55
56
            - The last 2 tokens [2,3] will be matched against the previous 
              4 tokens [1,2,3,4].
            - Finding a match of [2,3] would return the tokens that 
              followed that pattern. Here we will return [4,2,3] because 
              we only have three tokens after the match.
        """
57
58
59
60
61
        # Do not generate draft tokens beyond the max model length.
        k = min(self.k, self.max_model_len - context_token_ids.shape[0])
        if k <= 0:
            return None

62
        # TODO(woosuk): Optimize this.
63
        for n in range(self.max_n, self.min_n - 1, -1):
64
            result = _find_subarray_kmp(context_token_ids, n, k)
65
66
67
            if result is not None:
                return result
        return None
68

69
70
71
72
    def load_model(self, *args, **kwargs):
        # No model to load.
        pass

73

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
@jit(nopython=True)
def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray:
    """
    Build the lps (longest proper prefix which is also suffix) 
    array for the pattern.
    """
    lps = np.zeros(len(pattern), dtype=np.int32)
    prev_lps = 0  # length of the previous longest prefix suffix
    i = 1

    while i < len(pattern):
        if pattern[i] == pattern[prev_lps]:
            prev_lps += 1
            lps[i] = prev_lps
            i += 1
        else:
            if prev_lps != 0:
                prev_lps = lps[prev_lps - 1]
92
            else:
93
94
95
                lps[i] = 0
                i += 1
    return lps
96
97


98
99
100
101
102
103
104
105
@jit(nopython=True)
def _find_subarray_kmp(
    context_token_ids: np.ndarray,
    n: int,
    k: int,
) -> Optional[np.ndarray]:
    context_len = context_token_ids.shape[0]
    assert n > 0
106

107
108
109
    pattern = context_token_ids[-n:]
    # Precompute lps array for Y
    lps = _kmp_lps_array(pattern)
110

111
112
113
114
115
116
117
    i = 0
    j = 0
    # -n because the last n tokens are used as pattern
    while i < context_len - n:
        if context_token_ids[i] == pattern[j]:
            i += 1
            j += 1
118

119
120
121
122
123
124
125
126
127
            # If we have matched the entire Y
            if j == n:
                # Found pattern in context, gather the next K elements
                return context_token_ids[i:i + k]
        else:
            # Mismatch
            if j != 0:
                # Use the lps array to avoid re-checking elements
                j = lps[j - 1]
128
            else:
129
                i += 1
130

131
132
    # Y not found
    return None