Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0258b7a9
Unverified
Commit
0258b7a9
authored
Apr 10, 2024
by
Travis Johnson
Committed by
GitHub
Apr 10, 2024
Browse files
[Bugfix] handle prompt_logprobs in _apply_min_tokens_penalty (#3876)
Signed-off-by:
Travis Johnson
<
tsjohnso@us.ibm.com
>
parent
b3104b2a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
112 additions
and
23 deletions
+112
-23
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+94
-22
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+18
-1
No files found.
tests/samplers/test_sampler.py
View file @
0258b7a9
import
itertools
import
random
from
typing
import
List
,
Optional
,
Tuple
from
unittest.mock
import
patch
...
...
@@ -194,11 +195,15 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def
create_sampling_params
(
min_tokens
,
eos_token_id
=
0
,
stop_token_ids
=
None
):
*
,
stop_token_ids
:
Optional
[
List
[
str
]]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
):
sampling_params
=
SamplingParams
(
min_tokens
=
min_tokens
,
max_tokens
=
9999
,
# keep higher than max of min_tokens
stop_token_ids
=
stop_token_ids
,
# requesting prompt_logprobs changes the structure of `logits`
prompt_logprobs
=
prompt_logprobs
,
)
sampling_params
.
eos_token_id
=
eos_token_id
return
sampling_params
...
...
@@ -217,9 +222,9 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
expected_penalization
=
[]
sequence_metadata_list
=
[]
# 20% chance to generate seq group metadata list with all prompts
is_prompt
=
random
.
random
()
<
0.2
while
batch_size
>
0
:
# 20% chance to generate prompt seq group with single sequence
is_prompt
=
random
.
random
()
<
0.2
num_seqs
=
1
if
is_prompt
else
random
.
randint
(
1
,
batch_size
)
eos_token_id
=
random
.
randint
(
0
,
VOCAB_SIZE
-
1
)
...
...
@@ -240,7 +245,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
seq_group_penalization
=
[]
for
_
in
range
(
num_seqs
):
num_input
=
random
.
randint
(
1
,
100
)
num_generated
=
random
.
randint
(
1
,
100
)
if
not
is_prompt
else
0
num_generated
=
0
if
is_prompt
else
random
.
randint
(
1
,
100
)
seq_data
[
next
(
seq_id_counter
)]
=
create_sequence_data
(
num_input
=
num_input
,
num_generated
=
num_generated
)
seq_group_penalization
.
append
(
num_generated
<
min_tokens
)
...
...
@@ -292,6 +297,21 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
]
}
prompt_with_penalization_and_prompt_logprobs
=
{
"expected_penalization"
:
[
False
,
False
,
True
],
"seq_group_metadata_list"
:
[
SequenceGroupMetadata
(
request_id
=
"test_1"
,
is_prompt
=
True
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(
num_input
=
3
),
},
sampling_params
=
create_sampling_params
(
1
,
prompt_logprobs
=
3
),
block_tables
=
{},
),
]
}
stop_penalizing_after_min_tokens
=
{
"expected_penalization"
:
[
False
],
"seq_group_metadata_list"
:
[
...
...
@@ -309,8 +329,34 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
}
stop_token_ids
=
[
42
,
99
,
42
,
0
]
# intentional duplication
simple_combination
=
{
"expected_penalization"
:
[
True
,
False
,
False
],
prompt_combination
=
{
"expected_penalization"
:
[
False
,
True
,
False
],
"seq_group_metadata_list"
:
[
SequenceGroupMetadata
(
request_id
=
"test_2"
,
is_prompt
=
True
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(
num_input
=
2
),
},
sampling_params
=
create_sampling_params
(
1
,
prompt_logprobs
=
3
),
block_tables
=
{},
),
SequenceGroupMetadata
(
request_id
=
"test_3"
,
is_prompt
=
True
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(),
},
sampling_params
=
create_sampling_params
(
0
,
stop_token_ids
=
stop_token_ids
),
block_tables
=
{},
)
]
}
stop_token_ids
=
[
1
,
999
,
37
,
37
]
# intentional duplication
decode_combination
=
{
"expected_penalization"
:
[
True
,
False
,
False
,
True
,
False
],
"seq_group_metadata_list"
:
[
SequenceGroupMetadata
(
request_id
=
"test_1"
,
...
...
@@ -327,14 +373,19 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
),
SequenceGroupMetadata
(
request_id
=
"test_2"
,
is_prompt
=
Tru
e
,
is_prompt
=
Fals
e
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(),
next
(
seq_id_counter
):
create_sequence_data
(
num_generated
=
20
),
next
(
seq_id_counter
):
create_sequence_data
(
num_generated
=
1
),
next
(
seq_id_counter
):
create_sequence_data
(
num_generated
=
10
),
},
sampling_params
=
create_sampling_params
(
0
,
stop_token_ids
=
stop_token_ids
),
10
,
prompt_logprobs
=
5
,
stop_token_ids
=
stop_token_ids
),
block_tables
=
{},
)
)
,
]
}
...
...
@@ -342,8 +393,10 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
test_cases
=
[
prompt_without_penalization
,
prompt_with_penalization
,
prompt_with_penalization_and_prompt_logprobs
,
stop_penalizing_after_min_tokens
,
simple_combination
,
prompt_combination
,
decode_combination
,
]
else
:
test_cases
=
[
generate_test_case
()]
...
...
@@ -351,30 +404,49 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def
run_test_case
(
*
,
expected_penalization
=
None
,
seq_group_metadata_list
=
None
):
assert
expected_penalization
,
"Invalid test case"
assert
seq_group_metadata_list
,
"Invalid test case"
assert
expected_penalization
,
\
"Invalid test case, need expected_penalization"
assert
seq_group_metadata_list
,
\
"Invalid test case, need seq_group_metadata_list"
batch_size
=
0
prompt_lens
=
[]
sampling_params_per_
seq
=
[]
sampling_params_per_
row
=
[]
for
sgm
in
seq_group_metadata_list
:
num_seqs
=
len
(
sgm
.
seq_data
)
batch_size
+=
num_seqs
sampling_params
=
sgm
.
sampling_params
for
seq_id
in
sgm
.
seq_data
:
prompt_lens
.
append
(
sgm
.
seq_data
[
seq_id
].
get_prompt_len
())
sampling_params_per_seq
.
append
(
sampling_params
)
num_rows
=
len
(
sgm
.
seq_data
)
if
sgm
.
is_prompt
:
# a prompt seq_group has only one sequence
seq_data
=
next
(
iter
(
sgm
.
seq_data
.
values
()))
prompt_len
=
seq_data
.
get_prompt_len
()
prompt_lens
.
append
(
prompt_len
)
if
sgm
.
sampling_params
.
prompt_logprobs
:
# with prompt_logprobs each token in the prompt has a row in
# logits
num_rows
=
prompt_len
batch_size
+=
num_rows
sampling_params_per_row
.
extend
(
itertools
.
repeat
(
sampling_params
,
num_rows
))
assert
len
(
expected_penalization
)
==
batch_size
,
\
(
"Invalid test case, expected_penalization does not match computed"
"batch size"
)
_
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
=
prompt_lens
,
subquery_lens
=
prompt_lens
)
prompt_lens
=
prompt_lens
if
prompt_lens
else
None
,
subquery_lens
=
prompt_lens
if
prompt_lens
else
None
)
# the logits tensor is modified in-place by the sampler
_
=
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
for
logits_idx
,
(
should_penalize
,
sampling_params
)
in
enumerate
(
zip
(
expected_penalization
,
sampling_params_per_
seq
)):
zip
(
expected_penalization
,
sampling_params_per_
row
)):
tokens_to_check
=
[
sampling_params
.
eos_token_id
]
if
sampling_params
.
stop_token_ids
:
...
...
vllm/model_executor/layers/sampler.py
View file @
0258b7a9
...
...
@@ -27,6 +27,12 @@ class Sampler(nn.Module):
6. Sample the next tokens.
Here, each sequence group within the batch can have different sampling
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
The structure of the logits tensor is coupled with the seq_groups in
sampling_metadata. Typically, each sequence in each seq_group has one row in
logits for the next token to be sampled; however, for a seq_group with a
prompt request with the prompt_logprobs sampling parameter, there are rows
in logits for each token in the input prompt.
"""
def
forward
(
...
...
@@ -106,7 +112,16 @@ def _apply_min_tokens_penalty(
# list of indices in logits that will be set to -inf
logits_to_penalize
=
[]
start_idx
=
0
for
seq_ids
,
sampling_params
in
sampling_metadata
.
seq_groups
:
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
# handle prompt_logprobs by skipping rows in logits added for the prompt
# tokens (prompt logprobs are not penalized)
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
assert
len
(
seq_ids
)
==
1
start_idx
+=
sampling_metadata
.
prompt_lens
[
i
]
-
1
min_tokens
=
sampling_params
.
min_tokens
if
min_tokens
>
0
:
seqs_to_penalize
=
[]
...
...
@@ -132,6 +147,8 @@ def _apply_min_tokens_penalty(
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
logits
[
tuple
(
zip
(
*
logits_to_penalize
))]
=
-
float
(
"inf"
)
# verifies that no rows in logits were missed unexpectedly
assert
start_idx
==
logits
.
shape
[
0
]
return
logits
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment