Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
63375f0c
Unverified
Commit
63375f0c
authored
Apr 04, 2025
by
Woosuk Kwon
Committed by
GitHub
Apr 04, 2025
Browse files
[V1][Spec Decode] Update N-gram Proposer Interface (#15750)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
70ad3f9e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
18 deletions
+16
-18
vllm/v1/spec_decode/ngram_proposer.py
vllm/v1/spec_decode/ngram_proposer.py
+15
-13
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-5
No files found.
vllm/v1/spec_decode/ngram_proposer.py
View file @
63375f0c
...
@@ -10,14 +10,21 @@ from vllm.config import VllmConfig
...
@@ -10,14 +10,21 @@ from vllm.config import VllmConfig
class
NgramProposer
:
class
NgramProposer
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
):
self
.
vllm_config
=
vllm_config
# Minimum length of the n-gram to match.
self
.
min_n
=
vllm_config
.
speculative_config
.
prompt_lookup_min
# Maximum length of the n-gram to match.
self
.
max_n
=
vllm_config
.
speculative_config
.
prompt_lookup_max
# Number of tokens follow the match. If there are less than k
# tokens follow the match, we will return the maximum amount of
# tokens until the end.
self
.
k
=
vllm_config
.
speculative_config
.
num_speculative_tokens
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self
.
propose
(
np
.
zeros
(
1024
,
dtype
=
np
.
int32
))
def
propose
(
def
propose
(
self
,
self
,
context_token_ids
:
np
.
ndarray
,
context_token_ids
:
np
.
ndarray
,
min_n
:
int
,
max_n
:
int
,
k
:
int
,
)
->
Optional
[
np
.
ndarray
]:
)
->
Optional
[
np
.
ndarray
]:
"""Proposes the next sequence of tokens based on n-gram pattern
"""Proposes the next sequence of tokens based on n-gram pattern
matching in the context. The function finds matches of the last n
matching in the context. The function finds matches of the last n
...
@@ -27,11 +34,6 @@ class NgramProposer:
...
@@ -27,11 +34,6 @@ class NgramProposer:
Args:
Args:
context_token_ids: Numpy array of token IDs representing the
context_token_ids: Numpy array of token IDs representing the
context sequence.
context sequence.
min_n: Minimum length of the n-gram to match.
max_n: Maximum length of the n-gram to match.
k: Number of tokens follow the match. If there are less
than k tokens follow the match, we will return
the maximum amount of tokens until the end.
Returns:
Returns:
np.ndarray: The sequence of tokens that followed
np.ndarray: The sequence of tokens that followed
...
@@ -49,8 +51,8 @@ class NgramProposer:
...
@@ -49,8 +51,8 @@ class NgramProposer:
we only have three tokens after the match.
we only have three tokens after the match.
"""
"""
# TODO(woosuk): Optimize this.
# TODO(woosuk): Optimize this.
for
n
in
range
(
max_n
,
min_n
-
1
,
-
1
):
for
n
in
range
(
self
.
max_n
,
self
.
min_n
-
1
,
-
1
):
result
=
_find_subarray_kmp
(
context_token_ids
,
n
,
k
)
result
=
_find_subarray_kmp
(
context_token_ids
,
n
,
self
.
k
)
if
result
is
not
None
:
if
result
is
not
None
:
return
result
return
result
return
None
return
None
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
63375f0c
...
@@ -1246,11 +1246,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1246,11 +1246,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
end_idx
=
start_idx
+
num_sampled_ids
end_idx
=
start_idx
+
num_sampled_ids
self
.
input_batch
.
token_ids_cpu
[
i
,
start_idx
:
end_idx
]
=
sampled_ids
self
.
input_batch
.
token_ids_cpu
[
i
,
start_idx
:
end_idx
]
=
sampled_ids
drafter_output
=
self
.
drafter
.
propose
(
drafter_output
=
self
.
drafter
.
propose
(
self
.
input_batch
.
token_ids_cpu
[
i
,
:
end_idx
],
self
.
input_batch
.
token_ids_cpu
[
i
,
:
end_idx
])
self
.
speculative_config
.
prompt_lookup_min
,
self
.
speculative_config
.
prompt_lookup_max
,
self
.
speculative_config
.
num_speculative_tokens
,
)
if
drafter_output
is
None
or
len
(
drafter_output
)
==
0
:
if
drafter_output
is
None
or
len
(
drafter_output
)
==
0
:
draft_token_ids
.
append
([])
draft_token_ids
.
append
([])
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment