Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
365791ff
Unverified
Commit
365791ff
authored
Jun 27, 2024
by
Nick Hill
Committed by
GitHub
Jun 27, 2024
Browse files
[BugFix] Fix `min_tokens` behaviour for multiple eos tokens (#5849)
parent
691e29ec
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
13 deletions
+23
-13
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+2
-5
vllm/sampling_params.py
vllm/sampling_params.py
+21
-8
No files found.
vllm/engine/llm_engine.py
View file @
365791ff
...
...
@@ -606,12 +606,9 @@ class LLMEngine:
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params
=
sampling_params
.
clone
()
# Add the eos token id into the sampling_params to support min_tokens
# processing
if
seq
.
eos_token_id
is
not
None
:
sampling_params
.
all_stop_token_ids
.
add
(
seq
.
eos_token_id
)
sampling_params
.
update_from_generation_config
(
self
.
generation_config_fields
)
self
.
generation_config_fields
,
seq
.
eos_token_id
)
# Create the sequence group.
seq_group
=
SequenceGroup
(
...
...
vllm/sampling_params.py
View file @
365791ff
...
...
@@ -280,17 +280,30 @@ class SamplingParams:
f
"Got
{
self
.
best_of
}
."
)
def
update_from_generation_config
(
self
,
generation_config
:
Dict
[
str
,
Any
])
->
None
:
self
,
generation_config
:
Dict
[
str
,
Any
],
model_eos_token_id
:
Optional
[
int
]
=
None
)
->
None
:
"""Update if there are non-default values from generation_config"""
if
model_eos_token_id
is
not
None
:
# Add the eos token id into the sampling_params to support
# min_tokens processing.
self
.
all_stop_token_ids
.
add
(
model_eos_token_id
)
# Update eos_token_id for generation
if
(
not
self
.
ignore_eos
)
and
(
eos_ids
:
=
generation_config
.
get
(
"eos_token_id"
)):
if
(
eos_ids
:
=
generation_config
.
get
(
"eos_token_id"
))
is
not
None
:
# it can be either int or list of int
if
isinstance
(
eos_ids
,
int
):
eos_ids
=
[
eos_ids
]
original_stop_token_ids
=
set
(
self
.
stop_token_ids
)
original_stop_token_ids
.
update
(
eos_ids
)
self
.
stop_token_ids
=
list
(
original_stop_token_ids
)
eos_ids
=
{
eos_ids
}
if
isinstance
(
eos_ids
,
int
)
else
set
(
eos_ids
)
if
model_eos_token_id
is
not
None
:
# We don't need to include the primary eos_token_id in
# stop_token_ids since it's handled separately for stopping
# purposes.
eos_ids
.
discard
(
model_eos_token_id
)
if
eos_ids
:
self
.
all_stop_token_ids
.
update
(
eos_ids
)
if
not
self
.
ignore_eos
:
eos_ids
.
update
(
self
.
stop_token_ids
)
self
.
stop_token_ids
=
list
(
eos_ids
)
@
cached_property
def
sampling_type
(
self
)
->
SamplingType
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment