Unverified Commit 4c822298 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V1][Spec Decode] Optimize N-gram matching with Numba (#13365)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent c8d70e24
psutil psutil
sentencepiece # Required for LLaMA tokenizer. sentencepiece # Required for LLaMA tokenizer.
numpy < 2.0.0 numpy < 2.0.0
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding.
requests >= 2.26.0 requests >= 2.26.0
tqdm tqdm
blake3 blake3
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Optional from typing import Optional
import numpy as np import numpy as np
from numba import jit
class NgramProposer: class NgramProposer:
def __init__(self):
pass
def propose( def propose(
self, self,
context_token_ids: np.ndarray, context_token_ids: np.ndarray,
...@@ -21,7 +19,7 @@ class NgramProposer: ...@@ -21,7 +19,7 @@ class NgramProposer:
that match. that match.
Args: Args:
context_token_ids: List of token IDs representing the context_token_ids: Numpy array of token IDs representing the
context sequence. context sequence.
n: Length of the n-gram to match. n: Length of the n-gram to match.
k: Number of tokens follow the match. If there are less k: Number of tokens follow the match. If there are less
...@@ -41,66 +39,65 @@ class NgramProposer: ...@@ -41,66 +39,65 @@ class NgramProposer:
followed that pattern. Here we will return [4,2,3] because followed that pattern. Here we will return [4,2,3] because
we only have three tokens after the match. we only have three tokens after the match.
""" """
# TODO: Use c++ to implement the _find_subarray_kmp to return _find_subarray_kmp(context_token_ids, n, k)
# improve the efficiency
return self._find_subarray_kmp(context_token_ids, n, k)
@staticmethod
def _kmp_lps_array(pattern: List[int]) -> List[int]:
"""
Build the lps (longest proper prefix which is also suffix)
array for the pattern.
"""
lps = [0] * len(pattern)
prev_lps = 0 # length of the previous longest prefix suffix
i = 1
while i < len(pattern): @jit(nopython=True)
if pattern[i] == pattern[prev_lps]: def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray:
prev_lps += 1 """
lps[i] = prev_lps Build the lps (longest proper prefix which is also suffix)
i += 1 array for the pattern.
"""
lps = np.zeros(len(pattern), dtype=np.int32)
prev_lps = 0 # length of the previous longest prefix suffix
i = 1
while i < len(pattern):
if pattern[i] == pattern[prev_lps]:
prev_lps += 1
lps[i] = prev_lps
i += 1
else:
if prev_lps != 0:
prev_lps = lps[prev_lps - 1]
else: else:
if prev_lps != 0: lps[i] = 0
prev_lps = lps[prev_lps - 1] i += 1
else: return lps
lps[i] = 0
i += 1
return lps
@staticmethod @jit(nopython=True)
def _find_subarray_kmp( def _find_subarray_kmp(
context_token_ids: np.ndarray, context_token_ids: np.ndarray,
n: int, n: int,
k: int, k: int,
) -> Optional[np.ndarray]: ) -> Optional[np.ndarray]:
context_len = context_token_ids.shape[0] context_len = context_token_ids.shape[0]
assert n > 0 assert n > 0
pattern = context_token_ids[-n:] pattern = context_token_ids[-n:]
# Precompute lps array for Y # Precompute lps array for Y
lps = NgramProposer._kmp_lps_array(pattern) lps = _kmp_lps_array(pattern)
i = 0 i = 0
j = 0 j = 0
# -n because the last n tokens are used as pattern # -n because the last n tokens are used as pattern
while i < context_len - n: while i < context_len - n:
if context_token_ids[i] == pattern[j]: if context_token_ids[i] == pattern[j]:
i += 1 i += 1
j += 1 j += 1
# If we have matched the entire Y # If we have matched the entire Y
if j == n: if j == n:
# Found pattern in context, gather the next K elements # Found pattern in context, gather the next K elements
return context_token_ids[i:i + k] return context_token_ids[i:i + k]
else:
# Mismatch
if j != 0:
# Use the lps array to avoid re-checking elements
j = lps[j - 1]
else: else:
# Mismatch i += 1
if j != 0:
# Use the lps array to avoid re-checking elements
j = lps[j - 1]
else:
i += 1
# Y not found # Y not found
return None return None
...@@ -120,11 +120,20 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -120,11 +120,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Set up speculative decoding. # Set up speculative decoding.
self.use_spec_decode = False self.use_spec_decode = False
if self.speculative_config: if self.speculative_config:
self.use_spec_decode = True
# TODO: find a better way to check if we are using ngram. # TODO: find a better way to check if we are using ngram.
assert self.speculative_config.ngram_prompt_lookup_min, \ assert self.speculative_config.ngram_prompt_lookup_min, \
"Currently, only ngram spec decode is supported in V1." "Currently, only ngram spec decode is supported in V1."
self.drafter = NgramProposer() if get_pp_group().is_last_rank:
self.use_spec_decode = True self.drafter = NgramProposer()
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self.drafter.propose(
np.zeros(1024, dtype=np.int32),
self.speculative_config.ngram_prompt_lookup_min,
self.speculative_config.num_speculative_tokens,
)
# Request states. # Request states.
self.requests: Dict[str, CachedRequestState] = {} self.requests: Dict[str, CachedRequestState] = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment