Unverified Commit 45bd5c8e authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Mypy] Fix mypy for `vllm/config` (#37808)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 10a1018c
...@@ -40,7 +40,6 @@ EXCLUDE = [ ...@@ -40,7 +40,6 @@ EXCLUDE = [
"vllm/v1/attention/ops", "vllm/v1/attention/ops",
# TODO: Remove these entries after fixing mypy errors. # TODO: Remove these entries after fixing mypy errors.
"vllm/benchmarks", "vllm/benchmarks",
"vllm/config",
] ]
......
...@@ -56,7 +56,7 @@ class AttentionConfig: ...@@ -56,7 +56,7 @@ class AttentionConfig:
""" """
from vllm.config.utils import get_hash_factors, hash_factors from vllm.config.utils import get_hash_factors, hash_factors
ignored_factors: list[str] = [] ignored_factors: set[str] = set()
factors = get_hash_factors(self, ignored_factors) factors = get_hash_factors(self, ignored_factors)
return hash_factors(factors) return hash_factors(factors)
......
...@@ -116,29 +116,29 @@ class PassConfig: ...@@ -116,29 +116,29 @@ class PassConfig:
""" """
# New flags # New flags
fuse_norm_quant: bool = Field(default=None) fuse_norm_quant: bool | None = Field(default=None)
"""Fuse the custom RMSNorm + quant ops.""" """Fuse the custom RMSNorm + quant ops."""
fuse_act_quant: bool = Field(default=None) fuse_act_quant: bool | None = Field(default=None)
"""Fuse the custom SiluMul + quant ops.""" """Fuse the custom SiluMul + quant ops."""
fuse_attn_quant: bool = Field(default=None) fuse_attn_quant: bool | None = Field(default=None)
"""Fuse the custom attention + quant ops.""" """Fuse the custom attention + quant ops."""
eliminate_noops: bool = Field(default=True) eliminate_noops: bool = Field(default=True)
"""Eliminate no-op ops.""" """Eliminate no-op ops."""
enable_sp: bool = Field(default=None) enable_sp: bool | None = Field(default=None)
"""Enable sequence parallelism. Requires TP>1. Automatically disabled """Enable sequence parallelism. Requires TP>1. Automatically disabled
if the model's hidden_size is too small for SP to be beneficial if the model's hidden_size is too small for SP to be beneficial
(threshold is device-capability dependent).""" (threshold is device-capability dependent)."""
fuse_gemm_comms: bool = Field(default=None) fuse_gemm_comms: bool | None = Field(default=None)
"""Enable async TP.""" """Enable async TP."""
fuse_allreduce_rms: bool = Field(default=None) fuse_allreduce_rms: bool | None = Field(default=None)
"""Enable flashinfer allreduce fusion.""" """Enable flashinfer allreduce fusion."""
enable_qk_norm_rope_fusion: bool = False enable_qk_norm_rope_fusion: bool = False
"""Enable fused Q/K RMSNorm + RoPE pass.""" """Enable fused Q/K RMSNorm + RoPE pass."""
# ROCm/AITER specific fusions # ROCm/AITER specific fusions
fuse_act_padding: bool = Field(default=None) fuse_act_padding: bool | None = Field(default=None)
"""Fuse the custom RMSNorm + padding ops.""" """Fuse the custom RMSNorm + padding ops."""
fuse_rope_kvcache: bool = Field(default=None) fuse_rope_kvcache: bool | None = Field(default=None)
"""Fuse the QK rope + KV cache ops.""" """Fuse the QK rope + KV cache ops."""
rope_kvcache_fusion_max_token_num: int = 256 rope_kvcache_fusion_max_token_num: int = 256
...@@ -198,9 +198,10 @@ class PassConfig: ...@@ -198,9 +198,10 @@ class PassConfig:
if not current_platform.is_cuda(): if not current_platform.is_cuda():
return {} return {}
return FI_ALLREDUCE_FUSION_MAX_SIZE_MB.get( capability = current_platform.get_device_capability()
current_platform.get_device_capability().to_int(), {} if capability is None:
) return {}
return FI_ALLREDUCE_FUSION_MAX_SIZE_MB.get(capability.to_int(), {})
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
...@@ -350,7 +351,7 @@ class DynamicShapesConfig: ...@@ -350,7 +351,7 @@ class DynamicShapesConfig:
from vllm.config.utils import get_hash_factors, hash_factors from vllm.config.utils import get_hash_factors, hash_factors
factors = get_hash_factors(self, {}) factors = get_hash_factors(self, set())
return hash_factors(factors) return hash_factors(factors)
...@@ -404,7 +405,7 @@ class CompilationConfig: ...@@ -404,7 +405,7 @@ class CompilationConfig:
""" """
# Top-level Compilation control # Top-level Compilation control
mode: CompilationMode = Field(default=None) mode: CompilationMode = Field(default=None) # type: ignore[assignment]
"""The compilation approach used for torch.compile-based compilation of the """The compilation approach used for torch.compile-based compilation of the
model. model.
...@@ -544,7 +545,7 @@ class CompilationConfig: ...@@ -544,7 +545,7 @@ class CompilationConfig:
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`.""" constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
# CudaGraph compilation # CudaGraph compilation
cudagraph_mode: CUDAGraphMode = Field(default=None) cudagraph_mode: CUDAGraphMode = Field(default=None) # type: ignore[assignment]
""" """
The mode of the cudagraph: The mode of the cudagraph:
...@@ -606,7 +607,7 @@ class CompilationConfig: ...@@ -606,7 +607,7 @@ class CompilationConfig:
When `enable_lora` is False, this option has no effect. When `enable_lora` is False, this option has no effect.
""" """
use_inductor_graph_partition: bool = Field(default=None) use_inductor_graph_partition: bool = Field(default=None) # type: ignore[assignment]
"""Use inductor graph partition to split the graph at cudagraph_unsafe ops. """Use inductor graph partition to split the graph at cudagraph_unsafe ops.
This partition happens at inductor codegen time after all passes and fusions This partition happens at inductor codegen time after all passes and fusions
are finished. It generates a single `call` function which wraps are finished. It generates a single `call` function which wraps
...@@ -629,7 +630,7 @@ class CompilationConfig: ...@@ -629,7 +630,7 @@ class CompilationConfig:
pass_config: PassConfig = field(default_factory=PassConfig) pass_config: PassConfig = field(default_factory=PassConfig)
"""Custom inductor passes, see PassConfig for more details""" """Custom inductor passes, see PassConfig for more details"""
max_cudagraph_capture_size: int = field(default=None) max_cudagraph_capture_size: int | None = field(default=None)
"""The maximum cudagraph capture size. """The maximum cudagraph capture size.
If cudagraph_capture_sizes is specified, this will be set to the largest If cudagraph_capture_sizes is specified, this will be set to the largest
...@@ -769,7 +770,9 @@ class CompilationConfig: ...@@ -769,7 +770,9 @@ class CompilationConfig:
exclude["pass_config"] = pass_config_exclude exclude["pass_config"] = pass_config_exclude
config = TypeAdapter(CompilationConfig).dump_python( config = TypeAdapter(CompilationConfig).dump_python(
self, exclude=exclude, exclude_unset=True self,
exclude=exclude, # type: ignore[arg-type]
exclude_unset=True,
) )
return str(config) return str(config)
...@@ -991,7 +994,7 @@ class CompilationConfig: ...@@ -991,7 +994,7 @@ class CompilationConfig:
- initialize compile_sizes - initialize compile_sizes
""" """
computed_compile_sizes = [] computed_compile_sizes: list[int] = []
if self.compile_sizes is not None: if self.compile_sizes is not None:
# de-duplicate the sizes provided by the config # de-duplicate the sizes provided by the config
self.compile_sizes = list(set(self.compile_sizes)) self.compile_sizes = list(set(self.compile_sizes))
...@@ -1001,6 +1004,7 @@ class CompilationConfig: ...@@ -1001,6 +1004,7 @@ class CompilationConfig:
"Unrecognized size type in compile_sizes, " "Unrecognized size type in compile_sizes, "
f"expect 'cudagraph_capture_sizes', got {x}" f"expect 'cudagraph_capture_sizes', got {x}"
) )
assert self.cudagraph_capture_sizes is not None
computed_compile_sizes.extend(self.cudagraph_capture_sizes) computed_compile_sizes.extend(self.cudagraph_capture_sizes)
else: else:
assert isinstance(x, int) assert isinstance(x, int)
...@@ -1008,6 +1012,7 @@ class CompilationConfig: ...@@ -1008,6 +1012,7 @@ class CompilationConfig:
self.compile_sizes = computed_compile_sizes # type: ignore self.compile_sizes = computed_compile_sizes # type: ignore
# make sure the sizes are in ascending order # make sure the sizes are in ascending order
assert self.cudagraph_capture_sizes is not None
self.cudagraph_capture_sizes.sort() self.cudagraph_capture_sizes.sort()
if self.cudagraph_capture_sizes: if self.cudagraph_capture_sizes:
assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size
...@@ -1099,6 +1104,7 @@ class CompilationConfig: ...@@ -1099,6 +1104,7 @@ class CompilationConfig:
def set_splitting_ops_for_attn_fusion(self): def set_splitting_ops_for_attn_fusion(self):
assert self.pass_config.fuse_attn_quant assert self.pass_config.fuse_attn_quant
assert self.cudagraph_mode is not None
if self.splitting_ops is None: if self.splitting_ops is None:
self.splitting_ops = [] self.splitting_ops = []
if self.cudagraph_mode.has_piecewise_cudagraphs(): if self.cudagraph_mode.has_piecewise_cudagraphs():
...@@ -1290,6 +1296,4 @@ class CompilationConfig: ...@@ -1290,6 +1296,4 @@ class CompilationConfig:
if self.compile_ranges_endpoints is None: if self.compile_ranges_endpoints is None:
return [] return []
endpoints = sorted(set(self.compile_ranges_endpoints)) endpoints = sorted(set(self.compile_ranges_endpoints))
return [ return [Range(s + 1, e) for s, e in zip([0] + endpoints[:-1], endpoints)]
Range(start=s + 1, end=e) for s, e in zip([0] + endpoints[:-1], endpoints)
]
...@@ -13,8 +13,8 @@ from vllm.utils.hashing import safe_hash ...@@ -13,8 +13,8 @@ from vllm.utils.hashing import safe_hash
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"] Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
@config(config=ConfigDict(arbitrary_types_allowed=True)) @config(config=ConfigDict(arbitrary_types_allowed=True)) # type: ignore[arg-type,misc]
class DeviceConfig: class DeviceConfig: # type: ignore[misc]
"""Configuration for the device to use for vLLM execution.""" """Configuration for the device to use for vLLM execution."""
device: SkipValidation[Device | torch.device | None] = "auto" device: SkipValidation[Device | torch.device | None] = "auto"
......
...@@ -26,7 +26,7 @@ MoEBackend = Literal[ ...@@ -26,7 +26,7 @@ MoEBackend = Literal[
class KernelConfig: class KernelConfig:
"""Configuration for kernel selection and warmup behavior.""" """Configuration for kernel selection and warmup behavior."""
enable_flashinfer_autotune: bool = Field(default=None) enable_flashinfer_autotune: bool | None = Field(default=None)
"""If True, run FlashInfer autotuning during kernel warmup.""" """If True, run FlashInfer autotuning during kernel warmup."""
moe_backend: MoEBackend = "auto" moe_backend: MoEBackend = "auto"
......
...@@ -18,7 +18,7 @@ class KVEventsConfig: ...@@ -18,7 +18,7 @@ class KVEventsConfig:
Events can be published externally by zmq using the event publisher config. Events can be published externally by zmq using the event publisher config.
""" """
publisher: Literal["null", "zmq"] = Field(default=None) publisher: Literal["null", "zmq"] | None = Field(default=None)
"""The publisher to use for publishing kv events. Can be "null", "zmq". """The publisher to use for publishing kv events. Can be "null", "zmq".
""" """
......
...@@ -25,8 +25,8 @@ MaxLoRARanks = Literal[1, 8, 16, 32, 64, 128, 256, 320, 512] ...@@ -25,8 +25,8 @@ MaxLoRARanks = Literal[1, 8, 16, 32, 64, 128, 256, 320, 512]
LoRAExtraVocabSize = Literal[256, 512] LoRAExtraVocabSize = Literal[256, 512]
@config(config=ConfigDict(arbitrary_types_allowed=True)) @config(config=ConfigDict(arbitrary_types_allowed=True)) # type: ignore[arg-type,misc]
class LoRAConfig: class LoRAConfig: # type: ignore[misc]
"""Configuration for LoRA.""" """Configuration for LoRA."""
max_lora_rank: MaxLoRARanks = 16 max_lora_rank: MaxLoRARanks = 16
......
...@@ -93,7 +93,7 @@ LayerBlockType = Literal["attention", "linear_attention", "mamba"] ...@@ -93,7 +93,7 @@ LayerBlockType = Literal["attention", "linear_attention", "mamba"]
_RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = { _RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = {
"generate": [], "generate": [],
"pooling": ["embed", "classify", "reward"], "pooling": ["embed", "classify"],
"draft": [], "draft": [],
} }
...@@ -102,8 +102,8 @@ AttnTypeStr = Literal[ ...@@ -102,8 +102,8 @@ AttnTypeStr = Literal[
] ]
@config(config=ConfigDict(arbitrary_types_allowed=True)) @config(config=ConfigDict(arbitrary_types_allowed=True)) # type: ignore[arg-type,misc]
class ModelConfig: class ModelConfig: # type: ignore[misc]
"""Configuration for the model.""" """Configuration for the model."""
model: str = "Qwen/Qwen3-0.6B" model: str = "Qwen/Qwen3-0.6B"
...@@ -121,7 +121,7 @@ class ModelConfig: ...@@ -121,7 +121,7 @@ class ModelConfig:
"""Convert the model using adapters defined in """Convert the model using adapters defined in
[vllm.model_executor.models.adapters][]. The most common use case is to [vllm.model_executor.models.adapters][]. The most common use case is to
adapt a text generation model to be used for pooling tasks.""" adapt a text generation model to be used for pooling tasks."""
tokenizer: str = Field(default=None) tokenizer: str = Field(default=None) # type: ignore[assignment]
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model """Name or path of the Hugging Face tokenizer to use. If unspecified, model
name or path will be used.""" name or path will be used."""
tokenizer_mode: TokenizerMode | str = "auto" tokenizer_mode: TokenizerMode | str = "auto"
...@@ -177,7 +177,7 @@ class ModelConfig: ...@@ -177,7 +177,7 @@ class ModelConfig:
"""The specific revision to use for the tokenizer on the Hugging Face Hub. """The specific revision to use for the tokenizer on the Hugging Face Hub.
It can be a branch name, a tag name, or a commit id. If unspecified, will It can be a branch name, a tag name, or a commit id. If unspecified, will
use the default version.""" use the default version."""
max_model_len: int = Field(default=None, ge=-1) max_model_len: int = Field(default=None, ge=-1) # type: ignore[assignment]
"""Model context length (prompt and output). If unspecified, will be """Model context length (prompt and output). If unspecified, will be
automatically derived from the model config. automatically derived from the model config.
...@@ -454,7 +454,7 @@ class ModelConfig: ...@@ -454,7 +454,7 @@ class ModelConfig:
self.hf_config_path = maybe_model_redirect(self.hf_config_path) self.hf_config_path = maybe_model_redirect(self.hf_config_path)
if callable(self.hf_overrides): if callable(self.hf_overrides):
hf_overrides_kw = {} hf_overrides_kw: dict[str, Any] = {}
hf_overrides_fn = self.hf_overrides hf_overrides_fn = self.hf_overrides
dict_overrides: dict[str, Any] = {} dict_overrides: dict[str, Any] = {}
else: else:
...@@ -582,7 +582,7 @@ class ModelConfig: ...@@ -582,7 +582,7 @@ class ModelConfig:
self.dtype, self.dtype,
is_pooling_model=self.runner_type == "pooling", is_pooling_model=self.runner_type == "pooling",
revision=self.revision, revision=self.revision,
config_format=self.config_format, config_format=self.config_format, # type: ignore[arg-type]
) )
self.original_max_model_len = self.max_model_len self.original_max_model_len = self.max_model_len
...@@ -626,7 +626,7 @@ class ModelConfig: ...@@ -626,7 +626,7 @@ class ModelConfig:
k: v for k, v in mm_config_kwargs.items() if v is not None k: v for k, v in mm_config_kwargs.items() if v is not None
} }
self.multimodal_config = MultiModalConfig(**mm_config_kwargs) self.multimodal_config = MultiModalConfig(**mm_config_kwargs) # type: ignore[arg-type]
# Multimodal GGUF models must use original repo for mm processing # Multimodal GGUF models must use original repo for mm processing
if is_gguf(self.tokenizer) and self.is_multimodal_model: if is_gguf(self.tokenizer) and self.is_multimodal_model:
...@@ -732,7 +732,7 @@ class ModelConfig: ...@@ -732,7 +732,7 @@ class ModelConfig:
@property @property
def architectures(self) -> list[str]: def architectures(self) -> list[str]:
return self.model_arch_config.architectures return self.model_arch_config.architectures # type: ignore[return-value]
@property @property
def architecture(self) -> str: def architecture(self) -> str:
...@@ -1004,7 +1004,7 @@ class ModelConfig: ...@@ -1004,7 +1004,7 @@ class ModelConfig:
is_bitsandbytes = self.quantization == "bitsandbytes" is_bitsandbytes = self.quantization == "bitsandbytes"
has_quantization_config = self.model_arch_config.quantization_config is not None has_quantization_config = self.model_arch_config.quantization_config is not None
is_8bit = ( is_8bit = (
self.model_arch_config.quantization_config.get("load_in_8bit", False) self.model_arch_config.quantization_config.get("load_in_8bit", False) # type: ignore[union-attr]
if has_quantization_config if has_quantization_config
else False else False
) )
...@@ -1292,6 +1292,7 @@ class ModelConfig: ...@@ -1292,6 +1292,7 @@ class ModelConfig:
"attn_type_list, or a layer_types in the hf_config, " "attn_type_list, or a layer_types in the hf_config, "
f"cannot determine the num of {block_type} layers" f"cannot determine the num of {block_type} layers"
) )
raise AssertionError(f"Unsupported block type: {block_type}")
def get_mamba_chunk_size(self) -> int | None: def get_mamba_chunk_size(self) -> int | None:
""" """
......
...@@ -108,14 +108,14 @@ class PoolerConfig: ...@@ -108,14 +108,14 @@ class PoolerConfig:
pooling_type, pooling_type,
pooling_type, pooling_type,
) )
self.seq_pooling_type = pooling_type self.seq_pooling_type = pooling_type # type: ignore[assignment]
elif pooling_type in TOK_POOLING_TYPES: elif pooling_type in TOK_POOLING_TYPES:
logger.debug( logger.debug(
"Resolved `pooling_type=%r` to `tok_pooling_type=%r`.", "Resolved `pooling_type=%r` to `tok_pooling_type=%r`.",
pooling_type, pooling_type,
pooling_type, pooling_type,
) )
self.tok_pooling_type = pooling_type self.tok_pooling_type = pooling_type # type: ignore[assignment]
else: else:
raise NotImplementedError(pooling_type) raise NotImplementedError(pooling_type)
......
...@@ -173,7 +173,7 @@ class SchedulerConfig: ...@@ -173,7 +173,7 @@ class SchedulerConfig:
logger.warning_once( logger.warning_once(
"Using custom scheduler class %s. This scheduler interface is " "Using custom scheduler class %s. This scheduler interface is "
"not public and compatibility may not be maintained.", "not public and compatibility may not be maintained.",
self.scheduler_cls, self.scheduler_cls, # type: ignore[arg-type]
) )
if not isinstance(self.scheduler_cls, str): if not isinstance(self.scheduler_cls, str):
return cast(type["SchedulerInterface"], self.scheduler_cls) return cast(type["SchedulerInterface"], self.scheduler_cls)
......
...@@ -67,7 +67,7 @@ class SpeculativeConfig: ...@@ -67,7 +67,7 @@ class SpeculativeConfig:
enforce_eager: bool | None = None enforce_eager: bool | None = None
"""Override the default enforce_eager from model_config""" """Override the default enforce_eager from model_config"""
# General speculative decoding control # General speculative decoding control
num_speculative_tokens: int = Field(default=None, gt=0) num_speculative_tokens: int = Field(default=None, gt=0) # type: ignore[assignment]
"""The number of speculative tokens, if provided. It will default to the """The number of speculative tokens, if provided. It will default to the
number in the draft model config if present, otherwise, it is required.""" number in the draft model config if present, otherwise, it is required."""
model: str | None = None model: str | None = None
...@@ -89,7 +89,7 @@ class SpeculativeConfig: ...@@ -89,7 +89,7 @@ class SpeculativeConfig:
warn users when they mistakenly provide the wrong argument.""" warn users when they mistakenly provide the wrong argument."""
# Draft model configuration # Draft model configuration
quantization: me_quant.QuantizationMethods | None = None quantization: me_quant.QuantizationMethods | str | None = None
"""Quantization method that was used to quantize the draft model weights. """Quantization method that was used to quantize the draft model weights.
If `None`, we assume the model weights are not quantized. Note that it only If `None`, we assume the model weights are not quantized. Note that it only
takes effect when using the draft model-based speculative method.""" takes effect when using the draft model-based speculative method."""
......
...@@ -11,13 +11,13 @@ import os ...@@ -11,13 +11,13 @@ import os
import pathlib import pathlib
import textwrap import textwrap
from collections.abc import Callable, Mapping, Sequence, Set from collections.abc import Callable, Mapping, Sequence, Set
from dataclasses import MISSING, field, fields, is_dataclass from dataclasses import MISSING, dataclass, field, fields, is_dataclass
from itertools import pairwise from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
import torch import torch
from pydantic import ConfigDict from pydantic import ConfigDict
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass as pydantic_dataclass
from pydantic.fields import Field as PydanticField from pydantic.fields import Field as PydanticField
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from typing_extensions import dataclass_transform, runtime_checkable from typing_extensions import dataclass_transform, runtime_checkable
...@@ -58,8 +58,8 @@ def config( ...@@ -58,8 +58,8 @@ def config(
if config is not None: if config is not None:
merged_config.update(config) merged_config.update(config)
def decorator(cls): def decorator(cls: type[ConfigT]) -> type[ConfigT]:
return dataclass(cls, config=merged_config, **kwargs) return pydantic_dataclass(cls, config=merged_config, **kwargs) # type: ignore[return-value]
# Called with arguments: @config(config=...) # Called with arguments: @config(config=...)
if cls is None: if cls is None:
......
...@@ -243,15 +243,15 @@ OPTIMIZATION_LEVEL_TO_CONFIG = { ...@@ -243,15 +243,15 @@ OPTIMIZATION_LEVEL_TO_CONFIG = {
} }
@config(config=ConfigDict(arbitrary_types_allowed=True)) @config(config=ConfigDict(arbitrary_types_allowed=True)) # type: ignore[arg-type,misc]
class VllmConfig: class VllmConfig: # type: ignore[misc]
"""Dataclass which contains all vllm-related configuration. This """Dataclass which contains all vllm-related configuration. This
simplifies passing around the distinct configurations in the codebase. simplifies passing around the distinct configurations in the codebase.
""" """
# TODO: use default_factory once default constructing ModelConfig doesn't # TODO: use default_factory once default constructing ModelConfig doesn't
# try to download a model # try to download a model
model_config: ModelConfig = Field(default=None) model_config: ModelConfig = Field(default=None) # type: ignore[assignment]
"""Model configuration.""" """Model configuration."""
cache_config: CacheConfig = Field(default_factory=CacheConfig) cache_config: CacheConfig = Field(default_factory=CacheConfig)
"""Cache configuration.""" """Cache configuration."""
...@@ -883,7 +883,7 @@ class VllmConfig: ...@@ -883,7 +883,7 @@ class VllmConfig:
tp_size = self.parallel_config.tensor_parallel_size tp_size = self.parallel_config.tensor_parallel_size
hidden_size = self.model_config.get_hidden_size() hidden_size = self.model_config.get_hidden_size()
element_size = self.model_config.dtype.itemsize element_size = self.model_config.dtype.itemsize # type: ignore[union-attr]
pass_config.sp_min_token_num = get_sequence_parallelism_threshold( pass_config.sp_min_token_num = get_sequence_parallelism_threshold(
hidden_size, tp_size, element_size hidden_size, tp_size, element_size
) )
...@@ -1061,7 +1061,7 @@ class VllmConfig: ...@@ -1061,7 +1061,7 @@ class VllmConfig:
is_fullgraph = ( is_fullgraph = (
self.compilation_config.use_inductor_graph_partition self.compilation_config.use_inductor_graph_partition
or len(self.compilation_config.splitting_ops) == 0 or len(self.compilation_config.splitting_ops or []) == 0
) )
if self.parallel_config.pipeline_parallel_size > 1 or not is_fullgraph: if self.parallel_config.pipeline_parallel_size > 1 or not is_fullgraph:
if "-rms_norm" not in self.compilation_config.custom_ops: if "-rms_norm" not in self.compilation_config.custom_ops:
...@@ -1216,7 +1216,7 @@ class VllmConfig: ...@@ -1216,7 +1216,7 @@ class VllmConfig:
) )
self.compilation_config.debug_dump_path = env_path self.compilation_config.debug_dump_path = env_path
def has_blocked_weights(): def has_blocked_weights(): # type: ignore[no-redef]
if self.quant_config is not None: if self.quant_config is not None:
if hasattr(self.quant_config, "weight_block_size"): if hasattr(self.quant_config, "weight_block_size"):
return self.quant_config.weight_block_size is not None return self.quant_config.weight_block_size is not None
...@@ -1474,7 +1474,7 @@ class VllmConfig: ...@@ -1474,7 +1474,7 @@ class VllmConfig:
if max_size is not None: if max_size is not None:
max_token_num = max_size // ( max_token_num = max_size // (
self.model_config.get_hidden_size() self.model_config.get_hidden_size()
* self.model_config.dtype.itemsize * self.model_config.dtype.itemsize # type: ignore[union-attr]
) )
if compile_range_end is not None and max_token_num < compile_range_end: if compile_range_end is not None and max_token_num < compile_range_end:
computed_compile_ranges_endpoints.append(max_token_num) computed_compile_ranges_endpoints.append(max_token_num)
...@@ -1497,7 +1497,7 @@ class VllmConfig: ...@@ -1497,7 +1497,7 @@ class VllmConfig:
tp_size = self.parallel_config.tensor_parallel_size tp_size = self.parallel_config.tensor_parallel_size
hidden_size = self.model_config.get_hidden_size() hidden_size = self.model_config.get_hidden_size()
element_size = self.model_config.dtype.itemsize element_size = self.model_config.dtype.itemsize # type: ignore[union-attr]
pass_config.sp_min_token_num = get_sequence_parallelism_threshold( pass_config.sp_min_token_num = get_sequence_parallelism_threshold(
hidden_size, tp_size, element_size hidden_size, tp_size, element_size
) )
......
...@@ -1924,7 +1924,7 @@ class EngineArgs: ...@@ -1924,7 +1924,7 @@ class EngineArgs:
) )
offload_config = OffloadConfig( offload_config = OffloadConfig(
offload_backend=self.offload_backend, offload_backend=self.offload_backend, # type: ignore[arg-type]
uva=UVAOffloadConfig( uva=UVAOffloadConfig(
cpu_offload_gb=self.cpu_offload_gb, cpu_offload_gb=self.cpu_offload_gb,
cpu_offload_params=self.cpu_offload_params, cpu_offload_params=self.cpu_offload_params,
......
...@@ -72,6 +72,9 @@ class CudagraphDispatcher: ...@@ -72,6 +72,9 @@ class CudagraphDispatcher:
"""Pre-compute the mapping from batch size to padded graph size.""" """Pre-compute the mapping from batch size to padded graph size."""
max_size = self.compilation_config.max_cudagraph_capture_size max_size = self.compilation_config.max_cudagraph_capture_size
capture_sizes = self.compilation_config.cudagraph_capture_sizes capture_sizes = self.compilation_config.cudagraph_capture_sizes
assert max_size is not None, (
"Maximum cudagraph capture size must be set when cudagraphs are enabled."
)
assert capture_sizes is not None, ( assert capture_sizes is not None, (
"Cudagraph capture sizes must be set when cudagraphs are enabled." "Cudagraph capture sizes must be set when cudagraphs are enabled."
) )
...@@ -94,7 +97,7 @@ class CudagraphDispatcher: ...@@ -94,7 +97,7 @@ class CudagraphDispatcher:
): ):
for size in self.compilation_config.compile_sizes: for size in self.compilation_config.compile_sizes:
size = int(size) size = int(size)
if size <= self.compilation_config.max_cudagraph_capture_size: if size <= max_size:
padded = self._bs_to_padded_graph_size[size] padded = self._bs_to_padded_graph_size[size]
if padded != size: if padded != size:
raise ValueError( raise ValueError(
...@@ -265,11 +268,13 @@ class CudagraphDispatcher: ...@@ -265,11 +268,13 @@ class CudagraphDispatcher:
f"No allowed cudagraph modes: valid_modes={valid_modes}, " f"No allowed cudagraph modes: valid_modes={valid_modes}, "
f"invalid_modes={invalid_modes}" f"invalid_modes={invalid_modes}"
) )
max_size = self.compilation_config.max_cudagraph_capture_size
if ( if (
not self.keys_initialized not self.keys_initialized
or self.cudagraph_mode == CUDAGraphMode.NONE or self.cudagraph_mode == CUDAGraphMode.NONE
or num_tokens > self.compilation_config.max_cudagraph_capture_size or max_size is None
or num_tokens > max_size
or allowed_modes <= {CUDAGraphMode.NONE} or allowed_modes <= {CUDAGraphMode.NONE}
): ):
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens) return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment