Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
83450458
Unverified
Commit
83450458
authored
Oct 16, 2024
by
Lily Liu
Committed by
GitHub
Oct 16, 2024
Browse files
[Performance][Spec Decode] Optimize ngram lookup performance (#9333)
parent
5b8a1fde
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
6 deletions
+11
-6
vllm/spec_decode/ngram_worker.py
vllm/spec_decode/ngram_worker.py
+11
-6
No files found.
vllm/spec_decode/ngram_worker.py
View file @
83450458
...
@@ -67,9 +67,16 @@ class NGramWorker(NonLLMProposerWorkerBase):
...
@@ -67,9 +67,16 @@ class NGramWorker(NonLLMProposerWorkerBase):
execute_model_req
.
seq_group_metadata_list
):
execute_model_req
.
seq_group_metadata_list
):
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
seq_len
=
seq_data
.
get_len
()
# When seq_len is less than 3072 (3K), we use CPU to perform
# the ngram match. Otherwise, we use the device specified in
# the model config (normally GPU). 3072 is a rough threshold
# based on profiling on H100, and it can be adjusted based
# on the actual performance on different hardware.
cur_device
=
"cpu"
if
seq_len
<
3072
else
self
.
device
input_ids
=
torch
.
as_tensor
(
seq_data
.
get_token_ids
(),
input_ids
=
torch
.
as_tensor
(
seq_data
.
get_token_ids
(),
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
device
=
cur_
device
)
input_length
=
seq_data
.
get_len
()
input_length
=
seq_data
.
get_len
()
for
ngram_size
in
range
(
for
ngram_size
in
range
(
...
@@ -91,17 +98,15 @@ class NGramWorker(NonLLMProposerWorkerBase):
...
@@ -91,17 +98,15 @@ class NGramWorker(NonLLMProposerWorkerBase):
# first_match includes "values" (bool), indicating whether
# first_match includes "values" (bool), indicating whether
# the match is found, and "indices", indicating the index
# the match is found, and "indices", indicating the index
# of the first match.
# of the first match.
# Note that "first_match.values.item()" triggers GPU-CPU
# sync so it is a bit inefficient, but we have not found
# a better way to do this.
first_match
=
matches
.
max
(
dim
=-
1
)
first_match
=
matches
.
max
(
dim
=-
1
)
if
first_match
.
values
.
item
():
if
first_match
.
values
.
item
():
proposal_start_idx
=
first_match
.
indices
.
add_
(
ngram_size
)
proposal_start_idx
=
first_match
.
indices
.
add_
(
ngram_size
)
spec_indices
=
(
spec_indices
=
(
proposal_start_idx
).
repeat
(
sample_len
)
+
torch
.
arange
(
proposal_start_idx
).
repeat
(
sample_len
)
+
torch
.
arange
(
sample_len
,
device
=
self
.
device
)
sample_len
,
device
=
cur_
device
)
spec_indices
.
clamp_
(
max
=
input_ids
.
shape
[
-
1
]
-
1
)
spec_indices
.
clamp_
(
max
=
input_ids
.
shape
[
-
1
]
-
1
)
res
=
input_ids
.
gather
(
dim
=-
1
,
index
=
spec_indices
)
res
=
input_ids
.
gather
(
dim
=-
1
,
index
=
spec_indices
).
to
(
self
.
device
)
token_id_list
.
append
(
res
)
token_id_list
.
append
(
res
)
token_prob_list
.
append
(
token_prob_list
.
append
(
torch
.
nn
.
functional
.
one_hot
(
torch
.
nn
.
functional
.
one_hot
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment