Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
45bd5c8e
Unverified
Commit
45bd5c8e
authored
Mar 23, 2026
by
Wentao Ye
Committed by
GitHub
Mar 23, 2026
Browse files
[Mypy] Fix mypy for `vllm/config` (#37808)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
10a1018c
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
68 additions
and
59 deletions
+68
-59
tools/pre_commit/mypy.py
tools/pre_commit/mypy.py
+0
-1
vllm/config/attention.py
vllm/config/attention.py
+1
-1
vllm/config/compilation.py
vllm/config/compilation.py
+25
-21
vllm/config/device.py
vllm/config/device.py
+2
-2
vllm/config/kernel.py
vllm/config/kernel.py
+1
-1
vllm/config/kv_events.py
vllm/config/kv_events.py
+1
-1
vllm/config/lora.py
vllm/config/lora.py
+2
-2
vllm/config/model.py
vllm/config/model.py
+11
-10
vllm/config/pooler.py
vllm/config/pooler.py
+2
-2
vllm/config/scheduler.py
vllm/config/scheduler.py
+1
-1
vllm/config/speculative.py
vllm/config/speculative.py
+2
-2
vllm/config/utils.py
vllm/config/utils.py
+4
-4
vllm/config/vllm.py
vllm/config/vllm.py
+8
-8
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-1
vllm/v1/cudagraph_dispatcher.py
vllm/v1/cudagraph_dispatcher.py
+7
-2
No files found.
tools/pre_commit/mypy.py
View file @
45bd5c8e
...
...
@@ -40,7 +40,6 @@ EXCLUDE = [
"vllm/v1/attention/ops"
,
# TODO: Remove these entries after fixing mypy errors.
"vllm/benchmarks"
,
"vllm/config"
,
]
...
...
vllm/config/attention.py
View file @
45bd5c8e
...
...
@@ -56,7 +56,7 @@ class AttentionConfig:
"""
from
vllm.config.utils
import
get_hash_factors
,
hash_factors
ignored_factors
:
li
st
[
str
]
=
[]
ignored_factors
:
s
e
t
[
str
]
=
set
()
factors
=
get_hash_factors
(
self
,
ignored_factors
)
return
hash_factors
(
factors
)
...
...
vllm/config/compilation.py
View file @
45bd5c8e
...
...
@@ -116,29 +116,29 @@ class PassConfig:
"""
# New flags
fuse_norm_quant
:
bool
=
Field
(
default
=
None
)
fuse_norm_quant
:
bool
|
None
=
Field
(
default
=
None
)
"""Fuse the custom RMSNorm + quant ops."""
fuse_act_quant
:
bool
=
Field
(
default
=
None
)
fuse_act_quant
:
bool
|
None
=
Field
(
default
=
None
)
"""Fuse the custom SiluMul + quant ops."""
fuse_attn_quant
:
bool
=
Field
(
default
=
None
)
fuse_attn_quant
:
bool
|
None
=
Field
(
default
=
None
)
"""Fuse the custom attention + quant ops."""
eliminate_noops
:
bool
=
Field
(
default
=
True
)
"""Eliminate no-op ops."""
enable_sp
:
bool
=
Field
(
default
=
None
)
enable_sp
:
bool
|
None
=
Field
(
default
=
None
)
"""Enable sequence parallelism. Requires TP>1. Automatically disabled
if the model's hidden_size is too small for SP to be beneficial
(threshold is device-capability dependent)."""
fuse_gemm_comms
:
bool
=
Field
(
default
=
None
)
fuse_gemm_comms
:
bool
|
None
=
Field
(
default
=
None
)
"""Enable async TP."""
fuse_allreduce_rms
:
bool
=
Field
(
default
=
None
)
fuse_allreduce_rms
:
bool
|
None
=
Field
(
default
=
None
)
"""Enable flashinfer allreduce fusion."""
enable_qk_norm_rope_fusion
:
bool
=
False
"""Enable fused Q/K RMSNorm + RoPE pass."""
# ROCm/AITER specific fusions
fuse_act_padding
:
bool
=
Field
(
default
=
None
)
fuse_act_padding
:
bool
|
None
=
Field
(
default
=
None
)
"""Fuse the custom RMSNorm + padding ops."""
fuse_rope_kvcache
:
bool
=
Field
(
default
=
None
)
fuse_rope_kvcache
:
bool
|
None
=
Field
(
default
=
None
)
"""Fuse the QK rope + KV cache ops."""
rope_kvcache_fusion_max_token_num
:
int
=
256
...
...
@@ -198,9 +198,10 @@ class PassConfig:
if
not
current_platform
.
is_cuda
():
return
{}
return
FI_ALLREDUCE_FUSION_MAX_SIZE_MB
.
get
(
current_platform
.
get_device_capability
().
to_int
(),
{}
)
capability
=
current_platform
.
get_device_capability
()
if
capability
is
None
:
return
{}
return
FI_ALLREDUCE_FUSION_MAX_SIZE_MB
.
get
(
capability
.
to_int
(),
{})
def
compute_hash
(
self
)
->
str
:
"""
...
...
@@ -350,7 +351,7 @@ class DynamicShapesConfig:
from
vllm.config.utils
import
get_hash_factors
,
hash_factors
factors
=
get_hash_factors
(
self
,
{}
)
factors
=
get_hash_factors
(
self
,
set
()
)
return
hash_factors
(
factors
)
...
...
@@ -404,7 +405,7 @@ class CompilationConfig:
"""
# Top-level Compilation control
mode
:
CompilationMode
=
Field
(
default
=
None
)
mode
:
CompilationMode
=
Field
(
default
=
None
)
# type: ignore[assignment]
"""The compilation approach used for torch.compile-based compilation of the
model.
...
...
@@ -544,7 +545,7 @@ class CompilationConfig:
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
# CudaGraph compilation
cudagraph_mode
:
CUDAGraphMode
=
Field
(
default
=
None
)
cudagraph_mode
:
CUDAGraphMode
=
Field
(
default
=
None
)
# type: ignore[assignment]
"""
The mode of the cudagraph:
...
...
@@ -606,7 +607,7 @@ class CompilationConfig:
When `enable_lora` is False, this option has no effect.
"""
use_inductor_graph_partition
:
bool
=
Field
(
default
=
None
)
use_inductor_graph_partition
:
bool
=
Field
(
default
=
None
)
# type: ignore[assignment]
"""Use inductor graph partition to split the graph at cudagraph_unsafe ops.
This partition happens at inductor codegen time after all passes and fusions
are finished. It generates a single `call` function which wraps
...
...
@@ -629,7 +630,7 @@ class CompilationConfig:
pass_config
:
PassConfig
=
field
(
default_factory
=
PassConfig
)
"""Custom inductor passes, see PassConfig for more details"""
max_cudagraph_capture_size
:
int
=
field
(
default
=
None
)
max_cudagraph_capture_size
:
int
|
None
=
field
(
default
=
None
)
"""The maximum cudagraph capture size.
If cudagraph_capture_sizes is specified, this will be set to the largest
...
...
@@ -769,7 +770,9 @@ class CompilationConfig:
exclude
[
"pass_config"
]
=
pass_config_exclude
config
=
TypeAdapter
(
CompilationConfig
).
dump_python
(
self
,
exclude
=
exclude
,
exclude_unset
=
True
self
,
exclude
=
exclude
,
# type: ignore[arg-type]
exclude_unset
=
True
,
)
return
str
(
config
)
...
...
@@ -991,7 +994,7 @@ class CompilationConfig:
- initialize compile_sizes
"""
computed_compile_sizes
=
[]
computed_compile_sizes
:
list
[
int
]
=
[]
if
self
.
compile_sizes
is
not
None
:
# de-duplicate the sizes provided by the config
self
.
compile_sizes
=
list
(
set
(
self
.
compile_sizes
))
...
...
@@ -1001,6 +1004,7 @@ class CompilationConfig:
"Unrecognized size type in compile_sizes, "
f
"expect 'cudagraph_capture_sizes', got
{
x
}
"
)
assert
self
.
cudagraph_capture_sizes
is
not
None
computed_compile_sizes
.
extend
(
self
.
cudagraph_capture_sizes
)
else
:
assert
isinstance
(
x
,
int
)
...
...
@@ -1008,6 +1012,7 @@ class CompilationConfig:
self
.
compile_sizes
=
computed_compile_sizes
# type: ignore
# make sure the sizes are in ascending order
assert
self
.
cudagraph_capture_sizes
is
not
None
self
.
cudagraph_capture_sizes
.
sort
()
if
self
.
cudagraph_capture_sizes
:
assert
self
.
cudagraph_capture_sizes
[
-
1
]
==
self
.
max_cudagraph_capture_size
...
...
@@ -1099,6 +1104,7 @@ class CompilationConfig:
def
set_splitting_ops_for_attn_fusion
(
self
):
assert
self
.
pass_config
.
fuse_attn_quant
assert
self
.
cudagraph_mode
is
not
None
if
self
.
splitting_ops
is
None
:
self
.
splitting_ops
=
[]
if
self
.
cudagraph_mode
.
has_piecewise_cudagraphs
():
...
...
@@ -1290,6 +1296,4 @@ class CompilationConfig:
if
self
.
compile_ranges_endpoints
is
None
:
return
[]
endpoints
=
sorted
(
set
(
self
.
compile_ranges_endpoints
))
return
[
Range
(
start
=
s
+
1
,
end
=
e
)
for
s
,
e
in
zip
([
0
]
+
endpoints
[:
-
1
],
endpoints
)
]
return
[
Range
(
s
+
1
,
e
)
for
s
,
e
in
zip
([
0
]
+
endpoints
[:
-
1
],
endpoints
)]
vllm/config/device.py
View file @
45bd5c8e
...
...
@@ -13,8 +13,8 @@ from vllm.utils.hashing import safe_hash
Device
=
Literal
[
"auto"
,
"cuda"
,
"cpu"
,
"tpu"
,
"xpu"
]
@
config
(
config
=
ConfigDict
(
arbitrary_types_allowed
=
True
))
class
DeviceConfig
:
@
config
(
config
=
ConfigDict
(
arbitrary_types_allowed
=
True
))
# type: ignore[arg-type,misc]
class
DeviceConfig
:
# type: ignore[misc]
"""Configuration for the device to use for vLLM execution."""
device
:
SkipValidation
[
Device
|
torch
.
device
|
None
]
=
"auto"
...
...
vllm/config/kernel.py
View file @
45bd5c8e
...
...
@@ -26,7 +26,7 @@ MoEBackend = Literal[
class
KernelConfig
:
"""Configuration for kernel selection and warmup behavior."""
enable_flashinfer_autotune
:
bool
=
Field
(
default
=
None
)
enable_flashinfer_autotune
:
bool
|
None
=
Field
(
default
=
None
)
"""If True, run FlashInfer autotuning during kernel warmup."""
moe_backend
:
MoEBackend
=
"auto"
...
...
vllm/config/kv_events.py
View file @
45bd5c8e
...
...
@@ -18,7 +18,7 @@ class KVEventsConfig:
Events can be published externally by zmq using the event publisher config.
"""
publisher
:
Literal
[
"null"
,
"zmq"
]
=
Field
(
default
=
None
)
publisher
:
Literal
[
"null"
,
"zmq"
]
|
None
=
Field
(
default
=
None
)
"""The publisher to use for publishing kv events. Can be "null", "zmq".
"""
...
...
vllm/config/lora.py
View file @
45bd5c8e
...
...
@@ -25,8 +25,8 @@ MaxLoRARanks = Literal[1, 8, 16, 32, 64, 128, 256, 320, 512]
LoRAExtraVocabSize
=
Literal
[
256
,
512
]
@
config
(
config
=
ConfigDict
(
arbitrary_types_allowed
=
True
))
class
LoRAConfig
:
@
config
(
config
=
ConfigDict
(
arbitrary_types_allowed
=
True
))
# type: ignore[arg-type,misc]
class
LoRAConfig
:
# type: ignore[misc]
"""Configuration for LoRA."""
max_lora_rank
:
MaxLoRARanks
=
16
...
...
vllm/config/model.py
View file @
45bd5c8e
...
...
@@ -93,7 +93,7 @@ LayerBlockType = Literal["attention", "linear_attention", "mamba"]
_RUNNER_CONVERTS
:
dict
[
RunnerType
,
list
[
ConvertType
]]
=
{
"generate"
:
[],
"pooling"
:
[
"embed"
,
"classify"
,
"reward"
],
"pooling"
:
[
"embed"
,
"classify"
],
"draft"
:
[],
}
...
...
@@ -102,8 +102,8 @@ AttnTypeStr = Literal[
]
@
config
(
config
=
ConfigDict
(
arbitrary_types_allowed
=
True
))
class
ModelConfig
:
@
config
(
config
=
ConfigDict
(
arbitrary_types_allowed
=
True
))
# type: ignore[arg-type,misc]
class
ModelConfig
:
# type: ignore[misc]
"""Configuration for the model."""
model
:
str
=
"Qwen/Qwen3-0.6B"
...
...
@@ -121,7 +121,7 @@ class ModelConfig:
"""Convert the model using adapters defined in
[vllm.model_executor.models.adapters][]. The most common use case is to
adapt a text generation model to be used for pooling tasks."""
tokenizer
:
str
=
Field
(
default
=
None
)
tokenizer
:
str
=
Field
(
default
=
None
)
# type: ignore[assignment]
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
name or path will be used."""
tokenizer_mode
:
TokenizerMode
|
str
=
"auto"
...
...
@@ -177,7 +177,7 @@ class ModelConfig:
"""The specific revision to use for the tokenizer on the Hugging Face Hub.
It can be a branch name, a tag name, or a commit id. If unspecified, will
use the default version."""
max_model_len
:
int
=
Field
(
default
=
None
,
ge
=-
1
)
max_model_len
:
int
=
Field
(
default
=
None
,
ge
=-
1
)
# type: ignore[assignment]
"""Model context length (prompt and output). If unspecified, will be
automatically derived from the model config.
...
...
@@ -454,7 +454,7 @@ class ModelConfig:
self
.
hf_config_path
=
maybe_model_redirect
(
self
.
hf_config_path
)
if
callable
(
self
.
hf_overrides
):
hf_overrides_kw
=
{}
hf_overrides_kw
:
dict
[
str
,
Any
]
=
{}
hf_overrides_fn
=
self
.
hf_overrides
dict_overrides
:
dict
[
str
,
Any
]
=
{}
else
:
...
...
@@ -582,7 +582,7 @@ class ModelConfig:
self
.
dtype
,
is_pooling_model
=
self
.
runner_type
==
"pooling"
,
revision
=
self
.
revision
,
config_format
=
self
.
config_format
,
config_format
=
self
.
config_format
,
# type: ignore[arg-type]
)
self
.
original_max_model_len
=
self
.
max_model_len
...
...
@@ -626,7 +626,7 @@ class ModelConfig:
k
:
v
for
k
,
v
in
mm_config_kwargs
.
items
()
if
v
is
not
None
}
self
.
multimodal_config
=
MultiModalConfig
(
**
mm_config_kwargs
)
self
.
multimodal_config
=
MultiModalConfig
(
**
mm_config_kwargs
)
# type: ignore[arg-type]
# Multimodal GGUF models must use original repo for mm processing
if
is_gguf
(
self
.
tokenizer
)
and
self
.
is_multimodal_model
:
...
...
@@ -732,7 +732,7 @@ class ModelConfig:
@
property
def
architectures
(
self
)
->
list
[
str
]:
return
self
.
model_arch_config
.
architectures
return
self
.
model_arch_config
.
architectures
# type: ignore[return-value]
@
property
def
architecture
(
self
)
->
str
:
...
...
@@ -1004,7 +1004,7 @@ class ModelConfig:
is_bitsandbytes
=
self
.
quantization
==
"bitsandbytes"
has_quantization_config
=
self
.
model_arch_config
.
quantization_config
is
not
None
is_8bit
=
(
self
.
model_arch_config
.
quantization_config
.
get
(
"load_in_8bit"
,
False
)
self
.
model_arch_config
.
quantization_config
.
get
(
"load_in_8bit"
,
False
)
# type: ignore[union-attr]
if
has_quantization_config
else
False
)
...
...
@@ -1292,6 +1292,7 @@ class ModelConfig:
"attn_type_list, or a layer_types in the hf_config, "
f
"cannot determine the num of
{
block_type
}
layers"
)
raise
AssertionError
(
f
"Unsupported block type:
{
block_type
}
"
)
def
get_mamba_chunk_size
(
self
)
->
int
|
None
:
"""
...
...
vllm/config/pooler.py
View file @
45bd5c8e
...
...
@@ -108,14 +108,14 @@ class PoolerConfig:
pooling_type
,
pooling_type
,
)
self
.
seq_pooling_type
=
pooling_type
self
.
seq_pooling_type
=
pooling_type
# type: ignore[assignment]
elif
pooling_type
in
TOK_POOLING_TYPES
:
logger
.
debug
(
"Resolved `pooling_type=%r` to `tok_pooling_type=%r`."
,
pooling_type
,
pooling_type
,
)
self
.
tok_pooling_type
=
pooling_type
self
.
tok_pooling_type
=
pooling_type
# type: ignore[assignment]
else
:
raise
NotImplementedError
(
pooling_type
)
...
...
vllm/config/scheduler.py
View file @
45bd5c8e
...
...
@@ -173,7 +173,7 @@ class SchedulerConfig:
logger
.
warning_once
(
"Using custom scheduler class %s. This scheduler interface is "
"not public and compatibility may not be maintained."
,
self
.
scheduler_cls
,
self
.
scheduler_cls
,
# type: ignore[arg-type]
)
if
not
isinstance
(
self
.
scheduler_cls
,
str
):
return
cast
(
type
[
"SchedulerInterface"
],
self
.
scheduler_cls
)
...
...
vllm/config/speculative.py
View file @
45bd5c8e
...
...
@@ -67,7 +67,7 @@ class SpeculativeConfig:
enforce_eager
:
bool
|
None
=
None
"""Override the default enforce_eager from model_config"""
# General speculative decoding control
num_speculative_tokens
:
int
=
Field
(
default
=
None
,
gt
=
0
)
num_speculative_tokens
:
int
=
Field
(
default
=
None
,
gt
=
0
)
# type: ignore[assignment]
"""The number of speculative tokens, if provided. It will default to the
number in the draft model config if present, otherwise, it is required."""
model
:
str
|
None
=
None
...
...
@@ -89,7 +89,7 @@ class SpeculativeConfig:
warn users when they mistakenly provide the wrong argument."""
# Draft model configuration
quantization
:
me_quant
.
QuantizationMethods
|
None
=
None
quantization
:
me_quant
.
QuantizationMethods
|
str
|
None
=
None
"""Quantization method that was used to quantize the draft model weights.
If `None`, we assume the model weights are not quantized. Note that it only
takes effect when using the draft model-based speculative method."""
...
...
vllm/config/utils.py
View file @
45bd5c8e
...
...
@@ -11,13 +11,13 @@ import os
import
pathlib
import
textwrap
from
collections.abc
import
Callable
,
Mapping
,
Sequence
,
Set
from
dataclasses
import
MISSING
,
field
,
fields
,
is_dataclass
from
dataclasses
import
MISSING
,
dataclass
,
field
,
fields
,
is_dataclass
from
itertools
import
pairwise
from
typing
import
TYPE_CHECKING
,
Any
,
Protocol
,
TypeVar
,
cast
import
torch
from
pydantic
import
ConfigDict
from
pydantic.dataclasses
import
dataclass
from
pydantic.dataclasses
import
dataclass
as
pydantic_dataclass
from
pydantic.fields
import
Field
as
PydanticField
from
pydantic.fields
import
FieldInfo
from
typing_extensions
import
dataclass_transform
,
runtime_checkable
...
...
@@ -58,8 +58,8 @@ def config(
if
config
is
not
None
:
merged_config
.
update
(
config
)
def
decorator
(
cls
)
:
return
dataclass
(
cls
,
config
=
merged_config
,
**
kwargs
)
def
decorator
(
cls
:
type
[
ConfigT
])
->
type
[
ConfigT
]
:
return
pydantic_
dataclass
(
cls
,
config
=
merged_config
,
**
kwargs
)
# type: ignore[return-value]
# Called with arguments: @config(config=...)
if
cls
is
None
:
...
...
vllm/config/vllm.py
View file @
45bd5c8e
...
...
@@ -243,15 +243,15 @@ OPTIMIZATION_LEVEL_TO_CONFIG = {
}
@
config
(
config
=
ConfigDict
(
arbitrary_types_allowed
=
True
))
class
VllmConfig
:
@
config
(
config
=
ConfigDict
(
arbitrary_types_allowed
=
True
))
# type: ignore[arg-type,misc]
class
VllmConfig
:
# type: ignore[misc]
"""Dataclass which contains all vllm-related configuration. This
simplifies passing around the distinct configurations in the codebase.
"""
# TODO: use default_factory once default constructing ModelConfig doesn't
# try to download a model
model_config
:
ModelConfig
=
Field
(
default
=
None
)
model_config
:
ModelConfig
=
Field
(
default
=
None
)
# type: ignore[assignment]
"""Model configuration."""
cache_config
:
CacheConfig
=
Field
(
default_factory
=
CacheConfig
)
"""Cache configuration."""
...
...
@@ -883,7 +883,7 @@ class VllmConfig:
tp_size
=
self
.
parallel_config
.
tensor_parallel_size
hidden_size
=
self
.
model_config
.
get_hidden_size
()
element_size
=
self
.
model_config
.
dtype
.
itemsize
element_size
=
self
.
model_config
.
dtype
.
itemsize
# type: ignore[union-attr]
pass_config
.
sp_min_token_num
=
get_sequence_parallelism_threshold
(
hidden_size
,
tp_size
,
element_size
)
...
...
@@ -1061,7 +1061,7 @@ class VllmConfig:
is_fullgraph
=
(
self
.
compilation_config
.
use_inductor_graph_partition
or
len
(
self
.
compilation_config
.
splitting_ops
)
==
0
or
len
(
self
.
compilation_config
.
splitting_ops
or
[]
)
==
0
)
if
self
.
parallel_config
.
pipeline_parallel_size
>
1
or
not
is_fullgraph
:
if
"-rms_norm"
not
in
self
.
compilation_config
.
custom_ops
:
...
...
@@ -1216,7 +1216,7 @@ class VllmConfig:
)
self
.
compilation_config
.
debug_dump_path
=
env_path
def
has_blocked_weights
():
def
has_blocked_weights
():
# type: ignore[no-redef]
if
self
.
quant_config
is
not
None
:
if
hasattr
(
self
.
quant_config
,
"weight_block_size"
):
return
self
.
quant_config
.
weight_block_size
is
not
None
...
...
@@ -1474,7 +1474,7 @@ class VllmConfig:
if
max_size
is
not
None
:
max_token_num
=
max_size
//
(
self
.
model_config
.
get_hidden_size
()
*
self
.
model_config
.
dtype
.
itemsize
*
self
.
model_config
.
dtype
.
itemsize
# type: ignore[union-attr]
)
if
compile_range_end
is
not
None
and
max_token_num
<
compile_range_end
:
computed_compile_ranges_endpoints
.
append
(
max_token_num
)
...
...
@@ -1497,7 +1497,7 @@ class VllmConfig:
tp_size
=
self
.
parallel_config
.
tensor_parallel_size
hidden_size
=
self
.
model_config
.
get_hidden_size
()
element_size
=
self
.
model_config
.
dtype
.
itemsize
element_size
=
self
.
model_config
.
dtype
.
itemsize
# type: ignore[union-attr]
pass_config
.
sp_min_token_num
=
get_sequence_parallelism_threshold
(
hidden_size
,
tp_size
,
element_size
)
...
...
vllm/engine/arg_utils.py
View file @
45bd5c8e
...
...
@@ -1924,7 +1924,7 @@ class EngineArgs:
)
offload_config
=
OffloadConfig
(
offload_backend
=
self
.
offload_backend
,
offload_backend
=
self
.
offload_backend
,
# type: ignore[arg-type]
uva
=
UVAOffloadConfig
(
cpu_offload_gb
=
self
.
cpu_offload_gb
,
cpu_offload_params
=
self
.
cpu_offload_params
,
...
...
vllm/v1/cudagraph_dispatcher.py
View file @
45bd5c8e
...
...
@@ -72,6 +72,9 @@ class CudagraphDispatcher:
"""Pre-compute the mapping from batch size to padded graph size."""
max_size
=
self
.
compilation_config
.
max_cudagraph_capture_size
capture_sizes
=
self
.
compilation_config
.
cudagraph_capture_sizes
assert
max_size
is
not
None
,
(
"Maximum cudagraph capture size must be set when cudagraphs are enabled."
)
assert
capture_sizes
is
not
None
,
(
"Cudagraph capture sizes must be set when cudagraphs are enabled."
)
...
...
@@ -94,7 +97,7 @@ class CudagraphDispatcher:
):
for
size
in
self
.
compilation_config
.
compile_sizes
:
size
=
int
(
size
)
if
size
<=
self
.
compilation_config
.
max_cudagraph_capture
_size
:
if
size
<=
max
_size
:
padded
=
self
.
_bs_to_padded_graph_size
[
size
]
if
padded
!=
size
:
raise
ValueError
(
...
...
@@ -265,11 +268,13 @@ class CudagraphDispatcher:
f
"No allowed cudagraph modes: valid_modes=
{
valid_modes
}
, "
f
"invalid_modes=
{
invalid_modes
}
"
)
max_size
=
self
.
compilation_config
.
max_cudagraph_capture_size
if
(
not
self
.
keys_initialized
or
self
.
cudagraph_mode
==
CUDAGraphMode
.
NONE
or
num_tokens
>
self
.
compilation_config
.
max_cudagraph_capture_size
or
max_size
is
None
or
num_tokens
>
max_size
or
allowed_modes
<=
{
CUDAGraphMode
.
NONE
}
):
return
CUDAGraphMode
.
NONE
,
BatchDescriptor
(
num_tokens
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment