Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
0bb1e885
Unverified
Commit
0bb1e885
authored
Sep 12, 2023
by
Antoni Baum
Committed by
GitHub
Sep 12, 2023
Browse files
Make `max_model_len` configurable (#972)
parent
d6545ad2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
1 deletion
+22
-1
vllm/config.py
vllm/config.py
+15
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+7
-1
No files found.
vllm/config.py
View file @
0bb1e885
...
@@ -38,6 +38,8 @@ class ModelConfig:
...
@@ -38,6 +38,8 @@ class ModelConfig:
will use FP16 precision for FP32 and FP16 models, and BF16 precision
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
for BF16 models.
seed: Random seed for reproducibility.
seed: Random seed for reproducibility.
max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -50,6 +52,7 @@ class ModelConfig:
...
@@ -50,6 +52,7 @@ class ModelConfig:
load_format
:
str
,
load_format
:
str
,
dtype
:
str
,
dtype
:
str
,
seed
:
int
,
seed
:
int
,
max_model_len
:
Optional
[
int
]
=
None
,
)
->
None
:
)
->
None
:
self
.
model
=
model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
...
@@ -63,6 +66,16 @@ class ModelConfig:
...
@@ -63,6 +66,16 @@ class ModelConfig:
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_config
,
dtype
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_config
,
dtype
)
self
.
_verify_load_format
()
self
.
_verify_load_format
()
self
.
_verify_tokenizer_mode
()
self
.
_verify_tokenizer_mode
()
self
.
max_model_len
=
None
if
max_model_len
is
not
None
:
derived_max_model_len
=
self
.
get_max_model_len
()
if
max_model_len
>
derived_max_model_len
:
logger
.
warning
(
f
"User-specified max_model_len (
{
max_model_len
}
) is "
f
"greater than the derived max_model_len "
f
"(
{
derived_max_model_len
}
). Make sure the value is "
"correct and within the model context size."
)
self
.
max_model_len
=
max_model_len
def
_verify_load_format
(
self
)
->
None
:
def
_verify_load_format
(
self
)
->
None
:
load_format
=
self
.
load_format
.
lower
()
load_format
=
self
.
load_format
.
lower
()
...
@@ -134,6 +147,8 @@ class ModelConfig:
...
@@ -134,6 +147,8 @@ class ModelConfig:
return
total_num_attention_heads
//
parallel_config
.
tensor_parallel_size
return
total_num_attention_heads
//
parallel_config
.
tensor_parallel_size
def
get_max_model_len
(
self
)
->
int
:
def
get_max_model_len
(
self
)
->
int
:
if
self
.
max_model_len
is
not
None
:
return
self
.
max_model_len
max_model_len
=
float
(
"inf"
)
max_model_len
=
float
(
"inf"
)
possible_keys
=
[
possible_keys
=
[
# OPT
# OPT
...
...
vllm/engine/arg_utils.py
View file @
0bb1e885
...
@@ -18,6 +18,7 @@ class EngineArgs:
...
@@ -18,6 +18,7 @@ class EngineArgs:
load_format
:
str
=
'auto'
load_format
:
str
=
'auto'
dtype
:
str
=
'auto'
dtype
:
str
=
'auto'
seed
:
int
=
0
seed
:
int
=
0
max_model_len
:
Optional
[
int
]
=
None
worker_use_ray
:
bool
=
False
worker_use_ray
:
bool
=
False
pipeline_parallel_size
:
int
=
1
pipeline_parallel_size
:
int
=
1
tensor_parallel_size
:
int
=
1
tensor_parallel_size
:
int
=
1
...
@@ -89,6 +90,11 @@ class EngineArgs:
...
@@ -89,6 +90,11 @@ class EngineArgs:
'The "auto" option will use FP16 precision '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'
)
'for BF16 models.'
)
parser
.
add_argument
(
'--max-model-len'
,
type
=
int
,
default
=
None
,
help
=
'model context length. If unspecified, '
'will be automatically derived from the model.'
)
# Parallel arguments
# Parallel arguments
parser
.
add_argument
(
'--worker-use-ray'
,
parser
.
add_argument
(
'--worker-use-ray'
,
action
=
'store_true'
,
action
=
'store_true'
,
...
@@ -153,7 +159,7 @@ class EngineArgs:
...
@@ -153,7 +159,7 @@ class EngineArgs:
model_config
=
ModelConfig
(
self
.
model
,
self
.
tokenizer
,
model_config
=
ModelConfig
(
self
.
model
,
self
.
tokenizer
,
self
.
tokenizer_mode
,
self
.
trust_remote_code
,
self
.
tokenizer_mode
,
self
.
trust_remote_code
,
self
.
download_dir
,
self
.
load_format
,
self
.
download_dir
,
self
.
load_format
,
self
.
dtype
,
self
.
seed
)
self
.
dtype
,
self
.
seed
,
self
.
max_model_len
)
cache_config
=
CacheConfig
(
self
.
block_size
,
cache_config
=
CacheConfig
(
self
.
block_size
,
self
.
gpu_memory_utilization
,
self
.
gpu_memory_utilization
,
self
.
swap_space
)
self
.
swap_space
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment