Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
511a6b61
Unverified
Commit
511a6b61
authored
Nov 14, 2025
by
Cyrus Leung
Committed by
GitHub
Nov 14, 2025
Browse files
[Config] Clean up SchedulerConfig initialization (#28665)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
96b23b8e
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
181 additions
and
162 deletions
+181
-162
tests/models/language/generation/test_hybrid.py
tests/models/language/generation/test_hybrid.py
+6
-1
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+2
-0
tests/v1/sample/test_logprobs.py
tests/v1/sample/test_logprobs.py
+1
-0
vllm/config/scheduler.py
vllm/config/scheduler.py
+33
-69
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+135
-73
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+1
-3
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+1
-3
vllm/platforms/xpu.py
vllm/platforms/xpu.py
+1
-3
vllm/utils/__init__.py
vllm/utils/__init__.py
+1
-10
No files found.
tests/models/language/generation/test_hybrid.py
View file @
511a6b61
...
@@ -348,9 +348,14 @@ def test_fp32_cache_state(
...
@@ -348,9 +348,14 @@ def test_fp32_cache_state(
# Helper functions for the APC tests
# Helper functions for the APC tests
def
_get_vllm_runner_params
(
model
,
max_model_len
,
tensor_parallel_size
=
1
):
def
_get_vllm_runner_params
(
model
:
str
,
max_model_len
:
int
,
tensor_parallel_size
:
int
=
1
,
):
return
{
return
{
"model_name"
:
model
,
"model_name"
:
model
,
"enable_chunked_prefill"
:
True
,
"enable_prefix_caching"
:
False
,
"enable_prefix_caching"
:
False
,
"max_model_len"
:
max_model_len
,
"max_model_len"
:
max_model_len
,
"tensor_parallel_size"
:
tensor_parallel_size
,
"tensor_parallel_size"
:
tensor_parallel_size
,
...
...
tests/v1/core/test_scheduler.py
View file @
511a6b61
...
@@ -2256,6 +2256,8 @@ def test_chunked_prefill_disabled_for_encoder_decoder(
...
@@ -2256,6 +2256,8 @@ def test_chunked_prefill_disabled_for_encoder_decoder(
scheduler_config
=
SchedulerConfig
(
scheduler_config
=
SchedulerConfig
(
enable_chunked_prefill
=
enable_chunked_prefill
,
enable_chunked_prefill
=
enable_chunked_prefill
,
is_encoder_decoder
=
is_encoder_decoder
,
is_encoder_decoder
=
is_encoder_decoder
,
# Must <= max_num_batched_tokens if chunked prefill is disabled
max_model_len
=
SchedulerConfig
.
DEFAULT_MAX_NUM_BATCHED_TOKENS
,
)
)
# `is_encoder_decoder` should only be used during construction
# `is_encoder_decoder` should only be used during construction
...
...
tests/v1/sample/test_logprobs.py
View file @
511a6b61
...
@@ -47,6 +47,7 @@ def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]:
...
@@ -47,6 +47,7 @@ def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]:
max_num_batched_tokens
=
16
,
max_num_batched_tokens
=
16
,
max_num_seqs
=
16
,
max_num_seqs
=
16
,
max_model_len
=
128
,
max_model_len
=
128
,
enable_chunked_prefill
=
True
,
enforce_eager
=
True
,
enforce_eager
=
True
,
# TODO: enable this once we support it for
# TODO: enable this once we support it for
# prompt logprobs.
# prompt logprobs.
...
...
vllm/config/scheduler.py
View file @
511a6b61
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
import
hashlib
import
hashlib
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
dataclasses
import
InitVar
from
dataclasses
import
InitVar
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Literal
,
cast
from
pydantic
import
Field
,
field_validator
,
model_validator
from
pydantic
import
Field
,
field_validator
,
model_validator
from
pydantic.dataclasses
import
dataclass
from
pydantic.dataclasses
import
dataclass
...
@@ -12,11 +12,6 @@ from typing_extensions import Self
...
@@ -12,11 +12,6 @@ from typing_extensions import Self
from
vllm.config.utils
import
config
from
vllm.config.utils
import
config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
DEFAULT_MAX_NUM_BATCHED_TOKENS
,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS
,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS
,
)
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -33,25 +28,32 @@ SchedulerPolicy = Literal["fcfs", "priority"]
...
@@ -33,25 +28,32 @@ SchedulerPolicy = Literal["fcfs", "priority"]
class
SchedulerConfig
:
class
SchedulerConfig
:
"""Scheduler configuration."""
"""Scheduler configuration."""
DEFAULT_MAX_NUM_BATCHED_TOKENS
:
ClassVar
[
int
]
=
2048
DEFAULT_MAX_NUM_SEQS
:
ClassVar
[
int
]
=
128
runner_type
:
RunnerType
=
"generate"
runner_type
:
RunnerType
=
"generate"
"""The runner type to launch for the model."""
"""The runner type to launch for the model."""
max_num_batched_tokens
:
int
=
Field
(
default
=
None
,
ge
=
1
)
max_num_batched_tokens
:
int
=
Field
(
default
=
DEFAULT_MAX_NUM_BATCHED_TOKENS
,
ge
=
1
)
"""Maximum number of tokens to be processed in a single iteration.
"""Maximum number of tokens to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will
The default value here is mainly for convenience when testing.
be set in `EngineArgs.create_engine_config` based on the usage context."""
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""
max_num_seqs
:
int
=
Field
(
default
=
None
,
ge
=
1
)
max_num_seqs
:
int
=
Field
(
default
=
DEFAULT_MAX_NUM_SEQS
,
ge
=
1
)
"""Maximum number of sequences to be processed in a single iteration.
"""Maximum number of sequences to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will
The default value here is mainly for convenience when testing.
be set in `EngineArgs.create_engine_config` based on the usage context."""
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""
max_model_len
:
int
=
Field
(
default
=
8192
,
ge
=
1
)
"""Maximum length of a sequence (including prompt and generated text).
max_model_len
:
int
=
Field
(
default
=
None
,
ge
=
1
)
The default value here is mainly for convenience when testing.
"""Maximum length of a sequence (including prompt and generated text). This
In real usage, this should duplicate `ModelConfig.max_model_len` via
is primarily set in `ModelConfig` and that value should be manually
`EngineArgs`."""
duplicated here."""
max_num_partial_prefills
:
int
=
Field
(
default
=
1
,
ge
=
1
)
max_num_partial_prefills
:
int
=
Field
(
default
=
1
,
ge
=
1
)
"""For chunked prefill, the maximum number of sequences that can be
"""For chunked prefill, the maximum number of sequences that can be
...
@@ -76,9 +78,13 @@ class SchedulerConfig:
...
@@ -76,9 +78,13 @@ class SchedulerConfig:
NOTE: This will be replaced by speculative config in the future; it is
NOTE: This will be replaced by speculative config in the future; it is
present to enable correctness tests until then."""
present to enable correctness tests until then."""
enable_chunked_prefill
:
bool
=
Field
(
default
=
None
)
enable_chunked_prefill
:
bool
=
True
"""If True, prefill requests can be chunked based
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
on the remaining `max_num_batched_tokens`.
The default value here is mainly for convenience when testing.
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""
is_multimodal_model
:
bool
=
False
is_multimodal_model
:
bool
=
False
"""True if the model is multimodal."""
"""True if the model is multimodal."""
...
@@ -111,9 +117,6 @@ class SchedulerConfig:
...
@@ -111,9 +117,6 @@ class SchedulerConfig:
- "priority" means requests are handled based on given priority (lower
- "priority" means requests are handled based on given priority (lower
value means earlier handling) and time of arrival deciding any ties)."""
value means earlier handling) and time of arrival deciding any ties)."""
chunked_prefill_enabled
:
bool
=
Field
(
init
=
False
)
"""True if chunked prefill is enabled."""
disable_chunked_mm_input
:
bool
=
False
disable_chunked_mm_input
:
bool
=
False
"""If set to true and chunked prefill is enabled, we do not want to
"""If set to true and chunked prefill is enabled, we do not want to
partially schedule a multimodal item. Only used in V1
partially schedule a multimodal item. Only used in V1
...
@@ -188,15 +191,7 @@ class SchedulerConfig:
...
@@ -188,15 +191,7 @@ class SchedulerConfig:
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
(),
usedforsecurity
=
False
).
hexdigest
()
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
(),
usedforsecurity
=
False
).
hexdigest
()
return
hash_str
return
hash_str
@
field_validator
(
@
field_validator
(
"scheduler_cls"
,
"async_scheduling"
,
mode
=
"wrap"
)
"max_num_batched_tokens"
,
"max_num_seqs"
,
"max_model_len"
,
"enable_chunked_prefill"
,
"scheduler_cls"
,
"async_scheduling"
,
mode
=
"wrap"
,
)
@
classmethod
@
classmethod
def
_skip_none_validation
(
cls
,
value
:
Any
,
handler
:
Callable
)
->
Any
:
def
_skip_none_validation
(
cls
,
value
:
Any
,
handler
:
Callable
)
->
Any
:
"""Skip validation if the value is `None` when initialisation is delayed."""
"""Skip validation if the value is `None` when initialisation is delayed."""
...
@@ -205,16 +200,9 @@ class SchedulerConfig:
...
@@ -205,16 +200,9 @@ class SchedulerConfig:
return
handler
(
value
)
return
handler
(
value
)
def
__post_init__
(
self
,
is_encoder_decoder
:
bool
)
->
None
:
def
__post_init__
(
self
,
is_encoder_decoder
:
bool
)
->
None
:
if
self
.
max_model_len
is
None
:
self
.
max_model_len
=
8192
if
self
.
max_num_seqs
is
None
:
self
.
max_num_seqs
=
128
if
is_encoder_decoder
:
if
is_encoder_decoder
:
# Chunked prefill should be disabled for encoder-decoder models.
# Chunked prefill should be disabled for encoder-decoder models.
self
.
disable_chunked_mm_input
=
True
self
.
disable_chunked_mm_input
=
True
self
.
chunked_prefill_enabled
=
False
self
.
enable_chunked_prefill
=
False
self
.
enable_chunked_prefill
=
False
self
.
long_prefill_token_threshold
=
0
self
.
long_prefill_token_threshold
=
0
logger
.
info
(
logger
.
info
(
...
@@ -222,37 +210,6 @@ class SchedulerConfig:
...
@@ -222,37 +210,6 @@ class SchedulerConfig:
" prefix caching; disabling both."
" prefix caching; disabling both."
)
)
if
self
.
max_num_batched_tokens
is
None
:
if
self
.
enable_chunked_prefill
:
self
.
max_num_batched_tokens
=
DEFAULT_MAX_NUM_BATCHED_TOKENS
else
:
# If max_model_len is too short, use
# DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
# for higher throughput.
self
.
max_num_batched_tokens
=
max
(
self
.
max_model_len
,
DEFAULT_MAX_NUM_BATCHED_TOKENS
)
if
self
.
runner_type
==
"pooling"
:
# Choose specific value for higher throughput
self
.
max_num_batched_tokens
=
max
(
self
.
max_num_batched_tokens
,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS
,
)
if
self
.
is_multimodal_model
:
# The value needs to be at least the number of multimodal tokens
self
.
max_num_batched_tokens
=
max
(
self
.
max_num_batched_tokens
,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS
,
)
# When using default settings,
# Ensure max_num_batched_tokens does not exceed model limit.
# Some models (e.g., Whisper) have embeddings tied to max length.
self
.
max_num_batched_tokens
=
min
(
self
.
max_num_seqs
*
self
.
max_model_len
,
self
.
max_num_batched_tokens
)
self
.
max_num_encoder_input_tokens
=
self
.
max_num_batched_tokens
self
.
max_num_encoder_input_tokens
=
self
.
max_num_batched_tokens
self
.
encoder_cache_size
=
self
.
max_num_batched_tokens
self
.
encoder_cache_size
=
self
.
max_num_batched_tokens
...
@@ -262,7 +219,6 @@ class SchedulerConfig:
...
@@ -262,7 +219,6 @@ class SchedulerConfig:
self
.
max_num_batched_tokens
,
self
.
max_num_batched_tokens
,
)
)
self
.
chunked_prefill_enabled
=
self
.
enable_chunked_prefill
if
self
.
max_num_partial_prefills
>
1
:
if
self
.
max_num_partial_prefills
>
1
:
if
self
.
long_prefill_token_threshold
==
0
:
if
self
.
long_prefill_token_threshold
==
0
:
self
.
long_prefill_token_threshold
=
int
(
self
.
max_model_len
*
0.04
)
self
.
long_prefill_token_threshold
=
int
(
self
.
max_model_len
*
0.04
)
...
@@ -276,6 +232,14 @@ class SchedulerConfig:
...
@@ -276,6 +232,14 @@ class SchedulerConfig:
self
.
long_prefill_token_threshold
,
self
.
long_prefill_token_threshold
,
)
)
@
property
def
chunked_prefill_enabled
(
self
)
->
bool
:
return
self
.
enable_chunked_prefill
@
chunked_prefill_enabled
.
setter
def
chunked_prefill_enabled
(
self
,
value
:
bool
):
self
.
enable_chunked_prefill
=
value
@
model_validator
(
mode
=
"after"
)
@
model_validator
(
mode
=
"after"
)
def
_verify_args
(
self
)
->
Self
:
def
_verify_args
(
self
)
->
Self
:
if
(
if
(
...
...
vllm/engine/arg_utils.py
View file @
511a6b61
...
@@ -428,11 +428,11 @@ class EngineArgs:
...
@@ -428,11 +428,11 @@ class EngineArgs:
cpu_offload_gb
:
float
=
CacheConfig
.
cpu_offload_gb
cpu_offload_gb
:
float
=
CacheConfig
.
cpu_offload_gb
gpu_memory_utilization
:
float
=
CacheConfig
.
gpu_memory_utilization
gpu_memory_utilization
:
float
=
CacheConfig
.
gpu_memory_utilization
kv_cache_memory_bytes
:
int
|
None
=
CacheConfig
.
kv_cache_memory_bytes
kv_cache_memory_bytes
:
int
|
None
=
CacheConfig
.
kv_cache_memory_bytes
max_num_batched_tokens
:
int
|
None
=
SchedulerConfig
.
max_num_batched_tokens
max_num_batched_tokens
:
int
|
None
=
None
max_num_partial_prefills
:
int
=
SchedulerConfig
.
max_num_partial_prefills
max_num_partial_prefills
:
int
=
SchedulerConfig
.
max_num_partial_prefills
max_long_partial_prefills
:
int
=
SchedulerConfig
.
max_long_partial_prefills
max_long_partial_prefills
:
int
=
SchedulerConfig
.
max_long_partial_prefills
long_prefill_token_threshold
:
int
=
SchedulerConfig
.
long_prefill_token_threshold
long_prefill_token_threshold
:
int
=
SchedulerConfig
.
long_prefill_token_threshold
max_num_seqs
:
int
|
None
=
SchedulerConfig
.
max_num_seqs
max_num_seqs
:
int
|
None
=
None
max_logprobs
:
int
=
ModelConfig
.
max_logprobs
max_logprobs
:
int
=
ModelConfig
.
max_logprobs
logprobs_mode
:
LogprobsMode
=
ModelConfig
.
logprobs_mode
logprobs_mode
:
LogprobsMode
=
ModelConfig
.
logprobs_mode
disable_log_stats
:
bool
=
False
disable_log_stats
:
bool
=
False
...
@@ -485,7 +485,7 @@ class EngineArgs:
...
@@ -485,7 +485,7 @@ class EngineArgs:
model_loader_extra_config
:
dict
=
get_field
(
LoadConfig
,
"model_loader_extra_config"
)
model_loader_extra_config
:
dict
=
get_field
(
LoadConfig
,
"model_loader_extra_config"
)
ignore_patterns
:
str
|
list
[
str
]
=
get_field
(
LoadConfig
,
"ignore_patterns"
)
ignore_patterns
:
str
|
list
[
str
]
=
get_field
(
LoadConfig
,
"ignore_patterns"
)
enable_chunked_prefill
:
bool
|
None
=
SchedulerConfig
.
enable_chunked_prefill
enable_chunked_prefill
:
bool
|
None
=
None
disable_chunked_mm_input
:
bool
=
SchedulerConfig
.
disable_chunked_mm_input
disable_chunked_mm_input
:
bool
=
SchedulerConfig
.
disable_chunked_mm_input
disable_hybrid_kv_cache_manager
:
bool
=
(
disable_hybrid_kv_cache_manager
:
bool
=
(
...
@@ -1738,41 +1738,41 @@ class EngineArgs:
...
@@ -1738,41 +1738,41 @@ class EngineArgs:
)
)
_raise_unsupported_error
(
feature_name
=
name
)
_raise_unsupported_error
(
feature_name
=
name
)
def
_set_default_args
(
@
classmethod
self
,
usage_context
:
UsageContext
,
model_config
:
ModelConfig
def
get_chunked_prefill_prefix_caching_defaults
(
)
->
None
:
cls
,
"""Set Default Arguments for V1 Engine."""
model_config
:
ModelConfig
,
)
->
tuple
[
bool
,
bool
]:
# V1 uses chunked prefills and prefix caching by default
# for non-pooling tasks.
# For pooling tasks the default is False
if
model_config
.
runner_type
!=
"pooling"
:
if
model_config
.
runner_type
!=
"pooling"
:
self
.
enable
_chunked_prefill
=
True
default
_chunked_prefill
=
True
if
self
.
enable_prefix_caching
is
None
:
# Disable prefix caching default for hybrid models
# Disable prefix caching default for hybrid models
# since the feature is still experimental.
# since the feature is still experimental.
if
model_config
.
is_hybrid
:
default_prefix_caching
=
not
model_config
.
is_hybrid
self
.
enable_prefix_caching
=
False
else
:
self
.
enable_prefix_caching
=
True
else
:
else
:
assert
model_config
.
pooler_config
is
not
None
pooling_type
=
model_config
.
pooler_config
.
pooling_type
pooling_type
=
model_config
.
pooler_config
.
pooling_type
is_causal
=
getattr
(
model_config
.
hf_config
,
"is_causal"
,
True
)
incremental_prefill_supported
=
(
incremental_prefill_supported
=
(
pooling_type
is
not
None
pooling_type
is
not
None
and
pooling_type
.
lower
()
==
"last"
and
pooling_type
.
lower
()
==
"last"
and
bool
(
is_causal
)
and
getattr
(
model_config
.
hf_config
,
"
is_causal
"
,
True
)
)
)
action
=
"Enabling"
if
incremental_prefill_supported
else
"Disabling"
default_chunked_prefill
=
incremental_prefill_supported
default_prefix_caching
=
incremental_prefill_supported
if
self
.
enable_chunked_prefill
is
None
:
return
default_chunked_prefill
,
default_prefix_caching
self
.
enable_chunked_prefill
=
incremental_prefill_supported
logger
.
info
(
"(%s) chunked prefill by default"
,
action
)
@
classmethod
if
self
.
enable_prefix_caching
is
None
:
def
get_batch_defaults
(
self
.
enable_prefix_caching
=
incremental_prefill_supported
cls
,
logger
.
info
(
"(%s) prefix caching by default"
,
action
)
world_size
:
int
,
)
->
tuple
[
dict
[
UsageContext
|
None
,
int
],
dict
[
UsageContext
|
None
,
int
]]:
from
vllm.usage.usage_lib
import
UsageContext
default_max_num_batched_tokens
:
dict
[
UsageContext
|
None
,
int
]
default_max_num_seqs
:
dict
[
UsageContext
|
None
,
int
]
# When no user override, set the default values based on the usage
# When no user override, set the default values based on the usage
# context.
# context.
...
@@ -1793,8 +1793,6 @@ class EngineArgs:
...
@@ -1793,8 +1793,6 @@ class EngineArgs:
# NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
# NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
# throughput, see PR #17885 for more details.
# throughput, see PR #17885 for more details.
# So here we do an extra device name check to prevent such regression.
# So here we do an extra device name check to prevent such regression.
from
vllm.usage.usage_lib
import
UsageContext
if
device_memory
>=
70
*
GiB_bytes
and
"a100"
not
in
device_name
:
if
device_memory
>=
70
*
GiB_bytes
and
"a100"
not
in
device_name
:
# For GPUs like H100 and MI300x, use larger default values.
# For GPUs like H100 and MI300x, use larger default values.
default_max_num_batched_tokens
=
{
default_max_num_batched_tokens
=
{
...
@@ -1818,22 +1816,26 @@ class EngineArgs:
...
@@ -1818,22 +1816,26 @@ class EngineArgs:
# tpu specific default values.
# tpu specific default values.
if
current_platform
.
is_tpu
():
if
current_platform
.
is_tpu
():
default_max_num_batched_tokens_tpu
=
{
chip_name
=
current_platform
.
get_device_name
()
UsageContext
.
LLM_CLASS
:
{
"V6E"
:
2048
,
if
chip_name
==
"V6E"
:
"V5E"
:
1024
,
default_max_num_batched_tokens
=
{
"V5P"
:
512
,
UsageContext
.
LLM_CLASS
:
2048
,
},
UsageContext
.
OPENAI_API_SERVER
:
1024
,
UsageContext
.
OPENAI_API_SERVER
:
{
}
"V6E"
:
1024
,
elif
chip_name
==
"V5E"
:
"V5E"
:
512
,
default_max_num_batched_tokens
=
{
"V5P"
:
256
,
UsageContext
.
LLM_CLASS
:
1024
,
},
UsageContext
.
OPENAI_API_SERVER
:
512
,
}
elif
chip_name
==
"V5P"
:
default_max_num_batched_tokens
=
{
UsageContext
.
LLM_CLASS
:
512
,
UsageContext
.
OPENAI_API_SERVER
:
256
,
}
}
# cpu specific default values.
# cpu specific default values.
if
current_platform
.
is_cpu
():
if
current_platform
.
is_cpu
():
world_size
=
self
.
pipeline_parallel_size
*
self
.
tensor_parallel_size
default_max_num_batched_tokens
=
{
default_max_num_batched_tokens
=
{
UsageContext
.
LLM_CLASS
:
4096
*
world_size
,
UsageContext
.
LLM_CLASS
:
4096
*
world_size
,
UsageContext
.
OPENAI_API_SERVER
:
2048
*
world_size
,
UsageContext
.
OPENAI_API_SERVER
:
2048
*
world_size
,
...
@@ -1843,44 +1845,104 @@ class EngineArgs:
...
@@ -1843,44 +1845,104 @@ class EngineArgs:
UsageContext
.
OPENAI_API_SERVER
:
128
*
world_size
,
UsageContext
.
OPENAI_API_SERVER
:
128
*
world_size
,
}
}
use_context_value
=
usage_context
.
value
if
usage_context
else
None
return
default_max_num_batched_tokens
,
default_max_num_seqs
if
(
self
.
max_num_batched_tokens
is
None
def
_set_default_args
(
and
usage_context
in
default_max_num_batched_tokens
self
,
usage_context
:
UsageContext
,
model_config
:
ModelConfig
)
->
None
:
"""Set Default Arguments for V1 Engine."""
(
default_chunked_prefill
,
default_prefix_caching
,
)
=
self
.
get_chunked_prefill_prefix_caching_defaults
(
model_config
)
if
self
.
enable_chunked_prefill
is
None
:
self
.
enable_chunked_prefill
=
default_chunked_prefill
logger
.
debug
(
"%s chunked prefill by default"
,
"Enabling"
if
default_chunked_prefill
else
"Disabling"
,
)
elif
(
model_config
.
runner_type
==
"pooling"
and
self
.
enable_chunked_prefill
and
not
default_chunked_prefill
):
):
if
current_platform
.
is_tpu
():
logger
.
warning
(
chip_name
=
current_platform
.
get_device_name
()
"This model does not officially support chunked prefill. "
if
chip_name
in
default_max_num_batched_tokens_tpu
[
usage_context
]:
"Enabling this manually may cause the engine to crash "
self
.
max_num_batched_tokens
=
default_max_num_batched_tokens_tpu
[
"or produce incorrect outputs."
,
usage_context
)
][
chip_name
]
else
:
if
self
.
enable_prefix_caching
is
None
:
self
.
max_num_batched_tokens
=
default_max_num_batched_tokens
[
self
.
enable_prefix_caching
=
default_prefix_caching
usage_context
]
else
:
if
not
self
.
enable_chunked_prefill
:
self
.
max_num_batched_tokens
=
model_config
.
max_model_len
else
:
self
.
max_num_batched_tokens
=
default_max_num_batched_tokens
[
usage_context
]
logger
.
debug
(
logger
.
debug
(
"Setting max_num_batched_tokens to %d for %s usage context."
,
"%s prefix caching by default"
,
"Enabling"
if
default_prefix_caching
else
"Disabling"
,
)
elif
(
model_config
.
runner_type
==
"pooling"
and
self
.
enable_prefix_caching
and
not
default_prefix_caching
):
logger
.
warning
(
"This model does not officially support prefix caching. "
"Enabling this manually may cause the engine to crash "
"or produce incorrect outputs."
,
)
world_size
=
self
.
pipeline_parallel_size
*
self
.
tensor_parallel_size
(
default_max_num_batched_tokens
,
default_max_num_seqs
,
)
=
self
.
get_batch_defaults
(
world_size
)
orig_max_num_batched_tokens
=
self
.
max_num_batched_tokens
orig_max_num_seqs
=
self
.
max_num_seqs
if
self
.
max_num_batched_tokens
is
None
:
self
.
max_num_batched_tokens
=
default_max_num_batched_tokens
.
get
(
usage_context
,
SchedulerConfig
.
DEFAULT_MAX_NUM_BATCHED_TOKENS
,
)
if
self
.
max_num_seqs
is
None
:
self
.
max_num_seqs
=
default_max_num_seqs
.
get
(
usage_context
,
SchedulerConfig
.
DEFAULT_MAX_NUM_SEQS
,
)
if
orig_max_num_batched_tokens
is
None
:
if
not
self
.
enable_chunked_prefill
:
# If max_model_len is too short, use the default for higher throughput.
self
.
max_num_batched_tokens
=
max
(
model_config
.
max_model_len
,
self
.
max_num_batched_tokens
,
self
.
max_num_batched_tokens
,
use_context_value
,
)
)
if
self
.
max_num_seqs
is
None
and
usage_context
in
default_max_num_seqs
:
# When using default settings,
self
.
max_num_seqs
=
min
(
# Ensure max_num_batched_tokens does not exceed model limit.
default_max_num_seqs
[
usage_context
],
# Some models (e.g., Whisper) have embeddings tied to max length.
self
.
max_num_batched_tokens
or
sys
.
maxsize
,
self
.
max_num_batched_tokens
=
min
(
self
.
max_num_seqs
*
model_config
.
max_model_len
,
self
.
max_num_batched_tokens
,
)
)
logger
.
debug
(
logger
.
debug
(
"Setting max_num_seqs to %d for %s usage context."
,
"Defaulting max_num_batched_tokens to %d for %s usage context."
,
self
.
max_num_batched_tokens
,
usage_context
.
value
if
usage_context
else
None
,
)
if
orig_max_num_seqs
is
None
:
assert
self
.
max_num_batched_tokens
is
not
None
# For type checking
self
.
max_num_seqs
=
min
(
self
.
max_num_seqs
,
self
.
max_num_batched_tokens
)
logger
.
debug
(
"Defaulting max_num_seqs to %d for %s usage context."
,
self
.
max_num_seqs
,
self
.
max_num_seqs
,
use_context
_
value
,
us
ag
e_context
.
value
if
usage_context
else
None
,
)
)
...
...
vllm/platforms/cpu.py
View file @
511a6b61
...
@@ -15,7 +15,6 @@ import torch
...
@@ -15,7 +15,6 @@ import torch
from
vllm
import
envs
from
vllm
import
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
DEFAULT_MAX_NUM_BATCHED_TOKENS
from
.interface
import
CpuArchEnum
,
Platform
,
PlatformEnum
from
.interface
import
CpuArchEnum
,
Platform
,
PlatformEnum
...
@@ -339,10 +338,9 @@ class CpuPlatform(Platform):
...
@@ -339,10 +338,9 @@ class CpuPlatform(Platform):
"prefill and prefix caching to be disabled."
"prefill and prefix caching to be disabled."
)
)
vllm_config
.
scheduler_config
.
enable_chunked_prefill
=
False
vllm_config
.
scheduler_config
.
enable_chunked_prefill
=
False
vllm_config
.
scheduler_config
.
chunked_prefill_enabled
=
False
vllm_config
.
scheduler_config
.
max_num_batched_tokens
=
max
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
=
max
(
vllm_config
.
scheduler_config
.
max_model_len
,
vllm_config
.
scheduler_config
.
max_model_len
,
DEFAULT_MAX_NUM_BATCHED_TOKENS
,
vllm_config
.
scheduler_config
.
DEFAULT_MAX_NUM_BATCHED_TOKENS
,
)
)
@
classmethod
@
classmethod
...
...
vllm/platforms/tpu.py
View file @
511a6b61
...
@@ -10,7 +10,6 @@ from tpu_info import device
...
@@ -10,7 +10,6 @@ from tpu_info import device
from
vllm.inputs
import
ProcessorInputs
,
PromptType
from
vllm.inputs
import
ProcessorInputs
,
PromptType
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.utils
import
DEFAULT_MAX_NUM_BATCHED_TOKENS
from
.interface
import
Platform
,
PlatformEnum
from
.interface
import
Platform
,
PlatformEnum
...
@@ -186,10 +185,9 @@ class TpuPlatform(Platform):
...
@@ -186,10 +185,9 @@ class TpuPlatform(Platform):
"prefill and prefix caching to be disabled."
"prefill and prefix caching to be disabled."
)
)
vllm_config
.
scheduler_config
.
enable_chunked_prefill
=
False
vllm_config
.
scheduler_config
.
enable_chunked_prefill
=
False
vllm_config
.
scheduler_config
.
chunked_prefill_enabled
=
False
vllm_config
.
scheduler_config
.
max_num_batched_tokens
=
max
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
=
max
(
vllm_config
.
scheduler_config
.
max_model_len
,
vllm_config
.
scheduler_config
.
max_model_len
,
DEFAULT_MAX_NUM_BATCHED_TOKENS
,
vllm_config
.
scheduler_config
.
DEFAULT_MAX_NUM_BATCHED_TOKENS
,
)
)
@
classmethod
@
classmethod
...
...
vllm/platforms/xpu.py
View file @
511a6b61
...
@@ -9,7 +9,6 @@ import torch
...
@@ -9,7 +9,6 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
DEFAULT_MAX_NUM_BATCHED_TOKENS
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
...
@@ -185,10 +184,9 @@ class XPUPlatform(Platform):
...
@@ -185,10 +184,9 @@ class XPUPlatform(Platform):
"prefill and prefix caching to be disabled."
"prefill and prefix caching to be disabled."
)
)
vllm_config
.
scheduler_config
.
enable_chunked_prefill
=
False
vllm_config
.
scheduler_config
.
enable_chunked_prefill
=
False
vllm_config
.
scheduler_config
.
chunked_prefill_enabled
=
False
vllm_config
.
scheduler_config
.
max_num_batched_tokens
=
max
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
=
max
(
vllm_config
.
scheduler_config
.
max_model_len
,
vllm_config
.
scheduler_config
.
max_model_len
,
DEFAULT_MAX_NUM_BATCHED_TOKENS
,
vllm_config
.
scheduler_config
.
DEFAULT_MAX_NUM_BATCHED_TOKENS
,
)
)
@
classmethod
@
classmethod
...
...
vllm/utils/__init__.py
View file @
511a6b61
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
uuid
import
uuid
import
warnings
import
warnings
from
typing
import
Any
,
TypeVar
from
typing
import
Any
import
torch
import
torch
...
@@ -39,12 +39,6 @@ def __dir__() -> list[str]:
...
@@ -39,12 +39,6 @@ def __dir__() -> list[str]:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# This value is chosen to have a balance between ITL and TTFT. Note it is
# not optimized for throughput.
DEFAULT_MAX_NUM_BATCHED_TOKENS
=
2048
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS
=
32768
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS
=
5120
# Constants related to forcing the attention backend selection
# Constants related to forcing the attention backend selection
# String name of register which may be set in order to
# String name of register which may be set in order to
...
@@ -60,9 +54,6 @@ STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
...
@@ -60,9 +54,6 @@ STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL
:
str
=
"INVALID"
STR_INVALID_VAL
:
str
=
"INVALID"
T
=
TypeVar
(
"T"
)
def
random_uuid
()
->
str
:
def
random_uuid
()
->
str
:
return
str
(
uuid
.
uuid4
().
hex
)
return
str
(
uuid
.
uuid4
().
hex
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment