Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
db9e5708
Unverified
Commit
db9e5708
authored
Jul 30, 2024
by
Peng Guanwen
Committed by
GitHub
Jul 29, 2024
Browse files
[Core] Reduce unnecessary compute when logprobs=None (#6532)
parent
766435e6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
135 additions
and
80 deletions
+135
-80
tests/samplers/test_logprobs.py
tests/samplers/test_logprobs.py
+37
-2
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+81
-63
vllm/outputs.py
vllm/outputs.py
+9
-8
vllm/sampling_params.py
vllm/sampling_params.py
+8
-7
No files found.
tests/samplers/test_logprobs.py
View file @
db9e5708
...
@@ -14,7 +14,7 @@ MODELS = ["facebook/opt-125m"]
...
@@ -14,7 +14,7 @@ MODELS = ["facebook/opt-125m"]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
# needed for comparing logprobs with HF
[
"float"
])
# needed for comparing logprobs with HF
@
pytest
.
mark
.
parametrize
(
"chunked_prefill_token_size"
,
[
1
,
4
,
16
,
-
1
])
@
pytest
.
mark
.
parametrize
(
"chunked_prefill_token_size"
,
[
1
,
4
,
16
,
-
1
])
@
pytest
.
mark
.
parametrize
(
"num_top_logprobs"
,
[
6
])
# 32000 == vocab_size
@
pytest
.
mark
.
parametrize
(
"num_top_logprobs"
,
[
0
,
6
])
# 32000 == vocab_size
@
pytest
.
mark
.
parametrize
(
"detokenize"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"detokenize"
,
[
True
,
False
])
def
test_get_prompt_logprobs
(
def
test_get_prompt_logprobs
(
hf_runner
,
hf_runner
,
...
@@ -63,7 +63,10 @@ def test_get_prompt_logprobs(
...
@@ -63,7 +63,10 @@ def test_get_prompt_logprobs(
assert
result
.
outputs
[
0
].
logprobs
is
not
None
assert
result
.
outputs
[
0
].
logprobs
is
not
None
assert
len
(
result
.
outputs
[
0
].
logprobs
)
==
max_tokens
assert
len
(
result
.
outputs
[
0
].
logprobs
)
==
max_tokens
for
logprobs
in
result
.
outputs
[
0
].
logprobs
:
for
logprobs
in
result
.
outputs
[
0
].
logprobs
:
assert
len
(
logprobs
)
==
num_top_logprobs
# If the output token is not included in the top X
# logprob, it can return 1 more data
assert
(
len
(
logprobs
)
==
num_top_logprobs
or
len
(
logprobs
)
==
num_top_logprobs
+
1
)
output_text
=
result
.
outputs
[
0
].
text
output_text
=
result
.
outputs
[
0
].
text
output_string_from_most_likely_tokens_lst
:
List
[
str
]
=
[]
output_string_from_most_likely_tokens_lst
:
List
[
str
]
=
[]
for
top_logprobs
in
result
.
outputs
[
0
].
logprobs
:
for
top_logprobs
in
result
.
outputs
[
0
].
logprobs
:
...
@@ -135,3 +138,35 @@ def test_max_logprobs():
...
@@ -135,3 +138,35 @@ def test_max_logprobs():
bad_sampling_params
=
SamplingParams
(
logprobs
=
2
)
bad_sampling_params
=
SamplingParams
(
logprobs
=
2
)
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
runner
.
generate
([
"Hello world"
],
sampling_params
=
bad_sampling_params
)
runner
.
generate
([
"Hello world"
],
sampling_params
=
bad_sampling_params
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"chunked_prefill_token_size"
,
[
1
,
4
,
16
,
-
1
])
@
pytest
.
mark
.
parametrize
(
"detokenize"
,
[
True
,
False
])
def
test_none_logprobs
(
vllm_runner
,
model
,
chunked_prefill_token_size
:
int
,
detokenize
:
bool
,
example_prompts
):
max_num_seqs
=
256
enable_chunked_prefill
=
False
max_num_batched_tokens
=
None
if
chunked_prefill_token_size
!=
-
1
:
enable_chunked_prefill
=
True
max_num_seqs
=
min
(
chunked_prefill_token_size
,
max_num_seqs
)
max_num_batched_tokens
=
chunked_prefill_token_size
max_tokens
=
5
with
vllm_runner
(
model
,
enable_chunked_prefill
=
enable_chunked_prefill
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_seqs
=
max_num_seqs
,
)
as
vllm_model
:
sampling_params_logprobs_none
=
SamplingParams
(
max_tokens
=
max_tokens
,
logprobs
=
None
,
temperature
=
0.0
,
detokenize
=
detokenize
)
results_logprobs_none
=
vllm_model
.
model
.
generate
(
example_prompts
,
sampling_params
=
sampling_params_logprobs_none
)
for
i
in
range
(
len
(
results_logprobs_none
)):
assert
results_logprobs_none
[
i
].
outputs
[
0
].
logprobs
is
None
assert
results_logprobs_none
[
i
].
outputs
[
0
].
cumulative_logprob
is
None
vllm/model_executor/layers/sampler.py
View file @
db9e5708
"""A layer that samples the next tokens from the model's outputs."""
"""A layer that samples the next tokens from the model's outputs."""
import
itertools
import
itertools
from
math
import
inf
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -774,8 +775,11 @@ def _get_logprobs(
...
@@ -774,8 +775,11 @@ def _get_logprobs(
# The next token ids to get the logprob value from.
# The next token ids to get the logprob value from.
next_token_ids
:
List
[
int
]
=
[]
next_token_ids
:
List
[
int
]
=
[]
# The largest requested number of logprobs. We find logprobs as many as the
# The largest requested number of logprobs. We find logprobs as many as the
# largest num logprobs in this API.
# largest num logprobs in this API. If every logprobs is None, it will be
largest_num_logprobs
=
1
# set to -1.
largest_num_logprobs
=
-
1
# If beam search is enabled.
use_beam_search
=
False
# Select indices to compute logprob from, ranks of token ids, and the top
# Select indices to compute logprob from, ranks of token ids, and the top
# k token ids from logprobs.
# k token ids from logprobs.
...
@@ -808,6 +812,8 @@ def _get_logprobs(
...
@@ -808,6 +812,8 @@ def _get_logprobs(
largest_num_logprobs
=
max
(
largest_num_logprobs
,
largest_num_logprobs
=
max
(
largest_num_logprobs
,
sampling_params
.
logprobs
)
sampling_params
.
logprobs
)
use_beam_search
=
use_beam_search
or
sampling_params
.
use_beam_search
assert
len
(
next_token_ids
)
==
len
(
query_indices
)
assert
len
(
next_token_ids
)
==
len
(
query_indices
)
if
len
(
query_indices
)
==
0
:
if
len
(
query_indices
)
==
0
:
...
@@ -815,35 +821,40 @@ def _get_logprobs(
...
@@ -815,35 +821,40 @@ def _get_logprobs(
empty_prompt_logprob
:
Optional
[
PromptLogprobs
]
=
None
empty_prompt_logprob
:
Optional
[
PromptLogprobs
]
=
None
return
[
empty_prompt_logprob
],
[
empty_sampled_logprob
]
return
[
empty_prompt_logprob
],
[
empty_sampled_logprob
]
query_indices_gpu
=
torch
.
tensor
(
query_indices
,
device
=
logprobs
.
device
)
selected_logprobs
,
ranks
=
None
,
None
next_token_ids_gpu
=
torch
.
tensor
(
next_token_ids
,
device
=
logprobs
.
device
)
top_logprobs
,
top_token_ids
=
None
,
None
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
# If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
# contain duplicates if beam search is enabled.
# skip the whole logprob calculation.
selected_logprobs
=
logprobs
[[
if
largest_num_logprobs
>=
0
or
use_beam_search
:
query_indices_gpu
,
query_indices_gpu
=
torch
.
tensor
(
query_indices
,
device
=
logprobs
.
device
)
next_token_ids_gpu
,
next_token_ids_gpu
=
torch
.
tensor
(
next_token_ids
,
]]
device
=
logprobs
.
device
)
ranks
=
_get_ranks
(
logprobs
[
query_indices_gpu
],
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
next_token_ids_gpu
,
# contain duplicates if beam search is enabled.
)
selected_logprobs
=
logprobs
[[
assert
selected_logprobs
.
shape
[
0
]
==
ranks
.
shape
[
0
]
query_indices_gpu
,
next_token_ids_gpu
,
# Logprobs of topk tokens for a batch of sequence groups.
]]
# (num_query_tokens_across_batch).
ranks
=
_get_ranks
(
if
largest_num_logprobs
>
0
:
logprobs
[
query_indices_gpu
],
top_logprobs
,
top_token_ids
=
torch
.
topk
(
logprobs
,
next_token_ids_gpu
,
largest_num_logprobs
,
)
dim
=-
1
)
assert
selected_logprobs
.
shape
[
0
]
==
ranks
.
shape
[
0
]
else
:
top_logprobs
,
top_token_ids
=
None
,
None
# We need to compute top k only if there exists logprobs > 0.
if
largest_num_logprobs
>
0
:
# Logprobs of topk tokens for a batch of sequence groups.
# (num_query_tokens_across_batch).
top_logprobs
,
top_token_ids
=
torch
.
topk
(
logprobs
,
largest_num_logprobs
,
dim
=-
1
)
top_logprobs
=
top_logprobs
.
to
(
'cpu'
)
top_token_ids
=
top_token_ids
.
to
(
'cpu'
)
selected_logprobs
=
selected_logprobs
.
to
(
'cpu'
)
selected_logprobs
=
selected_logprobs
.
to
(
'cpu'
)
ranks
=
ranks
.
to
(
'cpu'
)
ranks
=
ranks
.
to
(
'cpu'
)
if
top_logprobs
is
not
None
and
top_token_ids
is
not
None
:
top_logprobs
=
top_logprobs
.
to
(
'cpu'
)
top_token_ids
=
top_token_ids
.
to
(
'cpu'
)
# Find prompt/sample logprobs.
# Find prompt/sample logprobs.
prompt_logprobs_per_seq_group
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
prompt_logprobs_per_seq_group
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
...
@@ -940,46 +951,53 @@ def _get_sampled_logprob_if_needed(
...
@@ -940,46 +951,53 @@ def _get_sampled_logprob_if_needed(
):
):
"""Compute the sample logprob if needed."""
"""Compute the sample logprob if needed."""
seq_ids
=
seq_group
.
seq_ids
seq_ids
=
seq_group
.
seq_ids
num_logprobs
=
seq_group
.
sampling_params
.
logprobs
or
0
num_logprobs
=
seq_group
.
sampling_params
.
logprobs
use_beam_search
=
seq_group
.
sampling_params
.
use_beam_search
sampled_logprobs
:
SampleLogprobs
=
[]
sampled_logprobs
:
SampleLogprobs
=
[]
next_token_ids
,
parent_seq_ids
=
sample_result
next_token_ids
,
parent_seq_ids
=
sample_result
if
seq_group
.
do_sample
:
if
seq_group
.
do_sample
:
assert
len
(
next_token_ids
)
>
0
assert
len
(
next_token_ids
)
>
0
# Pre-select items from tensor. tolist() is faster than repetitive
if
num_logprobs
is
None
and
not
use_beam_search
:
# `.item()` calls.
for
next_token_id
in
next_token_ids
:
selected_logprob_items
=
selected_logprobs
[
# Use a dummy logprob
selected_logprobs_idx
:
selected_logprobs_idx
+
sampled_logprobs
.
append
({
next_token_id
:
Logprob
(
inf
)})
len
(
next_token_ids
)].
tolist
()
else
:
rank_items
=
ranks
[
selected_logprobs_idx
:
selected_logprobs_idx
+
# Pre-select items from tensor. tolist() is faster than repetitive
len
(
next_token_ids
)].
tolist
()
# `.item()` calls.
for
idx
,
(
next_token_id
,
selected_logprob_items
=
selected_logprobs
[
parent_id
)
in
enumerate
(
zip
(
next_token_ids
,
parent_seq_ids
)):
selected_logprobs_idx
:
selected_logprobs_idx
+
# Get the logprob of a sampled token.
len
(
next_token_ids
)].
tolist
()
sampled_logprobs_dict
=
{
rank_items
=
ranks
[
selected_logprobs_idx
:
selected_logprobs_idx
+
next_token_id
:
(
selected_logprob_items
[
idx
],
rank_items
[
idx
])
len
(
next_token_ids
)].
tolist
()
}
for
idx
,
(
next_token_id
,
parent_id
)
in
enumerate
(
# Get top K logprobs.
zip
(
next_token_ids
,
parent_seq_ids
)):
if
num_logprobs
>
0
:
# Get the logprob of a sampled token.
top_ids
=
top_token_ids
[
top_logprob_idx
+
sampled_logprobs_dict
=
{
parent_id
,
:
num_logprobs
].
tolist
()
next_token_id
:
top_probs
=
top_logprobs
[
top_logprob_idx
+
(
selected_logprob_items
[
idx
],
rank_items
[
idx
])
parent_id
,
:
num_logprobs
].
tolist
()
}
# Top K is already sorted by rank, so we can use 1 ~
if
num_logprobs
is
not
None
and
num_logprobs
>
0
:
# num_logprobs + 1 for rank.
# Get top K logprobs.
top_ranks
=
range
(
1
,
num_logprobs
+
1
)
top_ids
=
top_token_ids
[
top_logprob_idx
+
sampled_logprobs_dict
.
update
({
parent_id
,
:
num_logprobs
].
tolist
()
top_id
:
(
top_prob
,
rank
)
top_probs
=
top_logprobs
[
for
top_id
,
top_prob
,
rank
in
zip
(
top_ids
,
top_probs
,
top_logprob_idx
+
parent_id
,
:
num_logprobs
].
tolist
()
top_ranks
)
# Top K is already sorted by rank, so we can use 1 ~
# num_logprobs + 1 for rank.
top_ranks
=
range
(
1
,
num_logprobs
+
1
)
sampled_logprobs_dict
.
update
({
top_id
:
(
top_prob
,
rank
)
for
top_id
,
top_prob
,
rank
in
zip
(
top_ids
,
top_probs
,
top_ranks
)
})
sampled_logprobs
.
append
({
token_id
:
Logprob
(
*
logprob_and_rank
)
for
token_id
,
logprob_and_rank
in
sampled_logprobs_dict
.
items
()
})
})
sampled_logprobs
.
append
({
token_id
:
Logprob
(
*
logprob_and_rank
)
for
token_id
,
logprob_and_rank
in
sampled_logprobs_dict
.
items
()
})
# NOTE: This part of code is not intuitive. `selected_logprobs` include
# NOTE: This part of code is not intuitive. `selected_logprobs` include
# logprobs for the current step, which has len(next_token_ids) tokens
# logprobs for the current step, which has len(next_token_ids) tokens
# per sequence group. `logprobs` includes logprobs from the previous
# per sequence group. `logprobs` includes logprobs from the previous
...
...
vllm/outputs.py
View file @
db9e5708
...
@@ -29,7 +29,7 @@ class CompletionOutput:
...
@@ -29,7 +29,7 @@ class CompletionOutput:
index
:
int
index
:
int
text
:
str
text
:
str
token_ids
:
Tuple
[
int
,
...]
token_ids
:
Tuple
[
int
,
...]
cumulative_logprob
:
float
cumulative_logprob
:
Optional
[
float
]
logprobs
:
Optional
[
SampleLogprobs
]
logprobs
:
Optional
[
SampleLogprobs
]
finish_reason
:
Optional
[
str
]
=
None
finish_reason
:
Optional
[
str
]
=
None
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
...
@@ -124,13 +124,14 @@ class RequestOutput:
...
@@ -124,13 +124,14 @@ class RequestOutput:
include_logprobs
=
seq_group
.
sampling_params
.
logprobs
is
not
None
include_logprobs
=
seq_group
.
sampling_params
.
logprobs
is
not
None
text_buffer_length
=
seq_group
.
sampling_params
.
output_text_buffer_length
text_buffer_length
=
seq_group
.
sampling_params
.
output_text_buffer_length
outputs
=
[
outputs
=
[
CompletionOutput
(
seqs
.
index
(
seq
),
CompletionOutput
(
seq
.
get_output_text_to_return
(
text_buffer_length
),
seqs
.
index
(
seq
),
seq
.
get_output_token_ids
(),
seq
.
get_output_text_to_return
(
text_buffer_length
),
seq
.
get_cumulative_logprob
(),
seq
.
get_output_token_ids
(),
seq
.
output_logprobs
if
include_logprobs
else
None
,
seq
.
get_cumulative_logprob
()
if
include_logprobs
else
None
,
SequenceStatus
.
get_finished_reason
(
seq
.
status
),
seq
.
output_logprobs
if
include_logprobs
else
None
,
seq
.
stop_reason
)
for
seq
in
top_n_seqs
SequenceStatus
.
get_finished_reason
(
seq
.
status
),
seq
.
stop_reason
)
for
seq
in
top_n_seqs
]
]
# Every sequence in the sequence group should have the same prompt.
# Every sequence in the sequence group should have the same prompt.
...
...
vllm/sampling_params.py
View file @
db9e5708
...
@@ -92,11 +92,12 @@ class SamplingParams:
...
@@ -92,11 +92,12 @@ class SamplingParams:
min_tokens: Minimum number of tokens to generate per output sequence
min_tokens: Minimum number of tokens to generate per output sequence
before EOS or stop_token_ids can be generated
before EOS or stop_token_ids can be generated
logprobs: Number of log probabilities to return per output token.
logprobs: Number of log probabilities to return per output token.
Note that the implementation follows the OpenAI API: The return
When set to None, no probability is returned. If set to a non-None
result includes the log probabilities on the `logprobs` most likely
value, the result includes the log probabilities of the specified
tokens, as well the chosen tokens. The API will always return the
number of most likely tokens, as well as the chosen tokens.
log probability of the sampled token, so there may be up to
Note that the implementation follows the OpenAI API: The API will
`logprobs+1` elements in the response.
always return the log probability of the sampled token, so there
may be up to `logprobs+1` elements in the response.
prompt_logprobs: Number of log probabilities to return per prompt token.
prompt_logprobs: Number of log probabilities to return per prompt token.
detokenize: Whether to detokenize the output. Defaults to True.
detokenize: Whether to detokenize the output. Defaults to True.
skip_special_tokens: Whether to skip special tokens in the output.
skip_special_tokens: Whether to skip special tokens in the output.
...
@@ -168,8 +169,8 @@ class SamplingParams:
...
@@ -168,8 +169,8 @@ class SamplingParams:
self
.
ignore_eos
=
ignore_eos
self
.
ignore_eos
=
ignore_eos
self
.
max_tokens
=
max_tokens
self
.
max_tokens
=
max_tokens
self
.
min_tokens
=
min_tokens
self
.
min_tokens
=
min_tokens
self
.
logprobs
=
logprobs
self
.
logprobs
=
1
if
logprobs
is
True
else
logprobs
self
.
prompt_logprobs
=
prompt_logprobs
self
.
prompt_logprobs
=
1
if
prompt_logprobs
is
True
else
prompt_logprobs
# NOTE: This parameter is only exposed at the engine level for now.
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
# not support returning only a list of token IDs.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment