Unverified Commit 4c822298 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V1][Spec Decode] Optimize N-gram matching with Numba (#13365)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent c8d70e24
psutil psutil
sentencepiece # Required for LLaMA tokenizer. sentencepiece # Required for LLaMA tokenizer.
numpy < 2.0.0 numpy < 2.0.0
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding.
requests >= 2.26.0 requests >= 2.26.0
tqdm tqdm
blake3 blake3
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Optional from typing import Optional
import numpy as np import numpy as np
from numba import jit
class NgramProposer: class NgramProposer:
def __init__(self):
pass
def propose( def propose(
self, self,
context_token_ids: np.ndarray, context_token_ids: np.ndarray,
...@@ -21,7 +19,7 @@ class NgramProposer: ...@@ -21,7 +19,7 @@ class NgramProposer:
that match. that match.
Args: Args:
context_token_ids: List of token IDs representing the context_token_ids: Numpy array of token IDs representing the
context sequence. context sequence.
n: Length of the n-gram to match. n: Length of the n-gram to match.
k: Number of tokens follow the match. If there are less k: Number of tokens follow the match. If there are less
...@@ -41,17 +39,16 @@ class NgramProposer: ...@@ -41,17 +39,16 @@ class NgramProposer:
followed that pattern. Here we will return [4,2,3] because followed that pattern. Here we will return [4,2,3] because
we only have three tokens after the match. we only have three tokens after the match.
""" """
# TODO: Use c++ to implement the _find_subarray_kmp to return _find_subarray_kmp(context_token_ids, n, k)
# improve the efficiency
return self._find_subarray_kmp(context_token_ids, n, k)
@staticmethod @jit(nopython=True)
def _kmp_lps_array(pattern: List[int]) -> List[int]: def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray:
""" """
Build the lps (longest proper prefix which is also suffix) Build the lps (longest proper prefix which is also suffix)
array for the pattern. array for the pattern.
""" """
lps = [0] * len(pattern) lps = np.zeros(len(pattern), dtype=np.int32)
prev_lps = 0 # length of the previous longest prefix suffix prev_lps = 0 # length of the previous longest prefix suffix
i = 1 i = 1
...@@ -66,21 +63,21 @@ class NgramProposer: ...@@ -66,21 +63,21 @@ class NgramProposer:
else: else:
lps[i] = 0 lps[i] = 0
i += 1 i += 1
return lps return lps
@staticmethod
def _find_subarray_kmp( @jit(nopython=True)
def _find_subarray_kmp(
context_token_ids: np.ndarray, context_token_ids: np.ndarray,
n: int, n: int,
k: int, k: int,
) -> Optional[np.ndarray]: ) -> Optional[np.ndarray]:
context_len = context_token_ids.shape[0] context_len = context_token_ids.shape[0]
assert n > 0 assert n > 0
pattern = context_token_ids[-n:] pattern = context_token_ids[-n:]
# Precompute lps array for Y # Precompute lps array for Y
lps = NgramProposer._kmp_lps_array(pattern) lps = _kmp_lps_array(pattern)
i = 0 i = 0
j = 0 j = 0
......
...@@ -120,11 +120,20 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -120,11 +120,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Set up speculative decoding. # Set up speculative decoding.
self.use_spec_decode = False self.use_spec_decode = False
if self.speculative_config: if self.speculative_config:
self.use_spec_decode = True
# TODO: find a better way to check if we are using ngram. # TODO: find a better way to check if we are using ngram.
assert self.speculative_config.ngram_prompt_lookup_min, \ assert self.speculative_config.ngram_prompt_lookup_min, \
"Currently, only ngram spec decode is supported in V1." "Currently, only ngram spec decode is supported in V1."
if get_pp_group().is_last_rank:
self.drafter = NgramProposer() self.drafter = NgramProposer()
self.use_spec_decode = True # Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self.drafter.propose(
np.zeros(1024, dtype=np.int32),
self.speculative_config.ngram_prompt_lookup_min,
self.speculative_config.num_speculative_tokens,
)
# Request states. # Request states.
self.requests: Dict[str, CachedRequestState] = {} self.requests: Dict[str, CachedRequestState] = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment