# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional import numpy as np from numba import jit from vllm.config import VllmConfig class NgramProposer: def __init__(self, vllm_config: VllmConfig): assert vllm_config.speculative_config is not None assert vllm_config.speculative_config.prompt_lookup_min is not None assert vllm_config.speculative_config.prompt_lookup_max is not None # Minimum length of the n-gram to match. self.min_n = vllm_config.speculative_config.prompt_lookup_min # Maximum length of the n-gram to match. self.max_n = vllm_config.speculative_config.prompt_lookup_max # Number of tokens follow the match. If there are less than k # tokens follow the match, we will return the maximum amount of # tokens until the end. self.k = vllm_config.speculative_config.num_speculative_tokens # Maximum length of the model. self.max_model_len = vllm_config.model_config.max_model_len # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. self.propose(np.zeros(1024, dtype=np.int32)) def propose( self, context_token_ids: np.ndarray, ) -> Optional[np.ndarray]: """Proposes the next sequence of tokens based on n-gram pattern matching in the context. The function finds matches of the last n tokens in the previous context, and returns k tokens that followed that match. Args: context_token_ids: Numpy array of token IDs representing the context sequence. Returns: np.ndarray: The sequence of tokens that followed the matched n-gram in the context. None: If no matching n-gram pattern is found. Example: If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and k = 4: - The last 3 (= max_n) tokens [4,2,3] cannot find a match. - The last 2 tokens [2,3] will be matched against the previous 4 tokens [1,2,3,4]. - Finding a match of [2,3] would return the tokens that followed that pattern. Here we will return [4,2,3] because we only have three tokens after the match. """ # TODO(woosuk): Optimize this. return _find_longest_matched_ngram_and_propose_tokens( origin_tokens=context_token_ids, min_ngram=self.min_n, max_ngram=self.max_n, max_model_len=self.max_model_len, k=self.k) def load_model(self, *args, **kwargs): # No model to load. pass @jit(nopython=True) def _find_longest_matched_ngram_and_propose_tokens( origin_tokens: np.ndarray, min_ngram: int, max_ngram: int, max_model_len: int, k: int) -> Optional[np.ndarray]: """ Find the longest n-gram which matches the suffix of the given tokens whose length is within [min_ngram, max_ngram] (inclusive). If found, we will extract k right after the matched ngram. """ # Do not generate draft tokens is context is shorter than minimum n-gram total_token = origin_tokens.shape[0] if total_token < min_ngram: return None # Do not generate draft tokens beyond the max model length. k = min(k, max_model_len - total_token) if k <= 0: return None # Flip tokens, and the goal become to find longest ngram # on the rightmost position which matches the prefix with # length [min_n, max_n] (inclusive). tokens = origin_tokens[::-1] # Longest prefix (not including itself) which is a suffix of # the current position. # lps[i] = max{v, where tokens[0:v] == tokens[i+1-v:i+1]} # # As ngram is capped by max_ngram to save memory, we only need to # store lps for the first max_ngram prefix. lps = np.zeros(max_ngram, dtype=np.int32) longest_ngram = 0 position = 0 # lps[0] always equal to 0, we start with index 1 prev_lps = 0 i = 1 while i < total_token: # tokens[:prev_lps] is the longest prefix as a suffix of tokens[:i] if tokens[prev_lps] == tokens[i]: # Token match: tokens[:prev_lps+1] is the longest prefix as # a suffix of tokens[:i+1] prev_lps += 1 # Check if we found a longer valid ngram. # # Update position when longest_ngram matched prev_lps, # as we want to get the target n-gram of the earliest position # in the original tokens (i.e. # latest position in the reversed tokens) if prev_lps >= longest_ngram: longest_ngram = prev_lps position = i if i < max_ngram: # Store LPS for the first max_ngram prefix lps[i] = prev_lps if prev_lps == max_ngram: # When prev_lps reached max_ngram, update prev_lps # to lps[max_ngram-1] to avoid matching ngram # longer than max_ngram prev_lps = lps[max_ngram - 1] i += 1 elif prev_lps != 0: # Token mismatch: try the second longest prefix # among all suffix of tokens[:i], # which is the longest prefix of tokens[:prev_lps] prev_lps = lps[prev_lps - 1] else: # Token mismatch, and no more prefix (except empty string) # as a suffix of tokens[:i] i += 1 if longest_ngram < min_ngram: # No valid ngram is found return None # Flip the position back, so in origin_tokens, # origin_tokens[total_token-1-position:total_token-1-position+longest_ngram] # is the matched ngram, so we should start drafting tokens from # total_token-1-position+longest_ngram start_position = total_token - 1 - position + longest_ngram k = min(k, total_token - start_position) return origin_tokens[start_position:start_position + k]