Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a134ef6f
Unverified
Commit
a134ef6f
authored
Apr 18, 2024
by
Simon Mo
Committed by
GitHub
Apr 19, 2024
Browse files
Support eos_token_id from generation_config.json (#4182)
parent
8a7a3e44
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
3 deletions
+30
-3
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+17
-2
vllm/sampling_params.py
vllm/sampling_params.py
+13
-1
No files found.
vllm/engine/llm_engine.py
View file @
a134ef6f
import
time
import
time
from
typing
import
Iterable
,
List
,
Optional
,
Type
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Type
,
Union
from
transformers
import
PreTrainedTokenizer
from
transformers
import
GenerationConfig
,
PreTrainedTokenizer
import
vllm
import
vllm
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
LoadConfig
,
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
LoadConfig
,
...
@@ -34,6 +34,17 @@ logger = init_logger(__name__)
...
@@ -34,6 +34,17 @@ logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC
=
5
_LOCAL_LOGGING_INTERVAL_SEC
=
5
def
_load_generation_config_dict
(
model_config
:
ModelConfig
):
try
:
return
GenerationConfig
.
from_pretrained
(
model_config
.
model
,
revision
=
model_config
.
revision
,
).
to_diff_dict
()
except
OSError
:
# Not found.
return
{}
class
LLMEngine
:
class
LLMEngine
:
"""An LLM engine that receives requests and generates texts.
"""An LLM engine that receives requests and generates texts.
...
@@ -124,6 +135,8 @@ class LLMEngine:
...
@@ -124,6 +135,8 @@ class LLMEngine:
self
.
_init_tokenizer
()
self
.
_init_tokenizer
()
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
)
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
)
self
.
seq_counter
=
Counter
()
self
.
seq_counter
=
Counter
()
self
.
generation_config_fields
=
_load_generation_config_dict
(
model_config
)
self
.
model_executor
=
executor_class
(
self
.
model_executor
=
executor_class
(
model_config
=
model_config
,
model_config
=
model_config
,
...
@@ -391,6 +404,8 @@ class LLMEngine:
...
@@ -391,6 +404,8 @@ class LLMEngine:
# inject the eos token id into the sampling_params to support min_tokens
# inject the eos token id into the sampling_params to support min_tokens
# processing
# processing
sampling_params
.
eos_token_id
=
seq
.
eos_token_id
sampling_params
.
eos_token_id
=
seq
.
eos_token_id
sampling_params
.
update_from_generation_config
(
self
.
generation_config_fields
)
# Create the sequence group.
# Create the sequence group.
seq_group
=
SequenceGroup
(
request_id
,
[
seq
],
sampling_params
,
seq_group
=
SequenceGroup
(
request_id
,
[
seq
],
sampling_params
,
...
@@ -435,7 +450,7 @@ class LLMEngine:
...
@@ -435,7 +450,7 @@ class LLMEngine:
scheduled_seq_groups
:
List
[
SequenceGroup
],
scheduled_seq_groups
:
List
[
SequenceGroup
],
ignored_seq_groups
:
List
[
SequenceGroup
])
->
List
[
RequestOutput
]:
ignored_seq_groups
:
List
[
SequenceGroup
])
->
List
[
RequestOutput
]:
"""Apply the model output to the sequences in the scheduled seq groups.
"""Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client.
Returns RequestOutputs that can be returned to the client.
"""
"""
...
...
vllm/sampling_params.py
View file @
a134ef6f
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
import
copy
import
copy
from
enum
import
IntEnum
from
enum
import
IntEnum
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
torch
import
torch
from
pydantic
import
Field
from
pydantic
import
Field
...
@@ -271,6 +271,18 @@ class SamplingParams:
...
@@ -271,6 +271,18 @@ class SamplingParams:
raise
ValueError
(
"best_of must be 1 when using greedy sampling."
raise
ValueError
(
"best_of must be 1 when using greedy sampling."
f
"Got
{
self
.
best_of
}
."
)
f
"Got
{
self
.
best_of
}
."
)
def
update_from_generation_config
(
self
,
generation_config
:
Dict
[
str
,
Any
])
->
None
:
"""Update if there are non-default values from generation_config"""
# Update eos_token_id for generation
if
eos_ids
:
=
generation_config
.
get
(
"eos_token_id"
):
# it can be either int or list of int
if
isinstance
(
eos_ids
,
int
):
eos_ids
=
[
eos_ids
]
original_stop_token_ids
=
set
(
self
.
stop_token_ids
)
original_stop_token_ids
.
update
(
eos_ids
)
self
.
stop_token_ids
=
list
(
original_stop_token_ids
)
@
cached_property
@
cached_property
def
sampling_type
(
self
)
->
SamplingType
:
def
sampling_type
(
self
)
->
SamplingType
:
if
self
.
use_beam_search
:
if
self
.
use_beam_search
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment