Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d7740ea4
"vscode:/vscode.git/clone" did not exist on "dfd951ed9b9eb4af2452764edd808599b5e8901e"
Unverified
Commit
d7740ea4
authored
May 09, 2024
by
SangBin Cho
Committed by
GitHub
May 08, 2024
Browse files
[Core] Optimize sampler get_logprobs (#4594)
parent
cc466a32
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
49 deletions
+68
-49
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+68
-49
No files found.
vllm/model_executor/layers/sampler.py
View file @
d7740ea4
...
@@ -782,13 +782,14 @@ def _get_logprobs(
...
@@ -782,13 +782,14 @@ def _get_logprobs(
top_logprobs
,
top_token_ids
=
torch
.
topk
(
logprobs
,
top_logprobs
,
top_token_ids
=
torch
.
topk
(
logprobs
,
largest_num_logprobs
,
largest_num_logprobs
,
dim
=-
1
)
dim
=-
1
)
top_logprobs
=
top_logprobs
.
cpu
()
top_token_ids
=
top_token_ids
.
cpu
()
else
:
else
:
top_logprobs
,
top_token_ids
=
None
,
None
top_logprobs
,
top_token_ids
=
None
,
None
selected_logprobs
=
selected_logprobs
.
cpu
()
selected_logprobs
=
selected_logprobs
.
to
(
'cpu'
)
ranks
=
ranks
.
cpu
()
ranks
=
ranks
.
to
(
'cpu'
)
if
top_logprobs
is
not
None
and
top_token_ids
is
not
None
:
top_logprobs
=
top_logprobs
.
to
(
'cpu'
)
top_token_ids
=
top_token_ids
.
to
(
'cpu'
)
# Find prompt/sample logprobs.
# Find prompt/sample logprobs.
prompt_logprobs_per_seq_group
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
prompt_logprobs_per_seq_group
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
...
@@ -828,37 +829,48 @@ def _get_prompt_logprob_if_needed(
...
@@ -828,37 +829,48 @@ def _get_prompt_logprob_if_needed(
# Find prompt logprobs
# Find prompt logprobs
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
if
(
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
)
:
if
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
:
prompt_logprobs
=
[]
prompt_logprobs
=
[]
num_logprobs
=
sampling_params
.
prompt_logprobs
num_logprobs
=
sampling_params
.
prompt_logprobs
next_prompt_tokens
=
_get_next_prompt_tokens
(
seq_group
)
next_prompt_tokens
=
_get_next_prompt_tokens
(
seq_group
)
for
token_id
in
next_prompt_tokens
:
# Pre-select indexes and create a list. It is faster than calling .item
# repetitively.
selected_logprob_items
=
selected_logprobs
[
selected_logprobs_idx
:
selected_logprobs_idx
+
len
(
next_prompt_tokens
)].
tolist
()
rank_items
=
ranks
[
selected_logprobs_idx
:
selected_logprobs_idx
+
len
(
next_prompt_tokens
)].
tolist
()
for
idx
,
token_id
in
enumerate
(
next_prompt_tokens
):
# Calculate the prompt logprob of the real prompt tokens.
# Calculate the prompt logprob of the real prompt tokens.
# Use tuple here for performance (to use to_list()).
# {token_id: (logprob, rank_from_vocab)}
# {token_id: (logprob, rank_from_vocab)}
prompt_logprobs_dict
:
Dict
[
int
,
Tuple
[
float
,
int
]]
=
{
prompt_logprobs_dict
:
Dict
[
int
,
Tuple
[
float
,
int
]]
=
{
token_id
:
(
selected_logprobs
[
selected_logprobs_idx
].
item
(),
token_id
:
(
selected_logprob_items
[
idx
],
rank_items
[
idx
])
ranks
[
selected_logprobs_idx
].
item
())
}
}
# Add top K prompt logprobs along with its rank.
# Add top K prompt logprobs along with its rank.
if
num_logprobs
>
0
:
if
num_logprobs
>
0
:
prompt_logprobs_dict
.
update
(
top_ids
=
top_token_ids
[
zip
(
top_logprob_idx
,
:
num_logprobs
].
tolist
()
top_token_ids
[
top_logprob_idx
,
:
num_logprobs
].
tolist
(),
top_probs
=
top_logprobs
[
zip
(
top_logprob_idx
,
:
num_logprobs
].
tolist
()
top_logprobs
[
# Top K is already sorted by rank, so we can use 1 ~
top_logprob_idx
,
:
num_logprobs
].
tolist
(),
# num_logprobs + 1 for rank.
# This is ranks. Since top_logprob is sorted,
top_ranks
=
range
(
1
,
num_logprobs
+
1
)
# we can just use a range here.
prompt_logprobs_dict
.
update
({
range
(
1
,
num_logprobs
+
1
))))
top_id
:
(
top_prob
,
rank
)
for
top_id
,
top_prob
,
rank
in
zip
(
top_ids
,
top_probs
,
top_ranks
)
})
prompt_logprobs
.
append
({
prompt_logprobs
.
append
({
token_id
:
Logprob
(
*
logprob_and_rank
)
token_id
:
Logprob
(
*
logprob_and_rank
)
for
token_id
,
logprob_and_rank
in
prompt_logprobs_dict
.
items
()
for
token_id
,
logprob_and_rank
in
prompt_logprobs_dict
.
items
()
})
})
# + 1 to go to the next prompt token.
# + 1 to go to the next prompt token.
top_logprob_idx
+=
1
top_logprob_idx
+=
1
selected_logprobs_idx
+=
1
# + len(next_prompt_tokens) to go to the next prompt.
selected_logprobs_idx
+=
len
(
next_prompt_tokens
)
return
prompt_logprobs
,
top_logprob_idx
,
selected_logprobs_idx
return
prompt_logprobs
,
top_logprob_idx
,
selected_logprobs_idx
...
@@ -874,47 +886,54 @@ def _get_sampled_logprob_if_needed(
...
@@ -874,47 +886,54 @@ def _get_sampled_logprob_if_needed(
):
):
"""Compute the sample logprob if needed."""
"""Compute the sample logprob if needed."""
seq_ids
=
seq_group
.
seq_ids
seq_ids
=
seq_group
.
seq_ids
num_logprobs
=
seq_group
.
sampling_params
.
logprobs
num_logprobs
=
seq_group
.
sampling_params
.
logprobs
or
0
if
num_logprobs
is
None
:
num_logprobs
=
0
sampled_logprobs
:
SampleLogprobs
=
[]
sampled_logprobs
:
SampleLogprobs
=
[]
next_token_ids
,
parent_seq_ids
=
sample_result
next_token_ids
,
parent_seq_ids
=
sample_result
if
seq_group
.
do_sample
:
if
seq_group
.
do_sample
:
assert
len
(
next_token_ids
)
>
0
assert
len
(
next_token_ids
)
>
0
for
(
next_token_id
,
parent_id
)
in
zip
(
next_token_ids
,
parent_seq_ids
):
# Pre-select items from tensor. tolist() is faster than repetitive
# Calculate the sample logprob of the real sampled tokens.
# `.item()` calls.
# Use tuple here for performance (to use to_list()).
selected_logprob_items
=
selected_logprobs
[
# token_id: (logprob, rank_from_vocab)
selected_logprobs_idx
:
selected_logprobs_idx
+
sampled_logprobs_dict
:
Dict
[
int
,
Tuple
[
float
,
int
]]
=
{
len
(
next_token_ids
)].
tolist
()
next_token_id
:
rank_items
=
ranks
[
selected_logprobs_idx
:
selected_logprobs_idx
+
(
selected_logprobs
[
selected_logprobs_idx
].
item
(),
len
(
next_token_ids
)].
tolist
()
ranks
[
selected_logprobs_idx
].
item
())
for
idx
,
(
next_token_id
,
parent_id
)
in
enumerate
(
zip
(
next_token_ids
,
parent_seq_ids
)):
# Get the logprob of a sampled token.
sampled_logprobs_dict
=
{
next_token_id
:
(
selected_logprob_items
[
idx
],
rank_items
[
idx
])
}
}
# +1 to go to the next sampled token. Note that
# Get top K logprobs.
# selected_logprobs can contain duplicates unlike top_logprobs
if
num_logprobs
>
0
:
# when beam search is enabled.
top_ids
=
top_token_ids
[
top_logprob_idx
+
selected_logprobs_idx
+=
1
parent_id
,
:
num_logprobs
].
tolist
()
top_probs
=
top_logprobs
[
top_logprob_idx
+
# Second, add top K logprobs along with its rank.
parent_id
,
:
num_logprobs
].
tolist
()
if
num_logprobs
>=
0
:
# Top K is already sorted by rank, so we can use 1 ~
sampled_logprobs_dict
.
update
(
# num_logprobs + 1 for rank.
zip
(
top_ranks
=
range
(
1
,
num_logprobs
+
1
)
top_token_ids
[
top_logprob_idx
+
sampled_logprobs_dict
.
update
({
parent_id
,
:
num_logprobs
].
tolist
(),
top_id
:
(
top_prob
,
rank
)
zip
(
for
top_id
,
top_prob
,
rank
in
zip
(
top_ids
,
top_probs
,
top_logprobs
[
top_logprob_idx
+
top_ranks
)
parent_id
,
:
num_logprobs
].
tolist
(),
})
# This is rank. Since top_logprob is sorted, we
# can just use a range here.
range
(
1
,
num_logprobs
+
1
))))
sampled_logprobs
.
append
({
sampled_logprobs
.
append
({
token_id
:
Logprob
(
*
logprob_and_rank
)
token_id
:
Logprob
(
*
logprob_and_rank
)
for
token_id
,
logprob_and_rank
in
for
token_id
,
logprob_and_rank
in
sampled_logprobs_dict
.
items
()
sampled_logprobs_dict
.
items
()
})
})
# There are len(seq_ids) number of sampled tokens for the current
# sequence group in top_logprobs. Jump to the next seq_group.
# NOTE: This part of code is not intuitive. `selected_logprobs` include
# logprobs for the current step, which has len(next_token_ids) tokens
# per sequence group. `logprobs` includes logprobs from the previous
# steps, which has len(seq_ids) tokens per sequence group.
# Iterate to the next sequence group in a batch.
selected_logprobs_idx
+=
len
(
next_token_ids
)
# Iterate to the next sequence group in a batch.
top_logprob_idx
+=
len
(
seq_ids
)
top_logprob_idx
+=
len
(
seq_ids
)
return
sampled_logprobs
,
top_logprob_idx
,
selected_logprobs_idx
return
sampled_logprobs
,
top_logprob_idx
,
selected_logprobs_idx
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment