Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
a19bc5c6
"src/vscode:/vscode.git/clone" did not exist on "4eb8dd838849982d2ca4488de59fb8a2397830e4"
Unverified
Commit
a19bc5c6
authored
Sep 27, 2023
by
Woosuk Kwon
Committed by
GitHub
Sep 27, 2023
Browse files
Automatically configure `max_num_batched_tokens` (#1198)
parent
28e616c4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
11 deletions
+35
-11
vllm/config.py
vllm/config.py
+34
-9
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-2
No files found.
vllm/config.py
View file @
a19bc5c6
...
@@ -266,11 +266,36 @@ class SchedulerConfig:
...
@@ -266,11 +266,36 @@ class SchedulerConfig:
and generated text).
and generated text).
"""
"""
def
__init__
(
self
,
max_num_batched_tokens
:
int
,
max_num_seqs
:
int
,
def
__init__
(
max_model_len
:
int
)
->
None
:
self
,
self
.
max_num_batched_tokens
=
max_num_batched_tokens
max_num_batched_tokens
:
Optional
[
int
],
max_num_seqs
:
int
,
max_model_len
:
int
,
)
->
None
:
if
max_num_batched_tokens
is
not
None
:
self
.
max_num_batched_tokens
=
max_num_batched_tokens
else
:
# If max_model_len is too short, use 2048 as the default value for
# higher throughput.
self
.
max_num_batched_tokens
=
max
(
max_model_len
,
2048
)
self
.
max_num_seqs
=
max_num_seqs
self
.
max_num_seqs
=
max_num_seqs
self
.
max_model_len
=
max_model_len
self
.
max_model_len
=
max_model_len
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
if
self
.
max_num_batched_tokens
<
self
.
max_model_len
:
raise
ValueError
(
f
"max_num_batched_tokens (
{
self
.
max_num_batched_tokens
}
) is "
f
"smaller than max_model_len (
{
self
.
max_model_len
}
). "
"This effectively limits the maximum sequence length to "
"max_num_batched_tokens and makes vLLM reject longer "
"sequences. Please increase max_num_batched_tokens or "
"decrease max_model_len."
)
if
self
.
max_num_batched_tokens
<
self
.
max_num_seqs
:
raise
ValueError
(
f
"max_num_batched_tokens (
{
self
.
max_num_batched_tokens
}
) must "
"be greater than or equal to max_num_seqs "
f
"(
{
self
.
max_num_seqs
}
)."
)
_STR_DTYPE_TO_TORCH_DTYPE
=
{
_STR_DTYPE_TO_TORCH_DTYPE
=
{
...
@@ -350,14 +375,14 @@ def _get_and_verify_max_len(
...
@@ -350,14 +375,14 @@ def _get_and_verify_max_len(
max_len_key
=
getattr
(
hf_config
,
key
,
None
)
max_len_key
=
getattr
(
hf_config
,
key
,
None
)
if
max_len_key
is
not
None
:
if
max_len_key
is
not
None
:
derived_max_model_len
=
min
(
derived_max_model_len
,
max_len_key
)
derived_max_model_len
=
min
(
derived_max_model_len
,
max_len_key
)
if
derived_max_model_len
==
float
(
"inf"
):
raise
ValueError
(
"The model's config.json must contain one of the following keys "
"to determine the original maximum length of the model: "
f
"
{
possible_keys
}
"
)
rope_scaling
=
getattr
(
hf_config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
hf_config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
:
if
rope_scaling
is
not
None
:
if
derived_max_model_len
==
float
(
"inf"
):
raise
ValueError
(
"When using rope_scaling, the model's config.json must "
"contain one of the following keys to determine the original "
f
"maximum length of the model:
{
possible_keys
}
"
)
assert
"factor"
in
rope_scaling
assert
"factor"
in
rope_scaling
scaling_factor
=
rope_scaling
[
"factor"
]
scaling_factor
=
rope_scaling
[
"factor"
]
derived_max_model_len
*=
scaling_factor
derived_max_model_len
*=
scaling_factor
...
@@ -371,4 +396,4 @@ def _get_and_verify_max_len(
...
@@ -371,4 +396,4 @@ def _get_and_verify_max_len(
" in model's config.json). This may lead to incorrect model "
" in model's config.json). This may lead to incorrect model "
"outputs or CUDA errors. Make sure the value is correct and "
"outputs or CUDA errors. Make sure the value is correct and "
"within the model context size."
)
"within the model context size."
)
return
max_model_len
return
int
(
max_model_len
)
vllm/engine/arg_utils.py
View file @
a19bc5c6
...
@@ -25,7 +25,7 @@ class EngineArgs:
...
@@ -25,7 +25,7 @@ class EngineArgs:
block_size
:
int
=
16
block_size
:
int
=
16
swap_space
:
int
=
4
# GiB
swap_space
:
int
=
4
# GiB
gpu_memory_utilization
:
float
=
0.90
gpu_memory_utilization
:
float
=
0.90
max_num_batched_tokens
:
int
=
2560
max_num_batched_tokens
:
Optional
[
int
]
=
None
max_num_seqs
:
int
=
256
max_num_seqs
:
int
=
256
disable_log_stats
:
bool
=
False
disable_log_stats
:
bool
=
False
revision
:
Optional
[
str
]
=
None
revision
:
Optional
[
str
]
=
None
...
@@ -34,7 +34,6 @@ class EngineArgs:
...
@@ -34,7 +34,6 @@ class EngineArgs:
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
tokenizer
is
None
:
if
self
.
tokenizer
is
None
:
self
.
tokenizer
=
self
.
model
self
.
tokenizer
=
self
.
model
self
.
max_num_seqs
=
min
(
self
.
max_num_seqs
,
self
.
max_num_batched_tokens
)
@
staticmethod
@
staticmethod
def
add_cli_args
(
def
add_cli_args
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment