Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
db9e5708
Unverified
Commit
db9e5708
authored
Jul 30, 2024
by
Peng Guanwen
Committed by
GitHub
Jul 29, 2024
Browse files
[Core] Reduce unnecessary compute when logprobs=None (#6532)
parent
766435e6
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
135 additions
and
80 deletions
+135
-80
tests/samplers/test_logprobs.py
tests/samplers/test_logprobs.py
+37
-2
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+81
-63
vllm/outputs.py
vllm/outputs.py
+9
-8
vllm/sampling_params.py
vllm/sampling_params.py
+8
-7
No files found.
tests/samplers/test_logprobs.py
View file @
db9e5708
...
...
@@ -14,7 +14,7 @@ MODELS = ["facebook/opt-125m"]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
# needed for comparing logprobs with HF
@
pytest
.
mark
.
parametrize
(
"chunked_prefill_token_size"
,
[
1
,
4
,
16
,
-
1
])
@
pytest
.
mark
.
parametrize
(
"num_top_logprobs"
,
[
6
])
# 32000 == vocab_size
@
pytest
.
mark
.
parametrize
(
"num_top_logprobs"
,
[
0
,
6
])
# 32000 == vocab_size
@
pytest
.
mark
.
parametrize
(
"detokenize"
,
[
True
,
False
])
def
test_get_prompt_logprobs
(
hf_runner
,
...
...
@@ -63,7 +63,10 @@ def test_get_prompt_logprobs(
assert
result
.
outputs
[
0
].
logprobs
is
not
None
assert
len
(
result
.
outputs
[
0
].
logprobs
)
==
max_tokens
for
logprobs
in
result
.
outputs
[
0
].
logprobs
:
assert
len
(
logprobs
)
==
num_top_logprobs
# If the output token is not included in the top X
# logprob, it can return 1 more data
assert
(
len
(
logprobs
)
==
num_top_logprobs
or
len
(
logprobs
)
==
num_top_logprobs
+
1
)
output_text
=
result
.
outputs
[
0
].
text
output_string_from_most_likely_tokens_lst
:
List
[
str
]
=
[]
for
top_logprobs
in
result
.
outputs
[
0
].
logprobs
:
...
...
@@ -135,3 +138,35 @@ def test_max_logprobs():
bad_sampling_params
=
SamplingParams
(
logprobs
=
2
)
with
pytest
.
raises
(
ValueError
):
runner
.
generate
([
"Hello world"
],
sampling_params
=
bad_sampling_params
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"chunked_prefill_token_size"
,
[
1
,
4
,
16
,
-
1
])
@
pytest
.
mark
.
parametrize
(
"detokenize"
,
[
True
,
False
])
def
test_none_logprobs
(
vllm_runner
,
model
,
chunked_prefill_token_size
:
int
,
detokenize
:
bool
,
example_prompts
):
max_num_seqs
=
256
enable_chunked_prefill
=
False
max_num_batched_tokens
=
None
if
chunked_prefill_token_size
!=
-
1
:
enable_chunked_prefill
=
True
max_num_seqs
=
min
(
chunked_prefill_token_size
,
max_num_seqs
)
max_num_batched_tokens
=
chunked_prefill_token_size
max_tokens
=
5
with
vllm_runner
(
model
,
enable_chunked_prefill
=
enable_chunked_prefill
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_seqs
=
max_num_seqs
,
)
as
vllm_model
:
sampling_params_logprobs_none
=
SamplingParams
(
max_tokens
=
max_tokens
,
logprobs
=
None
,
temperature
=
0.0
,
detokenize
=
detokenize
)
results_logprobs_none
=
vllm_model
.
model
.
generate
(
example_prompts
,
sampling_params
=
sampling_params_logprobs_none
)
for
i
in
range
(
len
(
results_logprobs_none
)):
assert
results_logprobs_none
[
i
].
outputs
[
0
].
logprobs
is
None
assert
results_logprobs_none
[
i
].
outputs
[
0
].
cumulative_logprob
is
None
vllm/model_executor/layers/sampler.py
View file @
db9e5708
"""A layer that samples the next tokens from the model's outputs."""
import
itertools
from
math
import
inf
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -774,8 +775,11 @@ def _get_logprobs(
# The next token ids to get the logprob value from.
next_token_ids
:
List
[
int
]
=
[]
# The largest requested number of logprobs. We find logprobs as many as the
# largest num logprobs in this API.
largest_num_logprobs
=
1
# largest num logprobs in this API. If every logprobs is None, it will be
# set to -1.
largest_num_logprobs
=
-
1
# If beam search is enabled.
use_beam_search
=
False
# Select indices to compute logprob from, ranks of token ids, and the top
# k token ids from logprobs.
...
...
@@ -808,6 +812,8 @@ def _get_logprobs(
largest_num_logprobs
=
max
(
largest_num_logprobs
,
sampling_params
.
logprobs
)
use_beam_search
=
use_beam_search
or
sampling_params
.
use_beam_search
assert
len
(
next_token_ids
)
==
len
(
query_indices
)
if
len
(
query_indices
)
==
0
:
...
...
@@ -815,8 +821,15 @@ def _get_logprobs(
empty_prompt_logprob
:
Optional
[
PromptLogprobs
]
=
None
return
[
empty_prompt_logprob
],
[
empty_sampled_logprob
]
selected_logprobs
,
ranks
=
None
,
None
top_logprobs
,
top_token_ids
=
None
,
None
# If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
# skip the whole logprob calculation.
if
largest_num_logprobs
>=
0
or
use_beam_search
:
query_indices_gpu
=
torch
.
tensor
(
query_indices
,
device
=
logprobs
.
device
)
next_token_ids_gpu
=
torch
.
tensor
(
next_token_ids
,
device
=
logprobs
.
device
)
next_token_ids_gpu
=
torch
.
tensor
(
next_token_ids
,
device
=
logprobs
.
device
)
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
# contain duplicates if beam search is enabled.
...
...
@@ -830,20 +843,18 @@ def _get_logprobs(
)
assert
selected_logprobs
.
shape
[
0
]
==
ranks
.
shape
[
0
]
# We need to compute top k only if there exists logprobs > 0.
if
largest_num_logprobs
>
0
:
# Logprobs of topk tokens for a batch of sequence groups.
# (num_query_tokens_across_batch).
if
largest_num_logprobs
>
0
:
top_logprobs
,
top_token_ids
=
torch
.
topk
(
logprobs
,
largest_num_logprobs
,
dim
=-
1
)
else
:
top_logprobs
,
top_token_ids
=
None
,
None
top_logprobs
=
top_logprobs
.
to
(
'cpu'
)
top_token_ids
=
top_token_ids
.
to
(
'cpu'
)
selected_logprobs
=
selected_logprobs
.
to
(
'cpu'
)
ranks
=
ranks
.
to
(
'cpu'
)
if
top_logprobs
is
not
None
and
top_token_ids
is
not
None
:
top_logprobs
=
top_logprobs
.
to
(
'cpu'
)
top_token_ids
=
top_token_ids
.
to
(
'cpu'
)
# Find prompt/sample logprobs.
prompt_logprobs_per_seq_group
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
...
...
@@ -940,12 +951,18 @@ def _get_sampled_logprob_if_needed(
):
"""Compute the sample logprob if needed."""
seq_ids
=
seq_group
.
seq_ids
num_logprobs
=
seq_group
.
sampling_params
.
logprobs
or
0
num_logprobs
=
seq_group
.
sampling_params
.
logprobs
use_beam_search
=
seq_group
.
sampling_params
.
use_beam_search
sampled_logprobs
:
SampleLogprobs
=
[]
next_token_ids
,
parent_seq_ids
=
sample_result
if
seq_group
.
do_sample
:
assert
len
(
next_token_ids
)
>
0
if
num_logprobs
is
None
and
not
use_beam_search
:
for
next_token_id
in
next_token_ids
:
# Use a dummy logprob
sampled_logprobs
.
append
({
next_token_id
:
Logprob
(
inf
)})
else
:
# Pre-select items from tensor. tolist() is faster than repetitive
# `.item()` calls.
selected_logprob_items
=
selected_logprobs
[
...
...
@@ -953,25 +970,26 @@ def _get_sampled_logprob_if_needed(
len
(
next_token_ids
)].
tolist
()
rank_items
=
ranks
[
selected_logprobs_idx
:
selected_logprobs_idx
+
len
(
next_token_ids
)].
tolist
()
for
idx
,
(
next_token_id
,
parent_id
)
in
enumerate
(
zip
(
next_token_ids
,
parent_seq_ids
)):
for
idx
,
(
next_token_id
,
parent_id
)
in
enumerate
(
zip
(
next_token_ids
,
parent_seq_ids
)):
# Get the logprob of a sampled token.
sampled_logprobs_dict
=
{
next_token_id
:
(
selected_logprob_items
[
idx
],
rank_items
[
idx
])
next_token_id
:
(
selected_logprob_items
[
idx
],
rank_items
[
idx
])
}
if
num_logprobs
is
not
None
and
num_logprobs
>
0
:
# Get top K logprobs.
if
num_logprobs
>
0
:
top_ids
=
top_token_ids
[
top_logprob_idx
+
parent_id
,
:
num_logprobs
].
tolist
()
top_probs
=
top_logprobs
[
top_logprob_idx
+
parent_id
,
:
num_logprobs
].
tolist
()
top_probs
=
top_logprobs
[
top_logprob_idx
+
parent_id
,
:
num_logprobs
].
tolist
()
# Top K is already sorted by rank, so we can use 1 ~
# num_logprobs + 1 for rank.
top_ranks
=
range
(
1
,
num_logprobs
+
1
)
sampled_logprobs_dict
.
update
({
top_id
:
(
top_prob
,
rank
)
for
top_id
,
top_prob
,
rank
in
zip
(
top_ids
,
top_probs
,
top_ranks
)
for
top_id
,
top_prob
,
rank
in
zip
(
top_ids
,
top_probs
,
top_ranks
)
})
sampled_logprobs
.
append
({
...
...
vllm/outputs.py
View file @
db9e5708
...
...
@@ -29,7 +29,7 @@ class CompletionOutput:
index
:
int
text
:
str
token_ids
:
Tuple
[
int
,
...]
cumulative_logprob
:
float
cumulative_logprob
:
Optional
[
float
]
logprobs
:
Optional
[
SampleLogprobs
]
finish_reason
:
Optional
[
str
]
=
None
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
...
...
@@ -124,10 +124,11 @@ class RequestOutput:
include_logprobs
=
seq_group
.
sampling_params
.
logprobs
is
not
None
text_buffer_length
=
seq_group
.
sampling_params
.
output_text_buffer_length
outputs
=
[
CompletionOutput
(
seqs
.
index
(
seq
),
CompletionOutput
(
seqs
.
index
(
seq
),
seq
.
get_output_text_to_return
(
text_buffer_length
),
seq
.
get_output_token_ids
(),
seq
.
get_cumulative_logprob
(),
seq
.
get_cumulative_logprob
()
if
include_logprobs
else
None
,
seq
.
output_logprobs
if
include_logprobs
else
None
,
SequenceStatus
.
get_finished_reason
(
seq
.
status
),
seq
.
stop_reason
)
for
seq
in
top_n_seqs
...
...
vllm/sampling_params.py
View file @
db9e5708
...
...
@@ -92,11 +92,12 @@ class SamplingParams:
min_tokens: Minimum number of tokens to generate per output sequence
before EOS or stop_token_ids can be generated
logprobs: Number of log probabilities to return per output token.
Note that the implementation follows the OpenAI API: The return
result includes the log probabilities on the `logprobs` most likely
tokens, as well the chosen tokens. The API will always return the
log probability of the sampled token, so there may be up to
`logprobs+1` elements in the response.
When set to None, no probability is returned. If set to a non-None
value, the result includes the log probabilities of the specified
number of most likely tokens, as well as the chosen tokens.
Note that the implementation follows the OpenAI API: The API will
always return the log probability of the sampled token, so there
may be up to `logprobs+1` elements in the response.
prompt_logprobs: Number of log probabilities to return per prompt token.
detokenize: Whether to detokenize the output. Defaults to True.
skip_special_tokens: Whether to skip special tokens in the output.
...
...
@@ -168,8 +169,8 @@ class SamplingParams:
self
.
ignore_eos
=
ignore_eos
self
.
max_tokens
=
max_tokens
self
.
min_tokens
=
min_tokens
self
.
logprobs
=
logprobs
self
.
prompt_logprobs
=
prompt_logprobs
self
.
logprobs
=
1
if
logprobs
is
True
else
logprobs
self
.
prompt_logprobs
=
1
if
prompt_logprobs
is
True
else
prompt_logprobs
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment