Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4c822298
Unverified
Commit
4c822298
authored
Feb 18, 2025
by
Woosuk Kwon
Committed by
GitHub
Feb 18, 2025
Browse files
[V1][Spec Decode] Optimize N-gram matching with Numba (#13365)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
c8d70e24
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
67 additions
and
60 deletions
+67
-60
requirements-common.txt
requirements-common.txt
+1
-0
vllm/v1/spec_decode/ngram_proposer.py
vllm/v1/spec_decode/ngram_proposer.py
+55
-58
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+11
-2
No files found.
requirements-common.txt
View file @
4c822298
psutil
psutil
sentencepiece # Required for LLaMA tokenizer.
sentencepiece # Required for LLaMA tokenizer.
numpy < 2.0.0
numpy < 2.0.0
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding.
requests >= 2.26.0
requests >= 2.26.0
tqdm
tqdm
blake3
blake3
...
...
vllm/v1/spec_decode/ngram_proposer.py
View file @
4c822298
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
from
typing
import
Optional
import
numpy
as
np
import
numpy
as
np
from
numba
import
jit
class
NgramProposer
:
class
NgramProposer
:
def
__init__
(
self
):
pass
def
propose
(
def
propose
(
self
,
self
,
context_token_ids
:
np
.
ndarray
,
context_token_ids
:
np
.
ndarray
,
...
@@ -21,7 +19,7 @@ class NgramProposer:
...
@@ -21,7 +19,7 @@ class NgramProposer:
that match.
that match.
Args:
Args:
context_token_ids:
List
of token IDs representing the
context_token_ids:
Numpy array
of token IDs representing the
context sequence.
context sequence.
n: Length of the n-gram to match.
n: Length of the n-gram to match.
k: Number of tokens follow the match. If there are less
k: Number of tokens follow the match. If there are less
...
@@ -41,17 +39,16 @@ class NgramProposer:
...
@@ -41,17 +39,16 @@ class NgramProposer:
followed that pattern. Here we will return [4,2,3] because
followed that pattern. Here we will return [4,2,3] because
we only have three tokens after the match.
we only have three tokens after the match.
"""
"""
# TODO: Use c++ to implement the _find_subarray_kmp to
return
_find_subarray_kmp
(
context_token_ids
,
n
,
k
)
# improve the efficiency
return
self
.
_find_subarray_kmp
(
context_token_ids
,
n
,
k
)
@
staticmethod
@
jit
(
nopython
=
True
)
def
_kmp_lps_array
(
pattern
:
List
[
int
])
->
List
[
int
]
:
def
_kmp_lps_array
(
pattern
:
np
.
ndarray
)
->
np
.
ndarray
:
"""
"""
Build the lps (longest proper prefix which is also suffix)
Build the lps (longest proper prefix which is also suffix)
array for the pattern.
array for the pattern.
"""
"""
lps
=
[
0
]
*
len
(
pattern
)
lps
=
np
.
zeros
(
len
(
pattern
)
,
dtype
=
np
.
int32
)
prev_lps
=
0
# length of the previous longest prefix suffix
prev_lps
=
0
# length of the previous longest prefix suffix
i
=
1
i
=
1
...
@@ -66,21 +63,21 @@ class NgramProposer:
...
@@ -66,21 +63,21 @@ class NgramProposer:
else
:
else
:
lps
[
i
]
=
0
lps
[
i
]
=
0
i
+=
1
i
+=
1
return
lps
return
lps
@
staticmethod
def
_find_subarray_kmp
(
@
jit
(
nopython
=
True
)
def
_find_subarray_kmp
(
context_token_ids
:
np
.
ndarray
,
context_token_ids
:
np
.
ndarray
,
n
:
int
,
n
:
int
,
k
:
int
,
k
:
int
,
)
->
Optional
[
np
.
ndarray
]:
)
->
Optional
[
np
.
ndarray
]:
context_len
=
context_token_ids
.
shape
[
0
]
context_len
=
context_token_ids
.
shape
[
0
]
assert
n
>
0
assert
n
>
0
pattern
=
context_token_ids
[
-
n
:]
pattern
=
context_token_ids
[
-
n
:]
# Precompute lps array for Y
# Precompute lps array for Y
lps
=
NgramProposer
.
_kmp_lps_array
(
pattern
)
lps
=
_kmp_lps_array
(
pattern
)
i
=
0
i
=
0
j
=
0
j
=
0
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
4c822298
...
@@ -120,11 +120,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -120,11 +120,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Set up speculative decoding.
# Set up speculative decoding.
self
.
use_spec_decode
=
False
self
.
use_spec_decode
=
False
if
self
.
speculative_config
:
if
self
.
speculative_config
:
self
.
use_spec_decode
=
True
# TODO: find a better way to check if we are using ngram.
# TODO: find a better way to check if we are using ngram.
assert
self
.
speculative_config
.
ngram_prompt_lookup_min
,
\
assert
self
.
speculative_config
.
ngram_prompt_lookup_min
,
\
"Currently, only ngram spec decode is supported in V1."
"Currently, only ngram spec decode is supported in V1."
if
get_pp_group
().
is_last_rank
:
self
.
drafter
=
NgramProposer
()
self
.
drafter
=
NgramProposer
()
self
.
use_spec_decode
=
True
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self
.
drafter
.
propose
(
np
.
zeros
(
1024
,
dtype
=
np
.
int32
),
self
.
speculative_config
.
ngram_prompt_lookup_min
,
self
.
speculative_config
.
num_speculative_tokens
,
)
# Request states.
# Request states.
self
.
requests
:
Dict
[
str
,
CachedRequestState
]
=
{}
self
.
requests
:
Dict
[
str
,
CachedRequestState
]
=
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment