Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4b2ed792
Unverified
Commit
4b2ed792
authored
May 09, 2025
by
Harry Mellor
Committed by
GitHub
May 09, 2025
Browse files
Improve configs - the rest! (#17562)
Signed-off-by:
Harry Mellor
<
19981378+hmellor@users.noreply.github.com
>
parent
7e357113
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
461 additions
and
345 deletions
+461
-345
tests/compile/test_full_graph.py
tests/compile/test_full_graph.py
+1
-4
tests/compile/test_functionalization.py
tests/compile/test_functionalization.py
+3
-4
tests/compile/test_fusion.py
tests/compile/test_fusion.py
+3
-3
tests/compile/test_sequence_parallelism.py
tests/compile/test_sequence_parallelism.py
+3
-4
tests/compile/test_silu_mul_quant_fusion.py
tests/compile/test_silu_mul_quant_fusion.py
+2
-3
tests/distributed/test_sequence_parallel.py
tests/distributed/test_sequence_parallel.py
+2
-2
tests/engine/test_arg_utils.py
tests/engine/test_arg_utils.py
+61
-15
vllm/compilation/vllm_inductor_pass.py
vllm/compilation/vllm_inductor_pass.py
+2
-5
vllm/config.py
vllm/config.py
+303
-220
vllm/distributed/kv_events.py
vllm/distributed/kv_events.py
+2
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+56
-75
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+9
-4
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+6
-5
vllm/utils.py
vllm/utils.py
+8
-0
No files found.
tests/compile/test_full_graph.py
View file @
4b2ed792
...
...
@@ -9,7 +9,7 @@ import torch
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
CompilationConfig
,
CompilationLevel
from
vllm.config
import
CompilationConfig
,
CompilationLevel
,
PassConfig
from
vllm.platforms
import
current_platform
from
..utils
import
create_new_process_for_each_test
...
...
@@ -95,9 +95,6 @@ def test_full_graph(
run_model
(
optimization_level
,
model
,
model_kwargs
)
PassConfig
=
CompilationConfig
.
PassConfig
# TODO(luka) add other supported compilation config scenarios here
@
pytest
.
mark
.
parametrize
(
"compilation_config, model_info"
,
...
...
tests/compile/test_functionalization.py
View file @
4b2ed792
...
...
@@ -11,7 +11,7 @@ from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
kFp8DynamicTokenSym
,
kFp8StaticTensorSym
)
from
vllm.compilation.fx_utils
import
find_auto_fn
,
find_auto_fn_maybe
,
is_func
from
vllm.compilation.noop_elimination
import
NoOpEliminationPass
from
vllm.config
import
CompilationConfig
,
VllmConfig
from
vllm.config
import
CompilationConfig
,
PassConfig
,
VllmConfig
from
.backend
import
TestBackend
...
...
@@ -53,9 +53,8 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
torch
.
set_default_device
(
"cuda"
)
vllm_config
=
VllmConfig
()
vllm_config
.
compilation_config
=
CompilationConfig
(
pass_config
=
\
CompilationConfig
.
PassConfig
(
enable_fusion
=
do_fusion
,
enable_noop
=
True
))
vllm_config
.
compilation_config
=
CompilationConfig
(
pass_config
=
PassConfig
(
enable_fusion
=
do_fusion
,
enable_noop
=
True
))
noop_pass
=
NoOpEliminationPass
(
vllm_config
)
fusion_pass
=
FusionPass
.
instance
(
vllm_config
)
act_quant_fusion_pass
=
ActivationQuantFusionPass
(
vllm_config
)
...
...
tests/compile/test_fusion.py
View file @
4b2ed792
...
...
@@ -9,7 +9,8 @@ from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass
,
QuantKey
)
from
vllm.compilation.fx_utils
import
find_auto_fn
,
find_auto_fn_maybe
from
vllm.compilation.noop_elimination
import
NoOpEliminationPass
from
vllm.config
import
CompilationConfig
,
CompilationLevel
,
VllmConfig
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
PassConfig
,
VllmConfig
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
CUTLASS_FP8_SUPPORTED
,
Fp8LinearOp
,
maybe_create_device_identity
)
...
...
@@ -78,8 +79,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
custom_ops
=
[
"+rms_norm"
]))
vllm_config
.
compilation_config
.
pass_config
=
\
CompilationConfig
.
PassConfig
(
enable_fusion
=
True
,
enable_noop
=
True
)
PassConfig
(
enable_fusion
=
True
,
enable_noop
=
True
)
with
vllm
.
config
.
set_current_vllm_config
(
vllm_config
):
# Reshape pass is needed for the fusion pass to work
noop_pass
=
NoOpEliminationPass
(
vllm_config
)
...
...
tests/compile/test_sequence_parallelism.py
View file @
4b2ed792
...
...
@@ -10,7 +10,7 @@ from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe,
find_specified_fn_maybe
,
is_func
)
from
vllm.compilation.sequence_parallelism
import
SequenceParallelismPass
from
vllm.config
import
(
CompilationConfig
,
DeviceConfig
,
ModelConfig
,
VllmConfig
)
PassConfig
,
VllmConfig
)
from
vllm.distributed
import
tensor_model_parallel_all_reduce
from
vllm.distributed.parallel_state
import
(
init_distributed_environment
,
initialize_model_parallel
)
...
...
@@ -126,9 +126,8 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
# configure vllm config for SequenceParallelismPass
vllm_config
=
VllmConfig
()
vllm_config
.
compilation_config
=
CompilationConfig
(
pass_config
=
CompilationConfig
.
PassConfig
(
enable_sequence_parallelism
=
True
,
),
)
vllm_config
.
compilation_config
=
CompilationConfig
(
pass_config
=
PassConfig
(
enable_sequence_parallelism
=
True
))
vllm_config
.
device_config
=
DeviceConfig
(
device
=
torch
.
device
(
"cuda"
))
# this is a fake model name to construct the model config
...
...
tests/compile/test_silu_mul_quant_fusion.py
View file @
4b2ed792
...
...
@@ -6,7 +6,7 @@ import vllm.envs as envs
from
vllm._custom_ops
import
scaled_fp8_quant
from
vllm.compilation.activation_quant_fusion
import
ActivationQuantFusionPass
from
vllm.compilation.fx_utils
import
find_auto_fn
,
find_auto_fn_maybe
from
vllm.config
import
CompilationConfig
,
VllmConfig
from
vllm.config
import
CompilationConfig
,
PassConfig
,
VllmConfig
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
.backend
import
TestBackend
...
...
@@ -36,8 +36,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
# Reshape pass is needed for the fusion pass to work
config
=
VllmConfig
()
config
.
compilation_config
=
CompilationConfig
(
pass_config
=
CompilationConfig
.
PassConfig
(
enable_fusion
=
True
,
enable_reshape
=
True
))
pass_config
=
PassConfig
(
enable_fusion
=
True
,
enable_reshape
=
True
))
fusion_pass
=
ActivationQuantFusionPass
(
config
)
backend
=
TestBackend
(
fusion_pass
)
...
...
tests/distributed/test_sequence_parallel.py
View file @
4b2ed792
...
...
@@ -206,7 +206,7 @@ def _compare_sp(
'compile_sizes'
:
[
4
,
8
],
'splitting_ops'
:
[],
'pass_config'
:
{
'enable_sequence_parallism'
:
sp_enabled
,
'enable_sequence_parall
el
ism'
:
sp_enabled
,
'enable_noop'
:
True
,
'enable_fusion'
:
True
,
},
...
...
@@ -223,7 +223,7 @@ def _compare_sp(
"--distributed-executor-backend"
,
distributed_backend
,
"--compilation_config"
,
str
(
compilation_config
),
json
.
dumps
(
compilation_config
),
]
tp_env
=
{
...
...
tests/engine/test_arg_utils.py
View file @
4b2ed792
...
...
@@ -8,21 +8,18 @@ from typing import Literal, Optional
import
pytest
from
vllm.config
import
config
from
vllm.config
import
CompilationConfig
,
config
from
vllm.engine.arg_utils
import
(
EngineArgs
,
contains_type
,
get_kwargs
,
get_type
,
is_not_builtin
,
is_type
,
literal_to_kwargs
,
nullable_kvs
,
optional_type
)
optional_type
,
parse_type
)
from
vllm.utils
import
FlexibleArgumentParser
@
pytest
.
mark
.
parametrize
((
"type"
,
"value"
,
"expected"
),
[
(
int
,
"42"
,
42
),
(
int
,
"None"
,
None
),
(
float
,
"3.14"
,
3.14
),
(
float
,
"None"
,
None
),
(
str
,
"Hello World!"
,
"Hello World!"
),
(
str
,
"None"
,
None
),
(
json
.
loads
,
'{"foo":1,"bar":2}'
,
{
"foo"
:
1
,
"bar"
:
2
...
...
@@ -31,15 +28,20 @@ from vllm.utils import FlexibleArgumentParser
"foo"
:
1
,
"bar"
:
2
}),
(
json
.
loads
,
"None"
,
None
),
])
def
test_
optional
_type
(
type
,
value
,
expected
):
optional
_type_func
=
optional
_type
(
type
)
def
test_
parse
_type
(
type
,
value
,
expected
):
parse
_type_func
=
parse
_type
(
type
)
context
=
nullcontext
()
if
value
==
"foo=1,bar=2"
:
context
=
pytest
.
warns
(
DeprecationWarning
)
with
context
:
assert
optional_type_func
(
value
)
==
expected
assert
parse_type_func
(
value
)
==
expected
def
test_optional_type
():
optional_type_func
=
optional_type
(
int
)
assert
optional_type_func
(
"None"
)
is
None
assert
optional_type_func
(
"42"
)
==
42
@
pytest
.
mark
.
parametrize
((
"type_hint"
,
"type"
,
"expected"
),
[
...
...
@@ -89,7 +91,40 @@ def test_literal_to_kwargs(type_hints, expected):
@
config
@
dataclass
class
DummyConfigClass
:
class
NestedConfig
:
field
:
int
=
1
"""field"""
@
config
@
dataclass
class
FromCliConfig1
:
field
:
int
=
1
"""field"""
@
classmethod
def
from_cli
(
cls
,
cli_value
:
str
):
inst
=
cls
(
**
json
.
loads
(
cli_value
))
inst
.
field
+=
1
return
inst
@
config
@
dataclass
class
FromCliConfig2
:
field
:
int
=
1
"""field"""
@
classmethod
def
from_cli
(
cls
,
cli_value
:
str
):
inst
=
cls
(
**
json
.
loads
(
cli_value
))
inst
.
field
+=
2
return
inst
@
config
@
dataclass
class
DummyConfig
:
regular_bool
:
bool
=
True
"""Regular bool with default True"""
optional_bool
:
Optional
[
bool
]
=
None
...
...
@@ -108,18 +143,24 @@ class DummyConfigClass:
"""Literal of literals with default 1"""
json_tip
:
dict
=
field
(
default_factory
=
dict
)
"""Dict which will be JSON in CLI"""
nested_config
:
NestedConfig
=
field
(
default_factory
=
NestedConfig
)
"""Nested config"""
from_cli_config1
:
FromCliConfig1
=
field
(
default_factory
=
FromCliConfig1
)
"""Config with from_cli method"""
from_cli_config2
:
FromCliConfig2
=
field
(
default_factory
=
FromCliConfig2
)
"""Different config with from_cli method"""
@
pytest
.
mark
.
parametrize
((
"type_hint"
,
"expected"
),
[
(
int
,
False
),
(
DummyConfig
Class
,
True
),
(
DummyConfig
,
True
),
])
def
test_is_not_builtin
(
type_hint
,
expected
):
assert
is_not_builtin
(
type_hint
)
==
expected
def
test_get_kwargs
():
kwargs
=
get_kwargs
(
DummyConfig
Class
)
kwargs
=
get_kwargs
(
DummyConfig
)
print
(
kwargs
)
# bools should not have their type set
...
...
@@ -142,6 +183,11 @@ def test_get_kwargs():
# dict should have json tip in help
json_tip
=
"
\n\n
Should be a valid JSON string."
assert
kwargs
[
"json_tip"
][
"help"
].
endswith
(
json_tip
)
# nested config should should construct the nested config
assert
kwargs
[
"nested_config"
][
"type"
](
'{"field": 2}'
)
==
NestedConfig
(
2
)
# from_cli configs should be constructed with the correct method
assert
kwargs
[
"from_cli_config1"
][
"type"
](
'{"field": 2}'
).
field
==
3
assert
kwargs
[
"from_cli_config2"
][
"type"
](
'{"field": 2}'
).
field
==
4
@
pytest
.
mark
.
parametrize
((
"arg"
,
"expected"
),
[
...
...
@@ -177,7 +223,7 @@ def test_compilation_config():
# default value
args
=
parser
.
parse_args
([])
assert
args
.
compilation_config
is
None
assert
args
.
compilation_config
==
CompilationConfig
()
# set to O3
args
=
parser
.
parse_args
([
"-O3"
])
...
...
@@ -194,7 +240,7 @@ def test_compilation_config():
# set to string form of a dict
args
=
parser
.
parse_args
([
"--compilation-config"
,
"{'
level
'
: 3,
'
cudagraph_capture_sizes
'
: [1, 2, 4, 8]}
"
,
'{"
level
"
: 3,
"
cudagraph_capture_sizes
"
: [1, 2, 4, 8]}
'
,
])
assert
(
args
.
compilation_config
.
level
==
3
and
args
.
compilation_config
.
cudagraph_capture_sizes
==
[
1
,
2
,
4
,
8
])
...
...
@@ -202,7 +248,7 @@ def test_compilation_config():
# set to string form of a dict
args
=
parser
.
parse_args
([
"--compilation-config="
"{'
level
'
: 3,
'
cudagraph_capture_sizes
'
: [1, 2, 4, 8]}
"
,
'{"
level
"
: 3,
"
cudagraph_capture_sizes
"
: [1, 2, 4, 8]}
'
,
])
assert
(
args
.
compilation_config
.
level
==
3
and
args
.
compilation_config
.
cudagraph_capture_sizes
==
[
1
,
2
,
4
,
8
])
...
...
vllm/compilation/vllm_inductor_pass.py
View file @
4b2ed792
...
...
@@ -4,7 +4,7 @@ import time
import
torch
from
vllm.config
import
Compilation
Config
,
VllmConfig
from
vllm.config
import
Pass
Config
,
VllmConfig
# yapf: disable
from
vllm.distributed
import
get_tensor_model_parallel_rank
as
get_tp_rank
from
vllm.distributed
import
(
...
...
@@ -56,10 +56,7 @@ class VllmInductorPass(InductorPass):
class
PrinterInductorPass
(
VllmInductorPass
):
def
__init__
(
self
,
name
:
str
,
config
:
CompilationConfig
.
PassConfig
,
always
=
False
):
def
__init__
(
self
,
name
:
str
,
config
:
PassConfig
,
always
=
False
):
super
().
__init__
(
config
)
self
.
name
=
name
self
.
always
=
always
...
...
vllm/config.py
View file @
4b2ed792
...
...
@@ -11,8 +11,8 @@ import textwrap
import
warnings
from
collections
import
Counter
from
contextlib
import
contextmanager
from
dataclasses
import
(
MISSING
,
dataclass
,
field
,
fields
,
is_dataclass
,
replace
)
from
dataclasses
import
(
MISSING
,
Field
,
asdict
,
dataclass
,
field
,
fields
,
is_dataclass
,
replace
)
from
functools
import
cached_property
from
importlib.util
import
find_spec
from
pathlib
import
Path
...
...
@@ -20,7 +20,6 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
Protocol
,
TypeVar
,
Union
,
cast
,
get_args
,
get_origin
)
import
torch
from
pydantic
import
BaseModel
,
Field
,
PrivateAttr
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
transformers
import
PretrainedConfig
from
typing_extensions
import
deprecated
...
...
@@ -57,7 +56,7 @@ if TYPE_CHECKING:
ConfigType
=
type
[
DataclassInstance
]
else
:
QuantizationConfig
=
None
QuantizationConfig
=
Any
ConfigType
=
type
logger
=
init_logger
(
__name__
)
...
...
@@ -169,6 +168,12 @@ def config(cls: ConfigT) -> ConfigT:
"""
A decorator that ensures all fields in a dataclass have default values
and that each field has a docstring.
If a `ConfigT` is used as a CLI argument itself, the default value provided
by `get_kwargs` will be the result parsing a JSON string as the kwargs
(i.e. `ConfigT(**json.loads(cli_arg))`). However, if a particular `ConfigT`
requires custom construction from CLI (i.e. `CompilationConfig`), it can
have a `from_cli` method, which will be called instead.
"""
if
not
is_dataclass
(
cls
):
raise
TypeError
(
"The decorated class must be a dataclass."
)
...
...
@@ -202,7 +207,7 @@ def get_field(cls: ConfigType, name: str) -> Field:
cls_fields
=
{
f
.
name
:
f
for
f
in
fields
(
cls
)}
if
name
not
in
cls_fields
:
raise
ValueError
(
f
"Field '
{
name
}
' not found in
{
cls
.
__name__
}
."
)
named_field
:
Field
=
cls_fields
.
get
(
name
)
named_field
:
Field
=
cls_fields
[
name
]
if
(
default_factory
:
=
named_field
.
default_factory
)
is
not
MISSING
:
return
field
(
default_factory
=
default_factory
)
if
(
default
:
=
named_field
.
default
)
is
not
MISSING
:
...
...
@@ -211,6 +216,10 @@ def get_field(cls: ConfigType, name: str) -> Field:
f
"
{
cls
.
__name__
}
.
{
name
}
must have a default value or default factory."
)
def
is_init_field
(
cls
:
ConfigType
,
name
:
str
)
->
bool
:
return
next
(
f
for
f
in
fields
(
cls
)
if
f
.
name
==
name
).
init
TokenizerMode
=
Literal
[
"auto"
,
"slow"
,
"mistral"
,
"custom"
]
ModelDType
=
Literal
[
"auto"
,
"half"
,
"float16"
,
"bfloat16"
,
"float"
,
"float32"
]
...
...
@@ -2007,13 +2016,13 @@ class SchedulerConfig:
def
__post_init__
(
self
)
->
None
:
if
self
.
max_model_len
is
None
:
self
.
max_model_len
=
8192
logger
.
warning
(
logger
.
warning
_once
(
"max_model_len was is not set. Defaulting to arbitrary value "
"of %d."
,
self
.
max_model_len
)
if
self
.
max_num_seqs
is
None
:
self
.
max_num_seqs
=
128
logger
.
warning
(
logger
.
warning
_once
(
"max_num_seqs was is not set. Defaulting to arbitrary value "
"of %d."
,
self
.
max_num_seqs
)
...
...
@@ -2840,8 +2849,8 @@ class PromptAdapterConfig:
class
MultiModalConfig
:
"""Controls the behavior of multimodal models."""
limit_per_prompt
:
dict
[
str
,
int
]
=
get_field
(
ModelConfig
,
"limit_mm_per_prompt"
)
limit_per_prompt
:
dict
[
str
,
int
]
=
\
cast
(
dict
[
str
,
int
],
get_field
(
ModelConfig
,
"limit_mm_per_prompt"
)
)
"""
The maximum number of input items allowed per prompt for each modality.
Defaults to 1 (V0) or 999 (V1) for each modality.
...
...
@@ -3415,41 +3424,49 @@ class ObservabilityConfig:
self
.
collect_detailed_traces
[
0
].
split
(
","
))
class
KVTransferConfig
(
BaseModel
):
KVProducer
=
Literal
[
"kv_producer"
,
"kv_both"
]
KVConsumer
=
Literal
[
"kv_consumer"
,
"kv_both"
]
KVRole
=
Literal
[
KVProducer
,
KVConsumer
]
@
config
@
dataclass
class
KVTransferConfig
:
"""Configuration for distributed KV cache transfer."""
# The KV connector for vLLM to transmit KV caches between vLLM instances.
kv_connector
:
Optional
[
str
]
=
None
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
"""
# The device used by kv connector to buffer the KV cache.
# Currently only support 'cuda'.
kv_buffer_device
:
Optional
[
str
]
=
"cuda"
"""The device used by kv connector to buffer the KV cache.
Currently only support 'cuda'."""
# The buffer size for TorchDistributedConnector. Measured in number of
# bytes. Recommended value: 1e9 (about 1GB).
kv_buffer_size
:
float
=
1e9
"""The buffer size for TorchDistributedConnector. Measured in number of
bytes. Recommended value: 1e9 (about 1GB)."""
# Whether this vLLM instance produces, consumes KV cache, or both. Choices
# are 'kv_producer', 'kv_consumer', and 'both'.
kv_ro
le
:
Optional
[
str
]
=
None
kv_role
:
Optional
[
KVRole
]
=
None
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices
are '
kv_
p
ro
ducer', 'kv_consumer', and 'both'."""
# The rank of this vLLM instance in the KV cache transfer. Typical value:
# 0 for prefill instance, 1 for decode instance.
# Currently only 1P1D is supported.
kv_rank
:
Optional
[
int
]
=
None
"""The rank of this vLLM instance in the KV cache transfer. Typical value:
0 for prefill instance, 1 for decode instance.
Currently only 1P1D is supported."""
# The number of parallel instances for KV cache transfer. For
# PyNcclConnector, this should be 2.
kv_parallel_size
:
int
=
1
"""The number of parallel instances for KV cache transfer. For
PyNcclConnector, this should be 2."""
# The KV connector ip, used to build distributed connection
kv_ip
:
str
=
"127.0.0.1"
"""The KV connector ip, used to build distributed connection."""
# The KV connector port, used to build distributed connection
kv_port
:
int
=
14579
"""The KV connector port, used to build distributed connection."""
# any
extra
config
that the connector may need
kv_connector_
extra
_
config
:
dict
[
str
,
Any
]
=
{}
kv_connector_
extra
_
config
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
"""any
extra
config
that the connector may need."""
def
compute_hash
(
self
)
->
str
:
"""
...
...
@@ -3470,46 +3487,37 @@ class KVTransferConfig(BaseModel):
usedforsecurity
=
False
).
hexdigest
()
return
hash_str
@
classmethod
def
from_cli
(
cls
,
cli_value
:
str
)
->
"KVTransferConfig"
:
"""Parse the CLI value for the kv cache transfer config."""
return
KVTransferConfig
.
model_validate_json
(
cli_value
)
def
model_post_init
(
self
,
__context
:
Any
)
->
None
:
if
self
.
kv_role
is
not
None
and
self
.
kv_role
not
in
[
"kv_producer"
,
"kv_consumer"
,
"kv_both"
]:
raise
ValueError
(
f
"Unsupported kv_role:
{
self
.
kv_role
}
. "
f
"Supported roles are `kv_producer`, `kv_consumer`, "
f
"and `kv_both`"
)
def
__post_init__
(
self
)
->
None
:
if
self
.
kv_role
is
not
None
and
self
.
kv_role
not
in
get_args
(
KVRole
):
raise
ValueError
(
f
"Unsupported kv_role:
{
self
.
kv_role
}
. "
f
"Supported roles are
{
get_args
(
KVRole
)
}
"
)
if
self
.
kv_connector
is
not
None
and
self
.
kv_role
is
None
:
raise
ValueError
(
"Please specify kv_disagg_role when kv_connector "
"is set, supported roles are `kv_producer`, "
"`kv_consumer`, and `kv_both`"
)
f
"is set, supported roles are
{
get_args
(
KVRole
)
}
"
)
@
property
def
is_kv_transfer_instance
(
self
)
->
bool
:
return
self
.
kv_connector
is
not
None
and
\
self
.
kv_role
in
[
"kv_producer"
,
"kv_consumer"
,
"kv_both"
]
self
.
kv_role
in
get_args
(
KVRole
)
@
property
def
is_kv_producer
(
self
)
->
bool
:
return
self
.
kv_connector
is
not
None
and
\
self
.
kv_role
in
[
"kv_producer"
,
"kv_both"
]
self
.
kv_role
in
get_args
(
KVProducer
)
@
property
def
is_kv_consumer
(
self
)
->
bool
:
return
self
.
kv_connector
is
not
None
and
\
self
.
kv_role
in
[
"kv_consumer"
,
"kv_both"
]
self
.
kv_role
in
get_args
(
KVConsumer
)
def
get_from_extra_config
(
self
,
key
,
default
)
->
Any
:
return
self
.
kv_connector_extra_config
.
get
(
key
,
default
)
class
KVEventsConfig
(
BaseModel
):
@
config
@
dataclass
class
KVEventsConfig
:
"""Configuration for KV event publishing."""
enable_kv_cache_events
:
bool
=
False
...
...
@@ -3548,11 +3556,6 @@ class KVEventsConfig(BaseModel):
this topic to receive events.
"""
@
classmethod
def
from_cli
(
cls
,
cli_value
:
str
)
->
"KVEventsConfig"
:
"""Parse the CLI value for the event publisher config."""
return
KVEventsConfig
.
model_validate_json
(
cli_value
)
class
CompilationLevel
:
# constants for the levels of the compilation process
...
...
@@ -3562,80 +3565,72 @@ class CompilationLevel:
PIECEWISE
=
3
class
CompilationConfig
(
BaseModel
):
"""
Configuration for compilation.
It has three parts:
@
config
@
dataclass
class
PassConfig
:
"""Configuration for custom Inductor passes.
This is separate from general `CompilationConfig` so that inductor passes
don't all have access to full configuration - that would create a cycle as
the `PassManager` is set as a property of config."""
dump_graph_stages
:
list
[
str
]
=
field
(
default_factory
=
list
)
"""List of stages for which we want to dump the graph. Each pass defines
its own stages (before, after, maybe in-between)."""
dump_graph_dir
:
Path
=
Path
(
"."
)
"""Directory to dump the graphs."""
# TODO(luka) better pass enabling system.
enable_fusion
:
bool
=
True
"""Whether to enable the custom fusion pass."""
enable_noop
:
bool
=
True
"""Whether to enable the custom no-op elimination pass."""
enable_sequence_parallelism
:
bool
=
False
"""Whether to enable sequence parallelism."""
def
uuid
(
self
):
"""
Produces a hash unique to the pass configuration.
Any new fields that affect compilation should be added to the hash.
Do not include dump_graph_* in the hash - they don't affect
compilation.
"""
include
=
{
"enable_fusion"
,
"enable_noop"
,
"enable_sequence_parallelism"
}
dict_
=
{
k
:
v
for
k
,
v
in
asdict
(
self
).
items
()
if
k
in
include
}
return
InductorPass
.
hash_dict
(
dict_
)
def
__post_init__
(
self
)
->
None
:
if
not
self
.
enable_noop
and
self
.
enable_fusion
:
logger
.
warning_once
(
"Fusion enabled but reshape elimination disabled. "
"RMSNorm + quant (fp8) fusion might not work"
)
@
config
@
dataclass
class
CompilationConfig
:
"""Configuration for compilation. It has three parts:
- Top-level Compilation control:
- level: the level of compilation.
- 0: no compilation.
- 1: dynamo as is.
- 2: dynamo once.
- 3: piecewise compilation.
- debug_dump_path: the path to dump the debug information.
- cache_dir: the directory to store the compiled graph, to
accelerate Inductor compilation. By default, it will use
model-related information to generate a cache directory.
- backend: the backend for compilation. It needs to be a string.
- "" (empty string): use the default backend.
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
- "full.module.name": a qualified name which can be used to import the backend function.
We use string to avoid serialization issues when using compilation in a distributed setting.
When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph).
When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph).
- custom_ops: fine-grained control over which custom ops to enable/disable.
Use 'all' to enable all, 'none' to disable all.
Also specify a list of custom op names to enable (prefixed with a '+'),
or disable (prefixed with a '-').
Examples:
- 'all,-op1' to enable all except op1
- 'none,+op1,+op2' to enable only op1 and op2
By default, all custom ops are enabled when running without Inductor
and disabled when running with Inductor (compile_level >= Inductor).
- splitting_ops: a list of ops to split the full graph into subgraphs, used in piecewise compilation.
- {attr}`level`
- {attr}`debug_dump_path`
- {attr}`cache_dir`
- {attr}`backend`
- {attr}`custom_ops`
- {attr}`splitting_ops`
- CudaGraph capture:
- use_cudagraph: whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used.
- True: cudagraph inside compilation is used. It requires
that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
Note that this is orthogonal to the cudagraph capture logic
outside of compilation.
TODO: move outside cudagraph logic into compilation.
torch.compile will handle cudagraph capture logic in the future.
- cudagraph_capture_sizes: sizes to capture cudagraph.
- None (default): capture sizes are inferred from vllm config.
- list[int]: capture sizes are specified as given.
- cudagraph_num_of_warmups: number of warmup runs for cudagraph.
It means the first several runs will be treated as warmup runs.
Only after that, the execution will be recorded, and the recorded
cudagraph will be used for subsequent runs.
- cudagraph_copy_inputs: whether to copy input tensors for
cudagraph. If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False.
- full_cuda_graph: whether to use a full cuda graph for the entire forward
pass rather than splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models.
- {attr}`use_cudagraph`
- {attr}`cudagraph_capture_sizes`
- {attr}`cudagraph_num_of_warmups`
- {attr}`cudagraph_copy_inputs`
- {attr}`full_cuda_graph`
- Inductor compilation:
- use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager.
- True: inductor compilation is used. one graph for symbolic shape
is compiled. In addition, compile for compile_sizes,
using configurations in inductor_compile_config.
- compile_sizes: sizes to compile for inductor. In addition
to integers, it also supports "cudagraph_capture_sizes" to
specify the sizes for cudagraph capture.
- inductor_compile_config: additional configurations for inductor.
- None: use default configurations.
- inductor_passes: additional passes for inductor. It is a dictionary
from pass name to pass function qualified name. We use function
name because the config uses json format. If we pass the config
from Python, functions can also be passed directly via Python object
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
- custom inductor passes: see PassConfig for more details
- {attr}`use_inductor`
- {attr}`compile_sizes`
- {attr}`inductor_compile_config`
- {attr}`inductor_passes`
- custom inductor passes
Why we have different sizes for cudagraph and inductor:
- cudagraph: a cudagraph captured for a specific size can only be used
...
...
@@ -3646,83 +3641,135 @@ class CompilationConfig(BaseModel):
static shapes. However, we find the general shape compilation is
sufficient for most cases. It might be beneficial to compile for
certain small batchsizes, where inductor is good at optimizing.
"""
# noqa
"""
# Top-level Compilation control
level
:
int
=
0
"""The level of compilation:
- 0: no compilation.
- 1: dynamo as is.
- 2: dynamo once.
- 3: piecewise compilation."""
debug_dump_path
:
str
=
""
"""The path to dump the debug information."""
cache_dir
:
str
=
""
"""The directory to store the compiled graph, to accelerate Inductor
compilation. By default, it will use model-related information to generate
a cache directory."""
backend
:
str
=
""
custom_ops
:
list
[
str
]
=
Field
(
default_factory
=
list
)
splitting_ops
:
list
[
str
]
=
Field
(
default
=
None
)
# type: ignore
"""The backend for compilation. It needs to be a string:
- "" (empty string): use the default backend.
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
- "full.module.name": a qualified name which can be used to import the
backend function.
We use string to avoid serialization issues when using compilation in a
distributed setting. When the compilation level is 1 or 2, the backend is
used for the compilation directly (it sees the whole graph). When the
compilation level is 3, the backend is used for the piecewise compilation
(it sees a part of the graph)."""
custom_ops
:
list
[
str
]
=
field
(
default_factory
=
list
)
"""Fine-grained control over which custom ops to enable/disable. Use 'all'
to enable all, 'none' to disable all. Also specify a list of custom op
names to enable (prefixed with a '+'), or disable (prefixed with a '-').
Examples:
- 'all,-op1' to enable all except op1
- 'none,+op1,+op2' to enable only op1 and op2
By default, all custom ops are enabled when running without Inductor and
disabled when running with Inductor (compile_level >= Inductor)."""
splitting_ops
:
list
[
str
]
=
field
(
default_factory
=
list
)
"""A list of ops to split the full graph into subgraphs, used in piecewise
compilation."""
# Inductor capture
use_inductor
:
bool
=
True
compile_sizes
:
Optional
[
list
[
Union
[
int
,
str
]]]
=
Field
(
default
=
None
)
inductor_compile_config
:
dict
=
Field
(
default_factory
=
dict
)
inductor_passes
:
dict
[
str
,
str
]
=
Field
(
default_factory
=
dict
)
"""Whether to use inductor compilation:
- False: inductor compilation is not used. graph runs in eager.
- True: inductor compilation is used. one graph for symbolic shape
is compiled. In addition, compile for compile_sizes,
using configurations in inductor_compile_config."""
compile_sizes
:
Optional
[
list
[
Union
[
int
,
str
]]]
=
None
"""Sizes to compile for inductor. In addition
to integers, it also supports "cudagraph_capture_sizes" to
specify the sizes for cudagraph capture."""
inductor_compile_config
:
dict
=
field
(
default_factory
=
dict
)
"""Additional configurations for inductor.
- None: use default configurations."""
inductor_passes
:
dict
[
str
,
str
]
=
field
(
default_factory
=
dict
)
"""Additional passes for inductor. It is a dictionary
from pass name to pass function qualified name. We use function
name because the config uses JSON format. If we pass the config
from Python, functions can also be passed directly via Python object
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
# CudaGraph compilation
use_cudagraph
:
bool
=
False
"""Whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used.
- True: cudagraph inside compilation is used. It requires
that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
Note that this is orthogonal to the cudagraph capture logic
outside of compilation.
TODO: move outside cudagraph logic into compilation.
torch.compile will handle cudagraph capture logic in the future."""
cudagraph_num_of_warmups
:
int
=
0
"""Number of warmup runs for cudagraph.
It means the first several runs will be treated as warmup runs.
Only after that, the execution will be recorded, and the recorded
cudagraph will be used for subsequent runs."""
cudagraph_capture_sizes
:
Optional
[
list
[
int
]]
=
None
"""Sizes to capture cudagraph.
- None (default): capture sizes are inferred from vllm config.
- list[int]: capture sizes are specified as given."""
cudagraph_copy_inputs
:
bool
=
False
"""Whether to copy input tensors for
cudagraph. If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False."""
full_cuda_graph
:
bool
=
False
class
PassConfig
(
BaseModel
):
"""
Configuration for custom Inductor passes.
This is separate from general CompilationConfig so that inductor passes
don't all have access to full configuration - that would create a cycle
as the PassManager is set as a property of config.
- dump_graph_stages: list of stages for which we want to dump the graph.
Each pass defines its own stages (before, after, maybe in-between).
- dump_graph_dir: directory to dump the graphs. Default is .
- enable_fusion: whether to enable the custom fusion pass.
- enable_noop: whether to enable the custom no-op elimination pass.
TODO(luka) better pass enabling system.
- enable_sequence_parallelism: whether to enable sequence parallelism.
"""
dump_graph_stages
:
list
[
str
]
=
Field
(
default_factory
=
list
)
dump_graph_dir
:
Path
=
Field
(
default
=
Path
(
"."
))
enable_fusion
:
bool
=
True
enable_noop
:
bool
=
True
enable_sequence_parallelism
:
bool
=
False
def
uuid
(
self
):
"""
Produces a hash unique to the pass configuration.
Any new fields that affect compilation should be added to the hash.
Do not include dump_graph_* in the hash - they don't affect
compilation.
"""
dict_
=
self
.
model_dump
(
include
=
{
"enable_fusion"
,
"enable_noop"
,
\
"enable_sequence_parallelism"
})
return
InductorPass
.
hash_dict
(
dict_
)
def
model_post_init
(
self
,
__context
:
Any
)
->
None
:
if
not
self
.
enable_noop
and
self
.
enable_fusion
:
logger
.
warning_once
(
"Fusion enabled but reshape elimination disabled. "
"RMSNorm + quant (fp8) fusion might not work"
)
pass_config
:
PassConfig
=
Field
(
default_factory
=
PassConfig
)
# not configurable, computed after init
max_capture_size
:
int
=
PrivateAttr
local_cache_dir
:
str
=
PrivateAttr
# local cache dir for each rank
# optimization:
# Intuitively, bs_to_padded_graph_size should be dict[int, int].
# since we know all keys are in a range [0, max_capture_size],
# we can optimize it to list[int] for better lookup performance.
bs_to_padded_graph_size
:
list
[
int
]
=
PrivateAttr
"""whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs. Thus this
flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models."""
pass_config
:
PassConfig
=
field
(
default_factory
=
PassConfig
)
"""Custom inductor passes, see PassConfig for more details"""
max_capture_size
:
int
=
field
(
default
=
None
,
init
=
False
)
# type: ignore
"""not configurable, computed after init"""
local_cache_dir
:
str
=
field
(
default
=
None
,
init
=
False
)
# type: ignore
"""local cache dir for each rank"""
bs_to_padded_graph_size
:
list
[
int
]
=
field
(
default
=
None
,
# type: ignore
init
=
False
)
"""optimization:
Intuitively, bs_to_padded_graph_size should be dict[int, int].
since we know all keys are in a range [0, max_capture_size],
we can optimize it to list[int] for better lookup performance."""
# keep track of enabled and disabled custom ops
enabled_custom_ops
:
Counter
[
str
]
=
PrivateAttr
disabled_custom_ops
:
Counter
[
str
]
=
PrivateAttr
traced_files
:
set
[
str
]
=
PrivateAttr
compilation_time
:
float
=
PrivateAttr
# Per-model forward context
# Map from layer name to layer objects that need to be accessed outside
# model code, e.g., Attention, FusedMOE when dp_size>1.
static_forward_context
:
dict
[
str
,
Any
]
=
PrivateAttr
enabled_custom_ops
:
Counter
[
str
]
=
field
(
default_factory
=
Counter
,
init
=
False
)
"""custom ops that are enabled"""
disabled_custom_ops
:
Counter
[
str
]
=
field
(
default_factory
=
Counter
,
init
=
False
)
"""custom ops that are disabled"""
traced_files
:
set
[
str
]
=
field
(
default_factory
=
set
,
init
=
False
)
"""files that are traced for compilation"""
compilation_time
:
float
=
field
(
default
=
0.0
,
init
=
False
)
"""time taken for compilation"""
static_forward_context
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
,
init
=
False
)
"""Per-model forward context
Map from layer name to layer objects that need to be accessed outside
model code, e.g., Attention, FusedMOE when dp_size>1."""
def
compute_hash
(
self
)
->
str
:
"""
...
...
@@ -3757,7 +3804,17 @@ class CompilationConfig(BaseModel):
"pass_config"
,
"traced_files"
,
}
return
self
.
model_dump_json
(
exclude
=
exclude
,
exclude_unset
=
True
)
include
=
dict
()
for
k
,
v
in
asdict
(
self
).
items
():
if
k
in
exclude
:
continue
f
=
get_field
(
CompilationConfig
,
k
)
if
(
d
:
=
f
.
default
)
is
not
MISSING
and
d
==
v
:
continue
if
(
df
:
=
f
.
default_factory
)
is
not
MISSING
and
df
()
==
v
:
continue
include
[
k
]
=
v
return
json
.
dumps
(
include
)
__str__
=
__repr__
...
...
@@ -3766,12 +3823,9 @@ class CompilationConfig(BaseModel):
"""Parse the CLI value for the compilation config."""
if
cli_value
in
[
"0"
,
"1"
,
"2"
,
"3"
]:
return
cls
(
level
=
int
(
cli_value
))
# do not use `eval`, it is dangerous and can execute arbitrary code
dict_value
=
ast
.
literal_eval
(
cli_value
)
return
CompilationConfig
.
model_validate
(
dict_value
)
def
model_post_init
(
self
,
__context
:
Any
)
->
None
:
return
cls
(
**
json
.
loads
(
cli_value
))
def
__post_init__
(
self
)
->
None
:
count_none
=
self
.
custom_ops
.
count
(
"none"
)
count_all
=
self
.
custom_ops
.
count
(
"all"
)
assert
count_none
+
count_all
<=
1
,
"Can only specify 'none' or 'all'"
...
...
@@ -3789,9 +3843,6 @@ class CompilationConfig(BaseModel):
if
KEY
not
in
self
.
inductor_compile_config
:
self
.
inductor_compile_config
[
KEY
]
=
False
if
self
.
splitting_ops
is
None
:
self
.
splitting_ops
=
[]
for
k
,
v
in
self
.
inductor_passes
.
items
():
if
not
isinstance
(
v
,
str
):
assert
callable
(
v
),
(
...
...
@@ -3808,11 +3859,8 @@ class CompilationConfig(BaseModel):
self
.
inductor_compile_config
[
k
]
=
func
if
isinstance
(
func
,
InductorPass
)
else
CallableInductorPass
(
func
)
self
.
enabled_custom_ops
=
Counter
()
self
.
disabled_custom_ops
=
Counter
()
self
.
traced_files
=
set
()
self
.
static_forward_context
=
{}
self
.
compilation_time
=
0.0
if
isinstance
(
self
.
pass_config
,
dict
):
self
.
pass_config
=
PassConfig
(
**
self
.
pass_config
)
def
init_backend
(
self
,
vllm_config
:
"VllmConfig"
)
->
Union
[
str
,
Callable
]:
if
self
.
level
==
CompilationLevel
.
NO_COMPILATION
:
...
...
@@ -3899,39 +3947,67 @@ class CompilationConfig(BaseModel):
]
@
config
@
dataclass
class
VllmConfig
:
"""Dataclass which contains all vllm-related configuration. This
simplifies passing around the distinct configurations in the codebase.
"""
model_config
:
ModelConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
cache_config
:
CacheConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
parallel_config
:
ParallelConfig
=
field
(
default_factory
=
ParallelConfig
,
init
=
True
)
scheduler_config
:
SchedulerConfig
=
field
(
default_factory
=
SchedulerConfig
,
init
=
True
)
device_config
:
DeviceConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
load_config
:
LoadConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
model_config
:
ModelConfig
=
field
(
default_factory
=
ModelConfig
)
"""Model configuration."""
cache_config
:
CacheConfig
=
field
(
default_factory
=
CacheConfig
)
"""Cache configuration."""
parallel_config
:
ParallelConfig
=
field
(
default_factory
=
ParallelConfig
)
"""Parallel configuration."""
scheduler_config
:
SchedulerConfig
=
field
(
default_factory
=
SchedulerConfig
)
"""Scheduler configuration."""
device_config
:
DeviceConfig
=
field
(
default_factory
=
DeviceConfig
)
"""Device configuration."""
load_config
:
LoadConfig
=
field
(
default_factory
=
LoadConfig
)
"""Load configuration."""
lora_config
:
Optional
[
LoRAConfig
]
=
None
speculative_config
:
SpeculativeConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
"""LoRA configuration."""
speculative_config
:
Optional
[
SpeculativeConfig
]
=
None
"""Speculative decoding configuration."""
decoding_config
:
Optional
[
DecodingConfig
]
=
None
"""Decoding configuration."""
observability_config
:
Optional
[
ObservabilityConfig
]
=
None
"""Observability configuration."""
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
"""Prompt adapter configuration."""
quant_config
:
Optional
[
QuantizationConfig
]
=
None
compilation_config
:
CompilationConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
kv_transfer_config
:
KVTransferConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
"""Quantization configuration."""
compilation_config
:
CompilationConfig
=
field
(
default_factory
=
CompilationConfig
)
"""`torch.compile` configuration for the model.
When it is a number (0, 1, 2, 3), it will be interpreted as the
optimization level.
NOTE: level 0 is the default level without any optimization. level 1 and 2
are for internal testing only. level 3 is the recommended level for
production.
Following the convention of traditional compilers, using `-O` without space
is also supported. `-O3` is equivalent to `-O 3`.
You can specify the full compilation config like so:
`{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
"""
kv_transfer_config
:
Optional
[
KVTransferConfig
]
=
None
"""The configurations for distributed KV cache transfer."""
kv_events_config
:
Optional
[
KVEventsConfig
]
=
None
"""The configurations for event publishing."""
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing, debugging or out of
# tree config registration.
additional_config
:
SupportsHash
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
additional_config
:
Union
[
dict
,
SupportsHash
]
=
field
(
default_factory
=
dict
)
"""Additional config for specified platform. Different platforms may
support different configs. Make sure the configs are valid for the platform
you are using. Contents must be hashable."""
instance_id
:
str
=
""
"""The ID of the vLLM instance."""
def
compute_hash
(
self
)
->
str
:
"""
...
...
@@ -4012,7 +4088,14 @@ class VllmConfig:
else
:
vllm_factors
.
append
(
"None"
)
if
self
.
additional_config
:
vllm_factors
.
append
(
self
.
additional_config
.
compute_hash
())
if
isinstance
(
additional_config
:
=
self
.
additional_config
,
dict
):
additional_config_hash
=
hashlib
.
md5
(
json
.
dumps
(
additional_config
,
sort_keys
=
True
).
encode
(),
usedforsecurity
=
False
,
).
hexdigest
()
else
:
additional_config_hash
=
additional_config
.
compute_hash
()
vllm_factors
.
append
(
additional_config_hash
)
else
:
vllm_factors
.
append
(
"None"
)
factors
.
append
(
vllm_factors
)
...
...
vllm/distributed/kv_events.py
View file @
4b2ed792
...
...
@@ -5,6 +5,7 @@ import threading
import
time
from
abc
import
ABC
,
abstractmethod
from
collections
import
deque
from
dataclasses
import
asdict
from
itertools
import
count
from
queue
import
Queue
from
typing
import
Any
,
Callable
,
Optional
,
Union
...
...
@@ -284,7 +285,7 @@ class EventPublisherFactory:
if
not
config
:
return
NullEventPublisher
()
config_dict
=
config
.
model_dump
(
)
config_dict
=
asdict
(
config
)
kind
=
config_dict
.
pop
(
"publisher"
,
"null"
)
config_dict
.
pop
(
"enable_kv_cache_events"
)
...
...
vllm/engine/arg_utils.py
View file @
4b2ed792
...
...
@@ -7,10 +7,10 @@ import json
import
re
import
threading
import
warnings
from
dataclasses
import
MISSING
,
dataclass
,
fields
from
dataclasses
import
MISSING
,
dataclass
,
fields
,
is_dataclass
from
itertools
import
permutations
from
typing
import
(
Any
,
Callable
,
Dict
,
List
,
Literal
,
Optional
,
Type
,
TypeVar
,
Union
,
cast
,
get_args
,
get_origin
)
from
typing
import
(
Annotated
,
Any
,
Callable
,
Dict
,
List
,
Literal
,
Optional
,
Type
,
TypeVar
,
Union
,
cast
,
get_args
,
get_origin
)
import
torch
from
typing_extensions
import
TypeIs
,
deprecated
...
...
@@ -36,7 +36,8 @@ from vllm.reasoning import ReasoningParserManager
from
vllm.test_utils
import
MODEL_WEIGHTS_S3_BUCKET
,
MODELS_ON_S3
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
GiB_bytes
,
is_in_ray_actor
from
vllm.utils
import
(
FlexibleArgumentParser
,
GiB_bytes
,
is_in_doc_build
,
is_in_ray_actor
)
# yapf: enable
...
...
@@ -48,12 +49,9 @@ TypeHint = Union[type[Any], object]
TypeHintT
=
Union
[
type
[
T
],
object
]
def
optional_type
(
return_type
:
Callable
[[
str
],
T
])
->
Callable
[[
str
],
Optional
[
T
]]:
def
parse_type
(
return_type
:
Callable
[[
str
],
T
])
->
Callable
[[
str
],
T
]:
def
_optional_type
(
val
:
str
)
->
Optional
[
T
]:
if
val
==
""
or
val
==
"None"
:
return
None
def
_parse_type
(
val
:
str
)
->
T
:
try
:
if
return_type
is
json
.
loads
and
not
re
.
match
(
"^{.*}$"
,
val
):
return
cast
(
T
,
nullable_kvs
(
val
))
...
...
@@ -62,14 +60,24 @@ def optional_type(
raise
argparse
.
ArgumentTypeError
(
f
"Value
{
val
}
cannot be converted to
{
return_type
}
."
)
from
e
return
_parse_type
def
optional_type
(
return_type
:
Callable
[[
str
],
T
])
->
Callable
[[
str
],
Optional
[
T
]]:
def
_optional_type
(
val
:
str
)
->
Optional
[
T
]:
if
val
==
""
or
val
==
"None"
:
return
None
return
parse_type
(
return_type
)(
val
)
return
_optional_type
def
union_dict_and_str
(
val
:
str
)
->
Optional
[
Union
[
str
,
dict
[
str
,
str
]]]:
if
not
re
.
match
(
"^{.*}$"
,
val
):
return
str
(
val
)
else
:
return
optional_type
(
json
.
loads
)(
val
)
return
optional_type
(
json
.
loads
)(
val
)
@
deprecated
(
...
...
@@ -144,10 +152,25 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs
=
get_attr_docs
(
cls
)
kwargs
=
{}
for
field
in
fields
(
cls
):
# Get the set of possible types for the field
type_hints
:
set
[
TypeHint
]
=
set
()
if
get_origin
(
field
.
type
)
in
{
Union
,
Annotated
}:
type_hints
.
update
(
get_args
(
field
.
type
))
else
:
type_hints
.
add
(
field
.
type
)
# If the field is a dataclass, we can use the model_validate_json
generator
=
(
th
for
th
in
type_hints
if
is_dataclass
(
th
))
dataclass_cls
=
next
(
generator
,
None
)
# Get the default value of the field
default
=
field
.
default
if
field
.
default_factory
is
not
MISSING
:
default
=
field
.
default_factory
()
if
field
.
default
is
not
MISSING
:
default
=
field
.
default
elif
field
.
default_factory
is
not
MISSING
:
if
is_dataclass
(
field
.
default_factory
)
and
is_in_doc_build
():
default
=
{}
else
:
default
=
field
.
default_factory
()
# Get the help text for the field
name
=
field
.
name
...
...
@@ -158,16 +181,17 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
# Initialise the kwargs dictionary for the field
kwargs
[
name
]
=
{
"default"
:
default
,
"help"
:
help
}
# Get the set of possible types for the field
type_hints
:
set
[
TypeHint
]
=
set
()
if
get_origin
(
field
.
type
)
is
Union
:
type_hints
.
update
(
get_args
(
field
.
type
))
else
:
type_hints
.
add
(
field
.
type
)
# Set other kwargs based on the type hints
json_tip
=
"
\n\n
Should be a valid JSON string."
if
contains_type
(
type_hints
,
bool
):
if
dataclass_cls
is
not
None
:
dataclass_init
=
lambda
x
,
f
=
dataclass_cls
:
f
(
**
json
.
loads
(
x
))
# Special case for configs with a from_cli method
if
hasattr
(
dataclass_cls
,
"from_cli"
):
from_cli
=
dataclass_cls
.
from_cli
dataclass_init
=
lambda
x
,
f
=
from_cli
:
f
(
x
)
kwargs
[
name
][
"type"
]
=
dataclass_init
kwargs
[
name
][
"help"
]
+=
json_tip
elif
contains_type
(
type_hints
,
bool
):
# Creates --no-<name> and --<name> flags
kwargs
[
name
][
"action"
]
=
argparse
.
BooleanOptionalAction
elif
contains_type
(
type_hints
,
Literal
):
...
...
@@ -202,7 +226,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
kwargs
[
name
][
"type"
]
=
union_dict_and_str
elif
contains_type
(
type_hints
,
dict
):
# Dict arguments will always be optional
kwargs
[
name
][
"type"
]
=
optional
_type
(
json
.
loads
)
kwargs
[
name
][
"type"
]
=
parse
_type
(
json
.
loads
)
kwargs
[
name
][
"help"
]
+=
json_tip
elif
(
contains_type
(
type_hints
,
str
)
or
any
(
is_not_builtin
(
th
)
for
th
in
type_hints
)):
...
...
@@ -771,63 +795,20 @@ class EngineArgs:
scheduler_group
.
add_argument
(
"--scheduler-cls"
,
**
scheduler_kwargs
[
"scheduler_cls"
])
# Compilation arguments
# compilation_kwargs = get_kwargs(CompilationConfig)
compilation_group
=
parser
.
add_argument_group
(
title
=
"CompilationConfig"
,
description
=
CompilationConfig
.
__doc__
,
)
compilation_group
.
add_argument
(
"--compilation-config"
,
"-O"
,
type
=
CompilationConfig
.
from_cli
,
default
=
None
,
help
=
"torch.compile configuration for the model. "
"When it is a number (0, 1, 2, 3), it will be "
"interpreted as the optimization level.
\n
"
"NOTE: level 0 is the default level without "
"any optimization. level 1 and 2 are for internal "
"testing only. level 3 is the recommended level "
"for production.
\n
"
"To specify the full compilation config, "
"use a JSON string, e.g. ``{
\"
level
\"
: 3, "
"
\"
cudagraph_capture_sizes
\"
: [1, 2, 4, 8]}``
\n
"
"Following the convention of traditional "
"compilers, using ``-O`` without space is also "
"supported. ``-O3`` is equivalent to ``-O 3``."
)
# KVTransfer arguments
# kv_transfer_kwargs = get_kwargs(KVTransferConfig)
kv_transfer_group
=
parser
.
add_argument_group
(
title
=
"KVTransferConfig"
,
description
=
KVTransferConfig
.
__doc__
,
)
kv_transfer_group
.
add_argument
(
"--kv-transfer-config"
,
type
=
KVTransferConfig
.
from_cli
,
default
=
None
,
help
=
"The configurations for distributed KV cache "
"transfer. Should be a JSON string."
)
kv_transfer_group
.
add_argument
(
'--kv-events-config'
,
type
=
KVEventsConfig
.
from_cli
,
default
=
None
,
help
=
'The configurations for event publishing.'
)
# vLLM arguments
#
vllm_kwargs = get_kwargs(VllmConfig)
vllm_kwargs
=
get_kwargs
(
VllmConfig
)
vllm_group
=
parser
.
add_argument_group
(
title
=
"VllmConfig"
,
description
=
VllmConfig
.
__doc__
,
)
vllm_group
.
add_argument
(
"--additional-
config"
,
type
=
json
.
loads
,
default
=
None
,
help
=
"Addi
tion
al
config
for specified platform in JSON format. "
"Different platforms may support different configs. Make sure the "
"configs are valid for the platform you are using. The input format"
" is like '{
\"
config_key
\"
:
\"
config_value
\"
}'"
)
vllm_group
.
add_argument
(
"--kv-transfer-config"
,
**
vllm_kwargs
[
"kv_transfer_
config"
])
vllm_group
.
add_argument
(
'--kv-events-config'
,
**
vllm_kwargs
[
"kv_events_config"
])
vllm_group
.
add_argument
(
"--compila
tion
-
config
"
,
"-O"
,
**
vllm_kwargs
[
"compilation_config"
])
vllm_group
.
add_argument
(
"--additional-config"
,
**
vllm_kwargs
[
"additional_config"
]
)
# Other arguments
parser
.
add_argument
(
'--use-v2-block-manager'
,
...
...
vllm/entrypoints/llm.py
View file @
4b2ed792
...
...
@@ -13,7 +13,8 @@ from typing_extensions import TypeVar, deprecated
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
BeamSearchSequence
,
get_beam_search_score
)
from
vllm.config
import
CompilationConfig
,
ModelDType
,
TokenizerMode
from
vllm.config
import
(
CompilationConfig
,
ModelDType
,
TokenizerMode
,
is_init_field
)
from
vllm.engine.arg_utils
import
(
EngineArgs
,
HfOverrides
,
PoolerConfig
,
TaskOption
)
from
vllm.engine.llm_engine
import
LLMEngine
...
...
@@ -204,9 +205,13 @@ class LLM:
kwargs
[
"worker_cls"
]
=
cloudpickle
.
dumps
(
worker_cls
)
if
compilation_config
is
not
None
:
if
isinstance
(
compilation_config
,
(
int
,
dict
)):
compilation_config_instance
=
CompilationConfig
.
from_cli
(
str
(
compilation_config
))
if
isinstance
(
compilation_config
,
int
):
compilation_config_instance
=
CompilationConfig
(
level
=
compilation_config
)
elif
isinstance
(
compilation_config
,
dict
):
predicate
=
lambda
x
:
is_init_field
(
CompilationConfig
,
x
[
0
])
compilation_config_instance
=
CompilationConfig
(
**
dict
(
filter
(
predicate
,
compilation_config
.
items
())))
else
:
compilation_config_instance
=
compilation_config
else
:
...
...
vllm/platforms/tpu.py
View file @
4b2ed792
# SPDX-License-Identifier: Apache-2.0
from
typing
import
TYPE_CHECKING
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Tuple
,
Union
,
cast
import
torch
from
tpu_info
import
device
...
...
@@ -13,9 +13,10 @@ from vllm.sampling_params import SamplingParams, SamplingType
from
.interface
import
Platform
,
PlatformEnum
,
_Backend
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.config
import
BlockSize
,
ModelConfig
,
VllmConfig
from
vllm.pooling_params
import
PoolingParams
else
:
BlockSize
=
None
ModelConfig
=
None
VllmConfig
=
None
PoolingParams
=
None
...
...
@@ -94,7 +95,7 @@ class TpuPlatform(Platform):
cache_config
=
vllm_config
.
cache_config
# For v0, the default block size is 16.
if
cache_config
and
cache_config
.
block_size
is
None
:
cache_config
.
block_size
=
16
cache_config
.
block_size
=
cast
(
BlockSize
,
16
)
compilation_config
=
vllm_config
.
compilation_config
# TPU only supports DYNAMO_ONCE compilation level
...
...
@@ -118,7 +119,7 @@ class TpuPlatform(Platform):
from
vllm.v1.attention.backends.pallas
import
(
PallasAttentionBackend
)
cache_config
.
block_size
=
PallasAttentionBackend
.
get_page_size
(
vllm_config
)
vllm_config
)
# type: ignore[assignment]
min_page_size
=
PallasAttentionBackend
.
get_min_page_size
(
vllm_config
)
if
min_page_size
>
cache_config
.
block_size
:
...
...
@@ -128,7 +129,7 @@ class TpuPlatform(Platform):
cache_config
.
block_size
,
min_page_size
,
)
cache_config
.
block_size
=
min_page_size
cache_config
.
block_size
=
min_page_size
# type: ignore[assignment]
parallel_config
=
vllm_config
.
parallel_config
scheduler_config
=
vllm_config
.
scheduler_config
...
...
vllm/utils.py
View file @
4b2ed792
...
...
@@ -1820,6 +1820,14 @@ def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
return
torch
.
ops
.
_C
.
get_cuda_view_from_cpu_tensor
(
cpu_tensor
)
def
is_in_doc_build
()
->
bool
:
try
:
from
sphinx.ext.autodoc.mock
import
_MockModule
return
isinstance
(
zmq
,
_MockModule
)
except
ModuleNotFoundError
:
return
False
def
import_from_path
(
module_name
:
str
,
file_path
:
Union
[
str
,
os
.
PathLike
]):
"""
Import a Python file according to its file path.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment