ngram_proposer.py 3.34 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional

4
import numpy as np
5
6
7
8
9
10
11


class NgramProposer:

    def __init__(self):
        pass

12
13
14
15
16
17
    def propose(
        self,
        context_token_ids: np.ndarray,
        n: int,
        k: int,
    ) -> Optional[np.ndarray]:
18
19
20
21
22
23
24
25
26
27
28
29
30
31
        """Proposes the next sequence of tokens based on n-gram pattern 
        matching in the context. The function finds matches of the last n 
        tokens in the previous context, and returns k tokens that followed 
        that match.
        
        Args:
            context_token_ids: List of token IDs representing the 
                               context sequence.
            n: Length of the n-gram to match.
            k: Number of tokens follow the match. If there are less 
               than k tokens follow the match, we will return 
               the maximum amount of tokens until the end.
        
        Returns:
32
33
            np.ndarray: The sequence of tokens that followed 
                        the matched n-gram in the context.
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
            None: If no matching n-gram pattern is found.
        
        Example:
            If context_token_ids = [1,2,3,4,2,3], n = 2, and k = 4:
            - The last 2 tokens [2,3] will be matched against the previous 
              4 tokens [1,2,3,4].
            - Finding a match of [2,3] would return the tokens that 
              followed that pattern. Here we will return [4,2,3] because 
              we only have three tokens after the match.
        """
        # TODO: Use c++ to implement the _find_subarray_kmp to
        # improve the efficiency
        return self._find_subarray_kmp(context_token_ids, n, k)

    @staticmethod
    def _kmp_lps_array(pattern: List[int]) -> List[int]:
        """
        Build the lps (longest proper prefix which is also suffix) 
        array for the pattern.
        """
        lps = [0] * len(pattern)
        prev_lps = 0  # length of the previous longest prefix suffix
        i = 1

        while i < len(pattern):
            if pattern[i] == pattern[prev_lps]:
                prev_lps += 1
                lps[i] = prev_lps
                i += 1
            else:
                if prev_lps != 0:
                    prev_lps = lps[prev_lps - 1]
                else:
                    lps[i] = 0
                    i += 1

        return lps

    @staticmethod
73
74
75
76
77
78
    def _find_subarray_kmp(
        context_token_ids: np.ndarray,
        n: int,
        k: int,
    ) -> Optional[np.ndarray]:
        context_len = context_token_ids.shape[0]
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        assert n > 0

        pattern = context_token_ids[-n:]
        # Precompute lps array for Y
        lps = NgramProposer._kmp_lps_array(pattern)

        i = 0
        j = 0
        # -n because the last n tokens are used as pattern
        while i < context_len - n:
            if context_token_ids[i] == pattern[j]:
                i += 1
                j += 1

                # If we have matched the entire Y
                if j == n:
                    # Found pattern in context, gather the next K elements
                    return context_token_ids[i:i + k]
            else:
                # Mismatch
                if j != 0:
                    # Use the lps array to avoid re-checking elements
                    j = lps[j - 1]
                else:
                    i += 1

        # Y not found
        return None