Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4c822298
Unverified
Commit
4c822298
authored
Feb 18, 2025
by
Woosuk Kwon
Committed by
GitHub
Feb 18, 2025
Browse files
[V1][Spec Decode] Optimize N-gram matching with Numba (#13365)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
c8d70e24
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
67 additions
and
60 deletions
+67
-60
requirements-common.txt
requirements-common.txt
+1
-0
vllm/v1/spec_decode/ngram_proposer.py
vllm/v1/spec_decode/ngram_proposer.py
+55
-58
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+11
-2
No files found.
requirements-common.txt
View file @
4c822298
psutil
sentencepiece # Required for LLaMA tokenizer.
numpy < 2.0.0
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding.
requests >= 2.26.0
tqdm
blake3
...
...
vllm/v1/spec_decode/ngram_proposer.py
View file @
4c822298
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
from
typing
import
Optional
import
numpy
as
np
from
numba
import
jit
class
NgramProposer
:
def
__init__
(
self
):
pass
def
propose
(
self
,
context_token_ids
:
np
.
ndarray
,
...
...
@@ -21,7 +19,7 @@ class NgramProposer:
that match.
Args:
context_token_ids:
List
of token IDs representing the
context_token_ids:
Numpy array
of token IDs representing the
context sequence.
n: Length of the n-gram to match.
k: Number of tokens follow the match. If there are less
...
...
@@ -41,17 +39,16 @@ class NgramProposer:
followed that pattern. Here we will return [4,2,3] because
we only have three tokens after the match.
"""
# TODO: Use c++ to implement the _find_subarray_kmp to
# improve the efficiency
return
self
.
_find_subarray_kmp
(
context_token_ids
,
n
,
k
)
return
_find_subarray_kmp
(
context_token_ids
,
n
,
k
)
@
staticmethod
def
_kmp_lps_array
(
pattern
:
List
[
int
])
->
List
[
int
]
:
@
jit
(
nopython
=
True
)
def
_kmp_lps_array
(
pattern
:
np
.
ndarray
)
->
np
.
ndarray
:
"""
Build the lps (longest proper prefix which is also suffix)
array for the pattern.
"""
lps
=
[
0
]
*
len
(
pattern
)
lps
=
np
.
zeros
(
len
(
pattern
)
,
dtype
=
np
.
int32
)
prev_lps
=
0
# length of the previous longest prefix suffix
i
=
1
...
...
@@ -66,21 +63,21 @@ class NgramProposer:
else
:
lps
[
i
]
=
0
i
+=
1
return
lps
@
staticmethod
def
_find_subarray_kmp
(
@
jit
(
nopython
=
True
)
def
_find_subarray_kmp
(
context_token_ids
:
np
.
ndarray
,
n
:
int
,
k
:
int
,
)
->
Optional
[
np
.
ndarray
]:
)
->
Optional
[
np
.
ndarray
]:
context_len
=
context_token_ids
.
shape
[
0
]
assert
n
>
0
pattern
=
context_token_ids
[
-
n
:]
# Precompute lps array for Y
lps
=
NgramProposer
.
_kmp_lps_array
(
pattern
)
lps
=
_kmp_lps_array
(
pattern
)
i
=
0
j
=
0
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
4c822298
...
...
@@ -120,11 +120,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Set up speculative decoding.
self
.
use_spec_decode
=
False
if
self
.
speculative_config
:
self
.
use_spec_decode
=
True
# TODO: find a better way to check if we are using ngram.
assert
self
.
speculative_config
.
ngram_prompt_lookup_min
,
\
"Currently, only ngram spec decode is supported in V1."
if
get_pp_group
().
is_last_rank
:
self
.
drafter
=
NgramProposer
()
self
.
use_spec_decode
=
True
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self
.
drafter
.
propose
(
np
.
zeros
(
1024
,
dtype
=
np
.
int32
),
self
.
speculative_config
.
ngram_prompt_lookup_min
,
self
.
speculative_config
.
num_speculative_tokens
,
)
# Request states.
self
.
requests
:
Dict
[
str
,
CachedRequestState
]
=
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment