ngram_proposer.py 3.34 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional

from vllm.v1.utils import ConstantList


class NgramProposer:

    def __init__(self):
        pass

    def propose(self, context_token_ids: ConstantList[int], n: int,
                k: int) -> Optional[List[int]]:
        """Proposes the next sequence of tokens based on n-gram pattern 
        matching in the context. The function finds matches of the last n 
        tokens in the previous context, and returns k tokens that followed 
        that match.
        
        Args:
            context_token_ids: List of token IDs representing the 
                               context sequence.
            n: Length of the n-gram to match.
            k: Number of tokens follow the match. If there are less 
               than k tokens follow the match, we will return 
               the maximum amount of tokens until the end.
        
        Returns:
            List[int]: The sequence of tokens that followed 
                       the matched n-gram in the context.
            None: If no matching n-gram pattern is found.
        
        Example:
            If context_token_ids = [1,2,3,4,2,3], n = 2, and k = 4:
            - The last 2 tokens [2,3] will be matched against the previous 
              4 tokens [1,2,3,4].
            - Finding a match of [2,3] would return the tokens that 
              followed that pattern. Here we will return [4,2,3] because 
              we only have three tokens after the match.
        """
        # TODO: Use c++ to implement the _find_subarray_kmp to
        # improve the efficiency
        return self._find_subarray_kmp(context_token_ids, n, k)

    @staticmethod
    def _kmp_lps_array(pattern: List[int]) -> List[int]:
        """
        Build the lps (longest proper prefix which is also suffix) 
        array for the pattern.
        """
        lps = [0] * len(pattern)
        prev_lps = 0  # length of the previous longest prefix suffix
        i = 1

        while i < len(pattern):
            if pattern[i] == pattern[prev_lps]:
                prev_lps += 1
                lps[i] = prev_lps
                i += 1
            else:
                if prev_lps != 0:
                    prev_lps = lps[prev_lps - 1]
                else:
                    lps[i] = 0
                    i += 1

        return lps

    @staticmethod
    def _find_subarray_kmp(context_token_ids: ConstantList[int], n: int,
                           k: int) -> Optional[List[int]]:
        context_len = len(context_token_ids)
        assert n > 0

        pattern = context_token_ids[-n:]
        # Precompute lps array for Y
        lps = NgramProposer._kmp_lps_array(pattern)

        i = 0
        j = 0
        # -n because the last n tokens are used as pattern
        while i < context_len - n:
            if context_token_ids[i] == pattern[j]:
                i += 1
                j += 1

                # If we have matched the entire Y
                if j == n:
                    # Found pattern in context, gather the next K elements
                    return context_token_ids[i:i + k]
            else:
                # Mismatch
                if j != 0:
                    # Use the lps array to avoid re-checking elements
                    j = lps[j - 1]
                else:
                    i += 1

        # Y not found
        return None