Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a134ef6f
Unverified
Commit
a134ef6f
authored
Apr 18, 2024
by
Simon Mo
Committed by
GitHub
Apr 19, 2024
Browse files
Support eos_token_id from generation_config.json (#4182)
parent
8a7a3e44
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
3 deletions
+30
-3
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+17
-2
vllm/sampling_params.py
vllm/sampling_params.py
+13
-1
No files found.
vllm/engine/llm_engine.py
View file @
a134ef6f
import
time
from
typing
import
Iterable
,
List
,
Optional
,
Type
,
Union
from
transformers
import
PreTrainedTokenizer
from
transformers
import
GenerationConfig
,
PreTrainedTokenizer
import
vllm
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
LoadConfig
,
...
...
@@ -34,6 +34,17 @@ logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC
=
5
def
_load_generation_config_dict
(
model_config
:
ModelConfig
):
try
:
return
GenerationConfig
.
from_pretrained
(
model_config
.
model
,
revision
=
model_config
.
revision
,
).
to_diff_dict
()
except
OSError
:
# Not found.
return
{}
class
LLMEngine
:
"""An LLM engine that receives requests and generates texts.
...
...
@@ -124,6 +135,8 @@ class LLMEngine:
self
.
_init_tokenizer
()
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
)
self
.
seq_counter
=
Counter
()
self
.
generation_config_fields
=
_load_generation_config_dict
(
model_config
)
self
.
model_executor
=
executor_class
(
model_config
=
model_config
,
...
...
@@ -391,6 +404,8 @@ class LLMEngine:
# inject the eos token id into the sampling_params to support min_tokens
# processing
sampling_params
.
eos_token_id
=
seq
.
eos_token_id
sampling_params
.
update_from_generation_config
(
self
.
generation_config_fields
)
# Create the sequence group.
seq_group
=
SequenceGroup
(
request_id
,
[
seq
],
sampling_params
,
...
...
vllm/sampling_params.py
View file @
a134ef6f
...
...
@@ -2,7 +2,7 @@
import
copy
from
enum
import
IntEnum
from
functools
import
cached_property
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
torch
from
pydantic
import
Field
...
...
@@ -271,6 +271,18 @@ class SamplingParams:
raise
ValueError
(
"best_of must be 1 when using greedy sampling."
f
"Got
{
self
.
best_of
}
."
)
def
update_from_generation_config
(
self
,
generation_config
:
Dict
[
str
,
Any
])
->
None
:
"""Update if there are non-default values from generation_config"""
# Update eos_token_id for generation
if
eos_ids
:
=
generation_config
.
get
(
"eos_token_id"
):
# it can be either int or list of int
if
isinstance
(
eos_ids
,
int
):
eos_ids
=
[
eos_ids
]
original_stop_token_ids
=
set
(
self
.
stop_token_ids
)
original_stop_token_ids
.
update
(
eos_ids
)
self
.
stop_token_ids
=
list
(
original_stop_token_ids
)
@
cached_property
def
sampling_type
(
self
)
->
SamplingType
:
if
self
.
use_beam_search
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment