Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ce532ff4
"vllm/vscode:/vscode.git/clone" did not exist on "f5dda63eb5fcb5624b93fa5f09da01d5372bbce4"
Unverified
Commit
ce532ff4
authored
May 13, 2024
by
Cody Yu
Committed by
GitHub
May 13, 2024
Browse files
[Speculative decoding] Improve n-gram efficiency (#4724)
parent
8bc68e19
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
132 additions
and
58 deletions
+132
-58
tests/spec_decode/test_ngram_worker.py
tests/spec_decode/test_ngram_worker.py
+9
-8
vllm/config.py
vllm/config.py
+8
-5
vllm/spec_decode/ngram_worker.py
vllm/spec_decode/ngram_worker.py
+52
-45
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+63
-0
No files found.
tests/spec_decode/test_ngram_worker.py
View file @
ce532ff4
...
...
@@ -34,8 +34,8 @@ def test_ngram_algo_correctness_for_single_no_match():
max_proposal_len
=
20
,
)
# set ngram window
(0
, 3], which is window=1/2/3
ngram_worker
.
set_ngram_window_size
(
0
,
3
)
# set ngram window
[1
, 3], which is window=1/2/3
ngram_worker
.
set_ngram_window_size
(
1
,
3
)
prompts
=
[
# shall find no candidate
...
...
@@ -90,8 +90,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
max_proposal_len
=
20
,
)
# set ngram window
(0
, 3], which is window=1/2/3
ngram_worker
.
set_ngram_window_size
(
0
,
3
)
# set ngram window
[1
, 3], which is window=1/2/3
ngram_worker
.
set_ngram_window_size
(
1
,
3
)
prompts
=
[
# shall find no candidate
...
...
@@ -128,11 +128,12 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
assert
proposals
.
proposal_probs
.
shape
[:
-
1
]
==
torch
.
Size
([
5
,
proposal_len
])
assert
proposals
.
proposal_lens
.
shape
==
torch
.
Size
([
5
])
# the first sequence has no match so proposal_len should be overwritten to 0
assert
proposals
.
proposal_lens
.
tolist
(
)
==
[
proposal_len
for
_
in
range
(
4
)]
+
[
0
]
)
==
[
0
]
+
[
proposal_len
for
_
in
range
(
3
)]
+
[
0
]
for
i
in
range
(
proposal_len
):
assert
proposals
.
proposal_token_ids
[
0
][
i
]
==
0
assert
proposals
.
proposal_token_ids
[
0
][
i
]
==
-
1
assert
proposals
.
proposal_token_ids
[
1
][
i
]
==
prompts
[
1
][
i
+
1
]
assert
proposals
.
proposal_token_ids
[
2
][
i
]
==
prompts
[
2
][
i
+
3
]
assert
proposals
.
proposal_token_ids
[
3
][
i
]
==
prompts
[
3
][
i
+
5
]
...
...
@@ -167,8 +168,8 @@ def test_ngram_algo_correctness_for_batches_match_all():
max_proposal_len
=
20
,
)
# set ngram window
(
0, 3], which is window=1/2/3
ngram_worker
.
set_ngram_window_size
(
0
,
3
)
# set ngram window
[
0, 3], which is window=1/2/3
ngram_worker
.
set_ngram_window_size
(
1
,
3
)
prompts
=
[
# shall find candidate 12,13,14,15,16
...
...
vllm/config.py
View file @
ce532ff4
...
...
@@ -784,12 +784,15 @@ class SpeculativeConfig:
draft_quantization
=
None
if
speculative_model
==
"[ngram]"
:
assert
(
ngram_prompt_lookup_max
is
not
None
and
ngram_prompt_lookup_max
>
0
)
if
ngram_prompt_lookup_min
is
None
:
ngram_prompt_lookup_min
=
0
else
:
assert
ngram_prompt_lookup_max
>
ngram_prompt_lookup_min
ngram_prompt_lookup_min
=
1
if
ngram_prompt_lookup_max
is
None
or
ngram_prompt_lookup_max
<
1
:
raise
ValueError
(
f
"
{
ngram_prompt_lookup_max
=
}
must be > 0"
)
if
ngram_prompt_lookup_min
<
1
:
raise
ValueError
(
f
"
{
ngram_prompt_lookup_min
=
}
must be > 0"
)
if
ngram_prompt_lookup_min
>
ngram_prompt_lookup_max
:
raise
ValueError
(
f
"
{
ngram_prompt_lookup_min
=
}
cannot be "
f
"larger than
{
ngram_prompt_lookup_max
=
}
"
)
# TODO: current we still need extract vocab_size from target model
# config, in future, we may try refactor it out, and set
...
...
vllm/spec_decode/ngram_worker.py
View file @
ce532ff4
...
...
@@ -77,9 +77,11 @@ class NGramWorker(LoraNotSupportedWorkerBase):
"""
self
.
_raise_if_unsupported
(
execute_model_req
)
arr
=
[]
has_spec_out
=
False
for
seq_group_metadata
in
execute_model_req
.
seq_group_metadata_list
:
token_id_list
=
[]
token_prob_list
=
[]
for
idx
,
seq_group_metadata
in
enumerate
(
execute_model_req
.
seq_group_metadata_list
):
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
input_ids
=
torch
.
as_tensor
(
seq_data
.
get_token_ids
(),
...
...
@@ -89,59 +91,64 @@ class NGramWorker(LoraNotSupportedWorkerBase):
for
ngram_size
in
range
(
min
(
self
.
ngram_prompt_lookup_max
,
input_length
-
1
),
self
.
ngram_prompt_lookup_min
,
self
.
ngram_prompt_lookup_min
-
1
,
-
1
,
):
ngram_tensor
=
input_ids
[
-
1
*
ngram_size
:]
windows
=
input_ids
.
unfold
(
dimension
=
0
,
size
=
ngram_size
,
step
=
1
)
matches
=
(
windows
==
ngram_tensor
).
all
(
dim
=
1
)
match_indices
=
matches
.
nonzero
(
as_tuple
=
True
)[
0
]
if
match_indices
.
size
()[
0
]
>
1
:
ngram_tensor
=
input_ids
[
-
ngram_size
:]
proposal_start_idx
=
None
if
ngram_size
==
1
:
# Do not match itself and do not use unfold and all
matches
=
(
input_ids
[:
-
1
]
==
ngram_tensor
)
else
:
windows
=
input_ids
.
unfold
(
dimension
=
0
,
size
=
ngram_size
,
step
=
1
)
# Do not match itself
matches
=
(
windows
[:
-
1
]
==
ngram_tensor
).
all
(
dim
=-
1
)
# first_match includes "values" (bool), indicating whether
# the match is found, and "indices", indicating the index
# of the first match.
# Note that "first_match.values.item()" triggers GPU-CPU
# sync so it is a bit inefficient, but we have not found
# a better way to do this.
first_match
=
matches
.
max
(
dim
=-
1
)
if
first_match
.
values
.
item
():
proposal_start_idx
=
first_match
.
indices
.
add_
(
ngram_size
)
spec_indices
=
(
proposal_start_idx
).
repeat
(
sample_len
)
+
torch
.
arange
(
sample_len
,
device
=
self
.
device
)
spec_indices
.
clamp_
(
max
=
input_ids
.
shape
[
-
1
]
-
1
)
res
=
input_ids
.
gather
(
dim
=-
1
,
index
=
spec_indices
)
token_id_list
.
append
(
res
)
token_prob_list
.
append
(
torch
.
nn
.
functional
.
one_hot
(
res
,
num_classes
=
self
.
vocab_size
).
to
(
torch
.
float32
))
has_spec_out
=
True
res
=
seq_data
.
get_token_ids
()
res
=
res
[
match_indices
[
0
]
+
ngram_size
:
match_indices
[
0
]
+
ngram_size
+
sample_len
]
res_len
=
len
(
res
)
# pad 0 towards output as sample_len tokens required
res
+=
[
0
]
*
(
sample_len
-
res_len
)
break
else
:
# if no candidate found, fill with 0
res
=
[
0
]
*
sample_len
arr
.
append
(
res
)
token_id_list
.
append
(
None
)
token_prob_list
.
append
(
None
)
if
not
has_spec_out
:
return
None
,
False
outputs
=
[]
token_ids
=
torch
.
as_tensor
(
arr
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
indices
=
token_ids
.
unsqueeze
(
2
)
outputs
:
List
[
Optional
[
SamplerOutput
]]
=
[]
for
idx
in
range
(
len
(
execute_model_req
.
seq_group_metadata_list
)):
if
token_id_list
[
idx
]
is
None
:
outputs
.
append
(
None
)
else
:
outputs
.
append
(
SamplerOutput
(
outputs
=
None
,
sampled_token_probs
=
token_prob_list
[
idx
],
logprobs
=
torch
.
zeros
((
sample_len
,
self
.
vocab_size
),
dtype
=
torch
.
float32
,
device
=
self
.
device
),
sampled_token_ids
=
token_id_list
[
idx
],
))
token_probs
=
torch
.
zeros
(
(
len
(
execute_model_req
.
seq_group_metadata_list
),
sample_len
,
self
.
vocab_size
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
token_probs
.
scatter_
(
2
,
indices
,
1
)
token_logprobs
=
torch
.
zeros
(
(
len
(
execute_model_req
.
seq_group_metadata_list
),
sample_len
,
self
.
vocab_size
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
for
i
in
range
(
len
(
execute_model_req
.
seq_group_metadata_list
)):
outputs
.
append
(
SamplerOutput
(
outputs
=
None
,
sampled_token_probs
=
token_probs
[
i
],
logprobs
=
token_logprobs
[
i
],
sampled_token_ids
=
token_ids
[
i
],
))
return
outputs
,
False
def
get_spec_proposals
(
...
...
vllm/spec_decode/top1_proposer.py
View file @
ce532ff4
...
...
@@ -73,6 +73,14 @@ class Top1Proposer(SpeculativeProposer):
execute_model_req
=
nonzero_execute_model_req
,
sample_len
=
proposal_len
,
)
(
proposal_lens
,
maybe_sampler_output
,
nonzero_proposal_len_indices
,
)
=
self
.
_remove_no_proposal_seqs
(
proposal_lens
,
maybe_sampler_output
,
nonzero_proposal_len_indices
,
transposed
)
else
:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output
=
None
...
...
@@ -140,6 +148,61 @@ class Top1Proposer(SpeculativeProposer):
nonzero_proposal_len_indices
,
)
def
_remove_no_proposal_seqs
(
self
,
proposal_lens
,
maybe_sampler_output
,
nonzero_proposal_len_indices
,
transposed
):
"""Remove sequences from nonzero_proposal_len_indices and reset
their proposal_len to 0 the draft worker does not provide a proposal
(maybe_sampler_output=None). This can avoid scoring overheads.
"""
# If maybe_sampler_output is None, then the draft worker did not
# provide a proposal for any sequence and thus no action needed.
# Also we do not support transposed maybe_sampler_output for now
# because it seems not straightforward for draft workers outputting
# transposed sampler outputs to handle the case of no proposal.
if
maybe_sampler_output
is
None
or
transposed
:
return
(
proposal_lens
,
maybe_sampler_output
,
nonzero_proposal_len_indices
)
new_proposal_lens
:
List
[
int
]
=
[]
new_nonzero_proposal_len_indices
:
List
[
int
]
=
[]
new_maybe_sampler_output
:
List
[
SamplerOutput
]
=
[]
nonzero_proposal_len_idx_ptr
=
0
seq_idx
=
0
while
seq_idx
<
len
(
proposal_lens
)
and
nonzero_proposal_len_idx_ptr
<
len
(
nonzero_proposal_len_indices
):
if
seq_idx
<
nonzero_proposal_len_indices
[
nonzero_proposal_len_idx_ptr
]:
# Sequence is not in the original nonzero_proposal_len_indices,
# meaning that it has a proposal length of 0 before sending to
# the draft worker.
assert
proposal_lens
[
seq_idx
]
==
0
new_proposal_lens
.
append
(
0
)
else
:
# Sequence is in the original nonzero_proposal_len_indices
if
maybe_sampler_output
[
nonzero_proposal_len_idx_ptr
]
is
None
:
# but does not have a proposal from the draft worker.
new_proposal_lens
.
append
(
0
)
else
:
# and has a proposal from the draft worker. Add it to the
# new nonzero proposal list and keep the sampler output.
new_proposal_lens
.
append
(
proposal_lens
[
seq_idx
])
new_nonzero_proposal_len_indices
.
append
(
seq_idx
)
new_maybe_sampler_output
.
append
(
maybe_sampler_output
[
nonzero_proposal_len_idx_ptr
])
nonzero_proposal_len_idx_ptr
+=
1
seq_idx
+=
1
# The remaining sequences should have proposal length of 0.
new_proposal_lens
.
extend
(
proposal_lens
[
seq_idx
:])
# We assume sampler_output will not be a list of all Nones.
# In this case this function should not be called.
assert
new_maybe_sampler_output
return
(
new_proposal_lens
,
new_maybe_sampler_output
,
new_nonzero_proposal_len_indices
)
def
_merge_outputs
(
self
,
batch_size
:
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment