Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0258b7a9
Unverified
Commit
0258b7a9
authored
Apr 10, 2024
by
Travis Johnson
Committed by
GitHub
Apr 10, 2024
Browse files
[Bugfix] handle prompt_logprobs in _apply_min_tokens_penalty (#3876)
Signed-off-by:
Travis Johnson
<
tsjohnso@us.ibm.com
>
parent
b3104b2a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
112 additions
and
23 deletions
+112
-23
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+94
-22
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+18
-1
No files found.
tests/samplers/test_sampler.py
View file @
0258b7a9
import
itertools
import
random
import
random
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
from
unittest.mock
import
patch
from
unittest.mock
import
patch
...
@@ -194,11 +195,15 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -194,11 +195,15 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def
create_sampling_params
(
min_tokens
,
def
create_sampling_params
(
min_tokens
,
eos_token_id
=
0
,
eos_token_id
=
0
,
stop_token_ids
=
None
):
*
,
stop_token_ids
:
Optional
[
List
[
str
]]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
):
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
min_tokens
=
min_tokens
,
min_tokens
=
min_tokens
,
max_tokens
=
9999
,
# keep higher than max of min_tokens
max_tokens
=
9999
,
# keep higher than max of min_tokens
stop_token_ids
=
stop_token_ids
,
stop_token_ids
=
stop_token_ids
,
# requesting prompt_logprobs changes the structure of `logits`
prompt_logprobs
=
prompt_logprobs
,
)
)
sampling_params
.
eos_token_id
=
eos_token_id
sampling_params
.
eos_token_id
=
eos_token_id
return
sampling_params
return
sampling_params
...
@@ -217,9 +222,9 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -217,9 +222,9 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
expected_penalization
=
[]
expected_penalization
=
[]
sequence_metadata_list
=
[]
sequence_metadata_list
=
[]
while
batch_size
>
0
:
# 20% chance to generate seq group metadata list with all prompts
# 20% chance to generate prompt seq group with single sequence
is_prompt
=
random
.
random
()
<
0.2
is_prompt
=
random
.
random
()
<
0.2
while
batch_size
>
0
:
num_seqs
=
1
if
is_prompt
else
random
.
randint
(
1
,
batch_size
)
num_seqs
=
1
if
is_prompt
else
random
.
randint
(
1
,
batch_size
)
eos_token_id
=
random
.
randint
(
0
,
VOCAB_SIZE
-
1
)
eos_token_id
=
random
.
randint
(
0
,
VOCAB_SIZE
-
1
)
...
@@ -240,7 +245,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -240,7 +245,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
seq_group_penalization
=
[]
seq_group_penalization
=
[]
for
_
in
range
(
num_seqs
):
for
_
in
range
(
num_seqs
):
num_input
=
random
.
randint
(
1
,
100
)
num_input
=
random
.
randint
(
1
,
100
)
num_generated
=
random
.
randint
(
1
,
100
)
if
not
is_prompt
else
0
num_generated
=
0
if
is_prompt
else
random
.
randint
(
1
,
100
)
seq_data
[
next
(
seq_id_counter
)]
=
create_sequence_data
(
seq_data
[
next
(
seq_id_counter
)]
=
create_sequence_data
(
num_input
=
num_input
,
num_generated
=
num_generated
)
num_input
=
num_input
,
num_generated
=
num_generated
)
seq_group_penalization
.
append
(
num_generated
<
min_tokens
)
seq_group_penalization
.
append
(
num_generated
<
min_tokens
)
...
@@ -292,6 +297,21 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -292,6 +297,21 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
]
]
}
}
prompt_with_penalization_and_prompt_logprobs
=
{
"expected_penalization"
:
[
False
,
False
,
True
],
"seq_group_metadata_list"
:
[
SequenceGroupMetadata
(
request_id
=
"test_1"
,
is_prompt
=
True
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(
num_input
=
3
),
},
sampling_params
=
create_sampling_params
(
1
,
prompt_logprobs
=
3
),
block_tables
=
{},
),
]
}
stop_penalizing_after_min_tokens
=
{
stop_penalizing_after_min_tokens
=
{
"expected_penalization"
:
[
False
],
"expected_penalization"
:
[
False
],
"seq_group_metadata_list"
:
[
"seq_group_metadata_list"
:
[
...
@@ -309,8 +329,34 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -309,8 +329,34 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
}
}
stop_token_ids
=
[
42
,
99
,
42
,
0
]
# intentional duplication
stop_token_ids
=
[
42
,
99
,
42
,
0
]
# intentional duplication
simple_combination
=
{
prompt_combination
=
{
"expected_penalization"
:
[
True
,
False
,
False
],
"expected_penalization"
:
[
False
,
True
,
False
],
"seq_group_metadata_list"
:
[
SequenceGroupMetadata
(
request_id
=
"test_2"
,
is_prompt
=
True
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(
num_input
=
2
),
},
sampling_params
=
create_sampling_params
(
1
,
prompt_logprobs
=
3
),
block_tables
=
{},
),
SequenceGroupMetadata
(
request_id
=
"test_3"
,
is_prompt
=
True
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(),
},
sampling_params
=
create_sampling_params
(
0
,
stop_token_ids
=
stop_token_ids
),
block_tables
=
{},
)
]
}
stop_token_ids
=
[
1
,
999
,
37
,
37
]
# intentional duplication
decode_combination
=
{
"expected_penalization"
:
[
True
,
False
,
False
,
True
,
False
],
"seq_group_metadata_list"
:
[
"seq_group_metadata_list"
:
[
SequenceGroupMetadata
(
SequenceGroupMetadata
(
request_id
=
"test_1"
,
request_id
=
"test_1"
,
...
@@ -327,14 +373,19 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -327,14 +373,19 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
),
),
SequenceGroupMetadata
(
SequenceGroupMetadata
(
request_id
=
"test_2"
,
request_id
=
"test_2"
,
is_prompt
=
Tru
e
,
is_prompt
=
Fals
e
,
seq_data
=
{
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(),
next
(
seq_id_counter
):
create_sequence_data
(
num_generated
=
20
),
next
(
seq_id_counter
):
create_sequence_data
(
num_generated
=
1
),
next
(
seq_id_counter
):
create_sequence_data
(
num_generated
=
10
),
},
},
sampling_params
=
create_sampling_params
(
sampling_params
=
create_sampling_params
(
0
,
stop_token_ids
=
stop_token_ids
),
10
,
prompt_logprobs
=
5
,
stop_token_ids
=
stop_token_ids
),
block_tables
=
{},
block_tables
=
{},
)
)
,
]
]
}
}
...
@@ -342,8 +393,10 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -342,8 +393,10 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
test_cases
=
[
test_cases
=
[
prompt_without_penalization
,
prompt_without_penalization
,
prompt_with_penalization
,
prompt_with_penalization
,
prompt_with_penalization_and_prompt_logprobs
,
stop_penalizing_after_min_tokens
,
stop_penalizing_after_min_tokens
,
simple_combination
,
prompt_combination
,
decode_combination
,
]
]
else
:
else
:
test_cases
=
[
generate_test_case
()]
test_cases
=
[
generate_test_case
()]
...
@@ -351,30 +404,49 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -351,30 +404,49 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def
run_test_case
(
*
,
def
run_test_case
(
*
,
expected_penalization
=
None
,
expected_penalization
=
None
,
seq_group_metadata_list
=
None
):
seq_group_metadata_list
=
None
):
assert
expected_penalization
,
"Invalid test case"
assert
expected_penalization
,
\
assert
seq_group_metadata_list
,
"Invalid test case"
"Invalid test case, need expected_penalization"
assert
seq_group_metadata_list
,
\
"Invalid test case, need seq_group_metadata_list"
batch_size
=
0
batch_size
=
0
prompt_lens
=
[]
prompt_lens
=
[]
sampling_params_per_
seq
=
[]
sampling_params_per_
row
=
[]
for
sgm
in
seq_group_metadata_list
:
for
sgm
in
seq_group_metadata_list
:
num_seqs
=
len
(
sgm
.
seq_data
)
batch_size
+=
num_seqs
sampling_params
=
sgm
.
sampling_params
sampling_params
=
sgm
.
sampling_params
for
seq_id
in
sgm
.
seq_data
:
prompt_lens
.
append
(
sgm
.
seq_data
[
seq_id
].
get_prompt_len
())
num_rows
=
len
(
sgm
.
seq_data
)
sampling_params_per_seq
.
append
(
sampling_params
)
if
sgm
.
is_prompt
:
# a prompt seq_group has only one sequence
seq_data
=
next
(
iter
(
sgm
.
seq_data
.
values
()))
prompt_len
=
seq_data
.
get_prompt_len
()
prompt_lens
.
append
(
prompt_len
)
if
sgm
.
sampling_params
.
prompt_logprobs
:
# with prompt_logprobs each token in the prompt has a row in
# logits
num_rows
=
prompt_len
batch_size
+=
num_rows
sampling_params_per_row
.
extend
(
itertools
.
repeat
(
sampling_params
,
num_rows
))
assert
len
(
expected_penalization
)
==
batch_size
,
\
(
"Invalid test case, expected_penalization does not match computed"
"batch size"
)
_
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
_
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
sampling_metadata
=
model_runner
.
_prepare_sample
(
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
seq_group_metadata_list
,
prompt_lens
=
prompt_lens
,
prompt_lens
=
prompt_lens
if
prompt_lens
else
None
,
subquery_lens
=
prompt_lens
)
subquery_lens
=
prompt_lens
if
prompt_lens
else
None
)
# the logits tensor is modified in-place by the sampler
# the logits tensor is modified in-place by the sampler
_
=
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
_
=
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
for
logits_idx
,
(
should_penalize
,
sampling_params
)
in
enumerate
(
for
logits_idx
,
(
should_penalize
,
sampling_params
)
in
enumerate
(
zip
(
expected_penalization
,
sampling_params_per_
seq
)):
zip
(
expected_penalization
,
sampling_params_per_
row
)):
tokens_to_check
=
[
sampling_params
.
eos_token_id
]
tokens_to_check
=
[
sampling_params
.
eos_token_id
]
if
sampling_params
.
stop_token_ids
:
if
sampling_params
.
stop_token_ids
:
...
...
vllm/model_executor/layers/sampler.py
View file @
0258b7a9
...
@@ -27,6 +27,12 @@ class Sampler(nn.Module):
...
@@ -27,6 +27,12 @@ class Sampler(nn.Module):
6. Sample the next tokens.
6. Sample the next tokens.
Here, each sequence group within the batch can have different sampling
Here, each sequence group within the batch can have different sampling
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
The structure of the logits tensor is coupled with the seq_groups in
sampling_metadata. Typically, each sequence in each seq_group has one row in
logits for the next token to be sampled; however, for a seq_group with a
prompt request with the prompt_logprobs sampling parameter, there are rows
in logits for each token in the input prompt.
"""
"""
def
forward
(
def
forward
(
...
@@ -106,7 +112,16 @@ def _apply_min_tokens_penalty(
...
@@ -106,7 +112,16 @@ def _apply_min_tokens_penalty(
# list of indices in logits that will be set to -inf
# list of indices in logits that will be set to -inf
logits_to_penalize
=
[]
logits_to_penalize
=
[]
start_idx
=
0
start_idx
=
0
for
seq_ids
,
sampling_params
in
sampling_metadata
.
seq_groups
:
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
# handle prompt_logprobs by skipping rows in logits added for the prompt
# tokens (prompt logprobs are not penalized)
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
assert
len
(
seq_ids
)
==
1
start_idx
+=
sampling_metadata
.
prompt_lens
[
i
]
-
1
min_tokens
=
sampling_params
.
min_tokens
min_tokens
=
sampling_params
.
min_tokens
if
min_tokens
>
0
:
if
min_tokens
>
0
:
seqs_to_penalize
=
[]
seqs_to_penalize
=
[]
...
@@ -132,6 +147,8 @@ def _apply_min_tokens_penalty(
...
@@ -132,6 +147,8 @@ def _apply_min_tokens_penalty(
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
logits
[
tuple
(
zip
(
*
logits_to_penalize
))]
=
-
float
(
"inf"
)
logits
[
tuple
(
zip
(
*
logits_to_penalize
))]
=
-
float
(
"inf"
)
# verifies that no rows in logits were missed unexpectedly
assert
start_idx
==
logits
.
shape
[
0
]
return
logits
return
logits
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment