Unverified Commit d215d1ef authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

[Mypy] Better fixes for the `mypy` issues in `vllm/config` (#37902)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 34d317dc
...@@ -8,7 +8,6 @@ import os ...@@ -8,7 +8,6 @@ import os
import random import random
import time import time
import warnings import warnings
from dataclasses import fields
from typing import Any from typing import Any
import torch import torch
...@@ -53,7 +52,7 @@ def run_vllm( ...@@ -53,7 +52,7 @@ def run_vllm(
) -> tuple[float, list[RequestOutput] | None]: ) -> tuple[float, list[RequestOutput] | None]:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)}) llm = LLM.from_engine_args(engine_args)
assert all( assert all(
llm.llm_engine.model_config.max_model_len llm.llm_engine.model_config.max_model_len
>= (request.prompt_len + request.expected_output_len) >= (request.prompt_len + request.expected_output_len)
...@@ -141,7 +140,7 @@ def run_vllm_chat( ...@@ -141,7 +140,7 @@ def run_vllm_chat(
""" """
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)}) llm = LLM.from_engine_args(engine_args)
assert all( assert all(
llm.llm_engine.model_config.max_model_len llm.llm_engine.model_config.max_model_len
......
...@@ -116,29 +116,29 @@ class PassConfig: ...@@ -116,29 +116,29 @@ class PassConfig:
""" """
# New flags # New flags
fuse_norm_quant: bool | None = Field(default=None) fuse_norm_quant: bool = None # type: ignore[assignment]
"""Fuse the custom RMSNorm + quant ops.""" """Fuse the custom RMSNorm + quant ops."""
fuse_act_quant: bool | None = Field(default=None) fuse_act_quant: bool = None # type: ignore[assignment]
"""Fuse the custom SiluMul + quant ops.""" """Fuse the custom SiluMul + quant ops."""
fuse_attn_quant: bool | None = Field(default=None) fuse_attn_quant: bool = None # type: ignore[assignment]
"""Fuse the custom attention + quant ops.""" """Fuse the custom attention + quant ops."""
eliminate_noops: bool = Field(default=True) eliminate_noops: bool = Field(default=True)
"""Eliminate no-op ops.""" """Eliminate no-op ops."""
enable_sp: bool | None = Field(default=None) enable_sp: bool = None # type: ignore[assignment]
"""Enable sequence parallelism. Requires TP>1. Automatically disabled """Enable sequence parallelism. Requires TP>1. Automatically disabled
if the model's hidden_size is too small for SP to be beneficial if the model's hidden_size is too small for SP to be beneficial
(threshold is device-capability dependent).""" (threshold is device-capability dependent)."""
fuse_gemm_comms: bool | None = Field(default=None) fuse_gemm_comms: bool = None # type: ignore[assignment]
"""Enable async TP.""" """Enable async TP."""
fuse_allreduce_rms: bool | None = Field(default=None) fuse_allreduce_rms: bool = None # type: ignore[assignment]
"""Enable flashinfer allreduce fusion.""" """Enable flashinfer allreduce fusion."""
enable_qk_norm_rope_fusion: bool = False enable_qk_norm_rope_fusion: bool = False
"""Enable fused Q/K RMSNorm + RoPE pass.""" """Enable fused Q/K RMSNorm + RoPE pass."""
# ROCm/AITER specific fusions # ROCm/AITER specific fusions
fuse_act_padding: bool | None = Field(default=None) fuse_act_padding: bool = None # type: ignore[assignment]
"""Fuse the custom RMSNorm + padding ops.""" """Fuse the custom RMSNorm + padding ops."""
fuse_rope_kvcache: bool | None = Field(default=None) fuse_rope_kvcache: bool = None # type: ignore[assignment]
"""Fuse the QK rope + KV cache ops.""" """Fuse the QK rope + KV cache ops."""
rope_kvcache_fusion_max_token_num: int = 256 rope_kvcache_fusion_max_token_num: int = 256
...@@ -405,7 +405,7 @@ class CompilationConfig: ...@@ -405,7 +405,7 @@ class CompilationConfig:
""" """
# Top-level Compilation control # Top-level Compilation control
mode: CompilationMode = Field(default=None) # type: ignore[assignment] mode: CompilationMode = None # type: ignore[assignment]
"""The compilation approach used for torch.compile-based compilation of the """The compilation approach used for torch.compile-based compilation of the
model. model.
...@@ -545,7 +545,7 @@ class CompilationConfig: ...@@ -545,7 +545,7 @@ class CompilationConfig:
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`.""" constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
# CudaGraph compilation # CudaGraph compilation
cudagraph_mode: CUDAGraphMode = Field(default=None) # type: ignore[assignment] cudagraph_mode: CUDAGraphMode = None # type: ignore[assignment]
""" """
The mode of the cudagraph: The mode of the cudagraph:
...@@ -586,7 +586,7 @@ class CompilationConfig: ...@@ -586,7 +586,7 @@ class CompilationConfig:
It means the first several runs will be treated as warmup runs. It means the first several runs will be treated as warmup runs.
Only after that, the execution will be recorded, and the recorded Only after that, the execution will be recorded, and the recorded
cudagraph will be used for subsequent runs.""" cudagraph will be used for subsequent runs."""
cudagraph_capture_sizes: list[int] | None = None cudagraph_capture_sizes: list[int] = None # type: ignore[assignment]
"""Sizes to capture cudagraph. """Sizes to capture cudagraph.
- None (default): capture sizes are inferred from vllm config. - None (default): capture sizes are inferred from vllm config.
- list[int]: capture sizes are specified as given.""" - list[int]: capture sizes are specified as given."""
...@@ -607,7 +607,7 @@ class CompilationConfig: ...@@ -607,7 +607,7 @@ class CompilationConfig:
When `enable_lora` is False, this option has no effect. When `enable_lora` is False, this option has no effect.
""" """
use_inductor_graph_partition: bool = Field(default=None) # type: ignore[assignment] use_inductor_graph_partition: bool = None # type: ignore[assignment]
"""Use inductor graph partition to split the graph at cudagraph_unsafe ops. """Use inductor graph partition to split the graph at cudagraph_unsafe ops.
This partition happens at inductor codegen time after all passes and fusions This partition happens at inductor codegen time after all passes and fusions
are finished. It generates a single `call` function which wraps are finished. It generates a single `call` function which wraps
...@@ -630,7 +630,7 @@ class CompilationConfig: ...@@ -630,7 +630,7 @@ class CompilationConfig:
pass_config: PassConfig = field(default_factory=PassConfig) pass_config: PassConfig = field(default_factory=PassConfig)
"""Custom inductor passes, see PassConfig for more details""" """Custom inductor passes, see PassConfig for more details"""
max_cudagraph_capture_size: int | None = field(default=None) max_cudagraph_capture_size: int = None # type: ignore[assignment]
"""The maximum cudagraph capture size. """The maximum cudagraph capture size.
If cudagraph_capture_sizes is specified, this will be set to the largest If cudagraph_capture_sizes is specified, this will be set to the largest
...@@ -750,7 +750,7 @@ class CompilationConfig: ...@@ -750,7 +750,7 @@ class CompilationConfig:
return hash_factors(factors) return hash_factors(factors)
def __repr__(self) -> str: def __repr__(self) -> str:
exclude = { exclude: dict[str, bool | dict[str, bool]] = {
"static_forward_context": True, "static_forward_context": True,
"enabled_custom_ops": True, "enabled_custom_ops": True,
"disabled_custom_ops": True, "disabled_custom_ops": True,
...@@ -770,9 +770,7 @@ class CompilationConfig: ...@@ -770,9 +770,7 @@ class CompilationConfig:
exclude["pass_config"] = pass_config_exclude exclude["pass_config"] = pass_config_exclude
config = TypeAdapter(CompilationConfig).dump_python( config = TypeAdapter(CompilationConfig).dump_python(
self, self, exclude=exclude, exclude_unset=True
exclude=exclude, # type: ignore[arg-type]
exclude_unset=True,
) )
return str(config) return str(config)
...@@ -1023,7 +1021,6 @@ class CompilationConfig: ...@@ -1023,7 +1021,6 @@ class CompilationConfig:
"Unrecognized size type in compile_sizes, " "Unrecognized size type in compile_sizes, "
f"expect 'cudagraph_capture_sizes', got {x}" f"expect 'cudagraph_capture_sizes', got {x}"
) )
assert self.cudagraph_capture_sizes is not None
computed_compile_sizes.extend(self.cudagraph_capture_sizes) computed_compile_sizes.extend(self.cudagraph_capture_sizes)
else: else:
assert isinstance(x, int) assert isinstance(x, int)
...@@ -1031,7 +1028,6 @@ class CompilationConfig: ...@@ -1031,7 +1028,6 @@ class CompilationConfig:
self.compile_sizes = computed_compile_sizes # type: ignore self.compile_sizes = computed_compile_sizes # type: ignore
# make sure the sizes are in ascending order # make sure the sizes are in ascending order
assert self.cudagraph_capture_sizes is not None
self.cudagraph_capture_sizes.sort() self.cudagraph_capture_sizes.sort()
if self.cudagraph_capture_sizes: if self.cudagraph_capture_sizes:
assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size
...@@ -1123,7 +1119,6 @@ class CompilationConfig: ...@@ -1123,7 +1119,6 @@ class CompilationConfig:
def set_splitting_ops_for_attn_fusion(self): def set_splitting_ops_for_attn_fusion(self):
assert self.pass_config.fuse_attn_quant assert self.pass_config.fuse_attn_quant
assert self.cudagraph_mode is not None
if self.splitting_ops is None: if self.splitting_ops is None:
self.splitting_ops = [] self.splitting_ops = []
if self.cudagraph_mode.has_piecewise_cudagraphs(): if self.cudagraph_mode.has_piecewise_cudagraphs():
......
...@@ -13,8 +13,8 @@ from vllm.utils.hashing import safe_hash ...@@ -13,8 +13,8 @@ from vllm.utils.hashing import safe_hash
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"] Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
@config(config=ConfigDict(arbitrary_types_allowed=True)) # type: ignore[arg-type,misc] @config(config=ConfigDict(arbitrary_types_allowed=True))
class DeviceConfig: # type: ignore[misc] class DeviceConfig:
"""Configuration for the device to use for vLLM execution.""" """Configuration for the device to use for vLLM execution."""
device: SkipValidation[Device | torch.device | None] = "auto" device: SkipValidation[Device | torch.device | None] = "auto"
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from collections.abc import Callable from collections.abc import Callable
from typing import Any, Literal from typing import Any, Literal
from pydantic import Field, field_validator from pydantic import field_validator
from vllm.config.utils import config from vllm.config.utils import config
from vllm.utils.hashing import safe_hash from vllm.utils.hashing import safe_hash
...@@ -26,7 +26,7 @@ MoEBackend = Literal[ ...@@ -26,7 +26,7 @@ MoEBackend = Literal[
class KernelConfig: class KernelConfig:
"""Configuration for kernel selection and warmup behavior.""" """Configuration for kernel selection and warmup behavior."""
enable_flashinfer_autotune: bool | None = Field(default=None) enable_flashinfer_autotune: bool = None # type: ignore[assignment]
"""If True, run FlashInfer autotuning during kernel warmup.""" """If True, run FlashInfer autotuning during kernel warmup."""
moe_backend: MoEBackend = "auto" moe_backend: MoEBackend = "auto"
......
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
from typing import Literal from typing import Literal
from pydantic import Field
from vllm.config.utils import config from vllm.config.utils import config
...@@ -18,7 +16,7 @@ class KVEventsConfig: ...@@ -18,7 +16,7 @@ class KVEventsConfig:
Events can be published externally by zmq using the event publisher config. Events can be published externally by zmq using the event publisher config.
""" """
publisher: Literal["null", "zmq"] | None = Field(default=None) publisher: Literal["null", "zmq"] = None # type: ignore[assignment]
"""The publisher to use for publishing kv events. Can be "null", "zmq". """The publisher to use for publishing kv events. Can be "null", "zmq".
""" """
......
...@@ -25,8 +25,8 @@ MaxLoRARanks = Literal[1, 8, 16, 32, 64, 128, 256, 320, 512] ...@@ -25,8 +25,8 @@ MaxLoRARanks = Literal[1, 8, 16, 32, 64, 128, 256, 320, 512]
LoRAExtraVocabSize = Literal[256, 512] LoRAExtraVocabSize = Literal[256, 512]
@config(config=ConfigDict(arbitrary_types_allowed=True)) # type: ignore[arg-type,misc] @config(config=ConfigDict(arbitrary_types_allowed=True))
class LoRAConfig: # type: ignore[misc] class LoRAConfig:
"""Configuration for LoRA.""" """Configuration for LoRA."""
max_lora_rank: MaxLoRARanks = 16 max_lora_rank: MaxLoRARanks = 16
......
...@@ -102,8 +102,8 @@ AttnTypeStr = Literal[ ...@@ -102,8 +102,8 @@ AttnTypeStr = Literal[
] ]
@config(config=ConfigDict(arbitrary_types_allowed=True)) # type: ignore[arg-type,misc] @config(config=ConfigDict(arbitrary_types_allowed=True))
class ModelConfig: # type: ignore[misc] class ModelConfig:
"""Configuration for the model.""" """Configuration for the model."""
model: str = "Qwen/Qwen3-0.6B" model: str = "Qwen/Qwen3-0.6B"
...@@ -121,7 +121,7 @@ class ModelConfig: # type: ignore[misc] ...@@ -121,7 +121,7 @@ class ModelConfig: # type: ignore[misc]
"""Convert the model using adapters defined in """Convert the model using adapters defined in
[vllm.model_executor.models.adapters][]. The most common use case is to [vllm.model_executor.models.adapters][]. The most common use case is to
adapt a text generation model to be used for pooling tasks.""" adapt a text generation model to be used for pooling tasks."""
tokenizer: str = Field(default=None) # type: ignore[assignment] tokenizer: str = None # type: ignore[assignment]
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model """Name or path of the Hugging Face tokenizer to use. If unspecified, model
name or path will be used.""" name or path will be used."""
tokenizer_mode: TokenizerMode | str = "auto" tokenizer_mode: TokenizerMode | str = "auto"
...@@ -583,7 +583,7 @@ class ModelConfig: # type: ignore[misc] ...@@ -583,7 +583,7 @@ class ModelConfig: # type: ignore[misc]
self.dtype, self.dtype,
is_pooling_model=self.runner_type == "pooling", is_pooling_model=self.runner_type == "pooling",
revision=self.revision, revision=self.revision,
config_format=self.config_format, # type: ignore[arg-type] config_format=self.config_format,
) )
self.original_max_model_len = self.max_model_len self.original_max_model_len = self.max_model_len
...@@ -733,7 +733,7 @@ class ModelConfig: # type: ignore[misc] ...@@ -733,7 +733,7 @@ class ModelConfig: # type: ignore[misc]
@property @property
def architectures(self) -> list[str]: def architectures(self) -> list[str]:
return self.model_arch_config.architectures # type: ignore[return-value] return self.model_arch_config.architectures
@property @property
def architecture(self) -> str: def architecture(self) -> str:
...@@ -1944,7 +1944,7 @@ def _get_and_verify_dtype( ...@@ -1944,7 +1944,7 @@ def _get_and_verify_dtype(
*, *,
is_pooling_model: bool, is_pooling_model: bool,
revision: str | None = None, revision: str | None = None,
config_format: ConfigFormat = "hf", config_format: str | ConfigFormat = "hf",
) -> torch.dtype: ) -> torch.dtype:
config_dtype = ModelArchConfigConvertorBase.get_torch_dtype( config_dtype = ModelArchConfigConvertorBase.get_torch_dtype(
config, model_id, revision=revision, config_format=config_format config, model_id, revision=revision, config_format=config_format
......
...@@ -16,7 +16,7 @@ class ModelArchitectureConfig: ...@@ -16,7 +16,7 @@ class ModelArchitectureConfig:
Configuration for model architecture that required by vLLM runtime Configuration for model architecture that required by vLLM runtime
""" """
architectures: list[str] | None architectures: list[str]
"""List of model architecture class names (e.g., ['LlamaForCausalLM']). """List of model architecture class names (e.g., ['LlamaForCausalLM']).
It can be None upon calling `vllm_config.with_hf_config(config.text_config)`""" It can be None upon calling `vllm_config.with_hf_config(config.text_config)`"""
......
...@@ -194,7 +194,7 @@ class ParallelConfig: ...@@ -194,7 +194,7 @@ class ParallelConfig:
threshold, microbatching will be used. Otherwise, the request will be threshold, microbatching will be used. Otherwise, the request will be
processed in a single batch.""" processed in a single batch."""
disable_nccl_for_dp_synchronization: bool | None = Field(default=None) disable_nccl_for_dp_synchronization: bool | None = None
"""Forces the dp synchronization logic in vllm/v1/worker/dp_utils.py """Forces the dp synchronization logic in vllm/v1/worker/dp_utils.py
to use Gloo instead of NCCL for its all reduce. to use Gloo instead of NCCL for its all reduce.
......
...@@ -52,7 +52,7 @@ class SchedulerConfig: ...@@ -52,7 +52,7 @@ class SchedulerConfig:
In real usage, this should be set in `EngineArgs.create_engine_config`. In real usage, this should be set in `EngineArgs.create_engine_config`.
""" """
max_num_scheduled_tokens: int | None = Field(default=None) max_num_scheduled_tokens: int | None = None
"""Maximum number of tokens that the scheduler may issue in a single iteration. """Maximum number of tokens that the scheduler may issue in a single iteration.
This is usually equal to max_num_batched_tokens, but can be smaller in cases This is usually equal to max_num_batched_tokens, but can be smaller in cases
...@@ -122,7 +122,7 @@ class SchedulerConfig: ...@@ -122,7 +122,7 @@ class SchedulerConfig:
# scheduler class or path. "vllm.v1.core.sched.scheduler.Scheduler" # scheduler class or path. "vllm.v1.core.sched.scheduler.Scheduler"
# (default) or "mod.custom_class". # (default) or "mod.custom_class".
scheduler_cls: str | type[object] | None = Field(default=None) scheduler_cls: str | type[object] | None = None
"""The scheduler class to use. "vllm.v1.core.sched.scheduler.Scheduler" is """The scheduler class to use. "vllm.v1.core.sched.scheduler.Scheduler" is
the default scheduler. Can be a class directly or the path to a class of the default scheduler. Can be a class directly or the path to a class of
form "mod.custom_class".""" form "mod.custom_class"."""
...@@ -141,7 +141,7 @@ class SchedulerConfig: ...@@ -141,7 +141,7 @@ class SchedulerConfig:
checking the first chunk. Prevents over-admission and KV cache thrashing checking the first chunk. Prevents over-admission and KV cache thrashing
with chunked prefill.""" with chunked prefill."""
async_scheduling: bool | None = Field(default=None) async_scheduling: bool | None = None
"""If set to False, disable async scheduling. Async scheduling helps to """If set to False, disable async scheduling. Async scheduling helps to
avoid gaps in GPU utilization, leading to better latency and throughput. avoid gaps in GPU utilization, leading to better latency and throughput.
""" """
......
...@@ -11,13 +11,13 @@ import os ...@@ -11,13 +11,13 @@ import os
import pathlib import pathlib
import textwrap import textwrap
from collections.abc import Callable, Mapping, Sequence, Set from collections.abc import Callable, Mapping, Sequence, Set
from dataclasses import MISSING, dataclass, field, fields, is_dataclass from dataclasses import MISSING, field, fields, is_dataclass
from itertools import pairwise from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast, overload
import torch import torch
from pydantic import ConfigDict from pydantic import ConfigDict
from pydantic.dataclasses import dataclass as pydantic_dataclass from pydantic.dataclasses import dataclass
from pydantic.fields import Field as PydanticField from pydantic.fields import Field as PydanticField
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from typing_extensions import dataclass_transform, runtime_checkable from typing_extensions import dataclass_transform, runtime_checkable
...@@ -36,6 +36,16 @@ ConfigType = type[DataclassInstance] ...@@ -36,6 +36,16 @@ ConfigType = type[DataclassInstance]
ConfigT = TypeVar("ConfigT", bound=DataclassInstance) ConfigT = TypeVar("ConfigT", bound=DataclassInstance)
@overload
def config(cls: type[ConfigT]) -> type[ConfigT]: ...
@overload
def config(
*, config: ConfigDict | None = None, **kwargs: Any
) -> Callable[[type[ConfigT]], type[ConfigT]]: ...
@dataclass_transform(field_specifiers=(PydanticField,)) @dataclass_transform(field_specifiers=(PydanticField,))
def config( def config(
cls: type[ConfigT] | None = None, cls: type[ConfigT] | None = None,
...@@ -59,7 +69,7 @@ def config( ...@@ -59,7 +69,7 @@ def config(
merged_config.update(config) merged_config.update(config)
def decorator(cls: type[ConfigT]) -> type[ConfigT]: def decorator(cls: type[ConfigT]) -> type[ConfigT]:
return pydantic_dataclass(cls, config=merged_config, **kwargs) # type: ignore[return-value] return dataclass(cls, config=merged_config, **kwargs) # type: ignore[return-value]
# Called with arguments: @config(config=...) # Called with arguments: @config(config=...)
if cls is None: if cls is None:
......
...@@ -246,15 +246,15 @@ OPTIMIZATION_LEVEL_TO_CONFIG = { ...@@ -246,15 +246,15 @@ OPTIMIZATION_LEVEL_TO_CONFIG = {
} }
@config(config=ConfigDict(arbitrary_types_allowed=True)) # type: ignore[arg-type,misc] @config(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig: # type: ignore[misc] class VllmConfig:
"""Dataclass which contains all vllm-related configuration. This """Dataclass which contains all vllm-related configuration. This
simplifies passing around the distinct configurations in the codebase. simplifies passing around the distinct configurations in the codebase.
""" """
# TODO: use default_factory once default constructing ModelConfig doesn't # TODO: use default_factory once default constructing ModelConfig doesn't
# try to download a model # try to download a model
model_config: ModelConfig = Field(default=None) # type: ignore[assignment] model_config: ModelConfig = None # type: ignore[assignment]
"""Model configuration.""" """Model configuration."""
cache_config: CacheConfig = Field(default_factory=CacheConfig) cache_config: CacheConfig = Field(default_factory=CacheConfig)
"""Cache configuration.""" """Cache configuration."""
...@@ -912,7 +912,8 @@ class VllmConfig: # type: ignore[misc] ...@@ -912,7 +912,8 @@ class VllmConfig: # type: ignore[misc]
tp_size = self.parallel_config.tensor_parallel_size tp_size = self.parallel_config.tensor_parallel_size
hidden_size = self.model_config.get_hidden_size() hidden_size = self.model_config.get_hidden_size()
element_size = self.model_config.dtype.itemsize # type: ignore[union-attr] assert isinstance(self.model_config.dtype, torch.dtype)
element_size = self.model_config.dtype.itemsize
pass_config.sp_min_token_num = get_sequence_parallelism_threshold( pass_config.sp_min_token_num = get_sequence_parallelism_threshold(
hidden_size, tp_size, element_size hidden_size, tp_size, element_size
) )
...@@ -1246,14 +1247,6 @@ class VllmConfig: # type: ignore[misc] ...@@ -1246,14 +1247,6 @@ class VllmConfig: # type: ignore[misc]
) )
self.compilation_config.debug_dump_path = env_path self.compilation_config.debug_dump_path = env_path
def has_blocked_weights(): # type: ignore[no-redef]
if self.quant_config is not None:
if hasattr(self.quant_config, "weight_block_size"):
return self.quant_config.weight_block_size is not None
elif hasattr(self.quant_config, "has_blocked_weights"):
return self.quant_config.has_blocked_weights()
return False
# Enable quant_fp8 CUDA ops (TODO disable in follow up) # Enable quant_fp8 CUDA ops (TODO disable in follow up)
# On H100 the CUDA kernel is faster than # On H100 the CUDA kernel is faster than
# native implementation # native implementation
...@@ -1502,9 +1495,10 @@ class VllmConfig: # type: ignore[misc] ...@@ -1502,9 +1495,10 @@ class VllmConfig: # type: ignore[misc]
tp_size = self.parallel_config.tensor_parallel_size tp_size = self.parallel_config.tensor_parallel_size
max_size = compilation_config.pass_config.flashinfer_max_size(tp_size) max_size = compilation_config.pass_config.flashinfer_max_size(tp_size)
if max_size is not None: if max_size is not None:
assert isinstance(self.model_config.dtype, torch.dtype)
max_token_num = max_size // ( max_token_num = max_size // (
self.model_config.get_hidden_size() self.model_config.get_hidden_size()
* self.model_config.dtype.itemsize # type: ignore[union-attr] * self.model_config.dtype.itemsize
) )
if compile_range_end is not None and max_token_num < compile_range_end: if compile_range_end is not None and max_token_num < compile_range_end:
computed_compile_ranges_endpoints.append(max_token_num) computed_compile_ranges_endpoints.append(max_token_num)
...@@ -1527,7 +1521,8 @@ class VllmConfig: # type: ignore[misc] ...@@ -1527,7 +1521,8 @@ class VllmConfig: # type: ignore[misc]
tp_size = self.parallel_config.tensor_parallel_size tp_size = self.parallel_config.tensor_parallel_size
hidden_size = self.model_config.get_hidden_size() hidden_size = self.model_config.get_hidden_size()
element_size = self.model_config.dtype.itemsize # type: ignore[union-attr] assert isinstance(self.model_config.dtype, torch.dtype)
element_size = self.model_config.dtype.itemsize
pass_config.sp_min_token_num = get_sequence_parallelism_threshold( pass_config.sp_min_token_num = get_sequence_parallelism_threshold(
hidden_size, tp_size, element_size hidden_size, tp_size, element_size
) )
......
...@@ -1935,7 +1935,7 @@ class EngineArgs: ...@@ -1935,7 +1935,7 @@ class EngineArgs:
) )
offload_config = OffloadConfig( offload_config = OffloadConfig(
offload_backend=self.offload_backend, # type: ignore[arg-type] offload_backend=self.offload_backend,
uva=UVAOffloadConfig( uva=UVAOffloadConfig(
cpu_offload_gb=self.cpu_offload_gb, cpu_offload_gb=self.cpu_offload_gb,
cpu_offload_params=self.cpu_offload_params, cpu_offload_params=self.cpu_offload_params,
......
...@@ -409,6 +409,11 @@ class LLM: ...@@ -409,6 +409,11 @@ class LLM:
# Cache for __repr__ to avoid repeated collective_rpc calls # Cache for __repr__ to avoid repeated collective_rpc calls
self._cached_repr: str | None = None self._cached_repr: str | None = None
@classmethod
def from_engine_args(cls, engine_args: EngineArgs) -> "LLM":
"""Create an LLM instance from EngineArgs."""
return cls(**vars(engine_args))
def get_tokenizer(self) -> TokenizerLike: def get_tokenizer(self) -> TokenizerLike:
return self.llm_engine.get_tokenizer() return self.llm_engine.get_tokenizer()
......
...@@ -28,7 +28,10 @@ class ModelArchConfigConvertorBase: ...@@ -28,7 +28,10 @@ class ModelArchConfigConvertorBase:
self.hf_text_config = hf_text_config self.hf_text_config = hf_text_config
def get_architectures(self) -> list[str]: def get_architectures(self) -> list[str]:
return getattr(self.hf_config, "architectures", []) # Sometimes we get here from `vllm_config.with_hf_config(text_config)` where
# `text_config` is a sub-config from a multi-modal model. If this is the case,
# the sub-config will not have `architectures` and it will explicitly be `None`
return getattr(self.hf_config, "architectures", None) or []
def get_num_hidden_layers(self) -> int: def get_num_hidden_layers(self) -> int:
return getattr(self.hf_text_config, "num_hidden_layers", 0) return getattr(self.hf_text_config, "num_hidden_layers", 0)
...@@ -128,7 +131,7 @@ class ModelArchConfigConvertorBase: ...@@ -128,7 +131,7 @@ class ModelArchConfigConvertorBase:
hf_config: PretrainedConfig, hf_config: PretrainedConfig,
model_id: str, model_id: str,
revision: str | None, revision: str | None,
config_format: ConfigFormat, config_format: str | ConfigFormat,
): ):
# NOTE: getattr(config, "dtype", torch.float32) is not correct # NOTE: getattr(config, "dtype", torch.float32) is not correct
# because config.dtype can be None. # because config.dtype can be None.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment