Unverified Commit 61aedb5f authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Move`VllmConfig` from `config/__init__.py` to `config/vllm.py` (#25271)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent d3bd1711
...@@ -20,8 +20,7 @@ from vllm.forward_context import ForwardContext, get_forward_context ...@@ -20,8 +20,7 @@ from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
......
...@@ -9,7 +9,8 @@ from vllm import envs ...@@ -9,7 +9,8 @@ from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata) AttentionMetadata)
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, QuantizationConfig from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, make_local_attention_virtual_batches, CommonAttentionMetadata, make_local_attention_virtual_batches,
subclass_attention_backend) subclass_attention_backend)
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utility functions for vLLM config dataclasses."""
import ast import ast
import inspect import inspect
import textwrap import textwrap
from dataclasses import MISSING, Field, field, fields, is_dataclass from dataclasses import MISSING, Field, field, fields, is_dataclass, replace
from typing import TYPE_CHECKING, Any, TypeVar from typing import TYPE_CHECKING, Any, Protocol, TypeVar
import regex as re import regex as re
from typing_extensions import runtime_checkable
if TYPE_CHECKING: if TYPE_CHECKING:
from _typeshed import DataclassInstance from _typeshed import DataclassInstance
ConfigType = type[DataclassInstance]
else: else:
ConfigType = type DataclassInstance = Any
ConfigType = type[DataclassInstance]
ConfigT = TypeVar("ConfigT", bound=ConfigType) ConfigT = TypeVar("ConfigT", bound=ConfigType)
...@@ -143,3 +143,33 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]: ...@@ -143,3 +143,33 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
def is_init_field(cls: ConfigType, name: str) -> bool: def is_init_field(cls: ConfigType, name: str) -> bool:
return next(f for f in fields(cls) if f.name == name).init return next(f for f in fields(cls) if f.name == name).init
@runtime_checkable
class SupportsHash(Protocol):
def compute_hash(self) -> str:
...
class SupportsMetricsInfo(Protocol):
def metrics_info(self) -> dict[str, str]:
...
def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT:
processed_overrides = {}
for field_name, value in overrides.items():
assert hasattr(
config, field_name), f"{type(config)} has no field `{field_name}`"
current_value = getattr(config, field_name)
if is_dataclass(current_value) and not is_dataclass(value):
assert isinstance(value, dict), (
f"Overrides to {type(config)}.{field_name} must be a dict"
f" or {type(current_value)}, but got {type(value)}")
value = update_config(
current_value, # type: ignore[type-var]
value)
processed_overrides[field_name] = value
return replace(config, **processed_overrides)
This diff is collapsed.
...@@ -29,8 +29,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -29,8 +29,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
......
...@@ -9,9 +9,8 @@ import torch ...@@ -9,9 +9,8 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import (QuantizationConfig,
from vllm.model_executor.layers.quantization.base_config import ( QuantizationMethods)
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
......
...@@ -7,9 +7,8 @@ from packaging import version ...@@ -7,9 +7,8 @@ from packaging import version
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import (QuantizationConfig,
from vllm.model_executor.layers.quantization.base_config import ( QuantizationMethods)
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_NUM_BITS, BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_NUM_BITS,
BITBLAS_SUPPORTED_SYM, MINIMUM_BITBLAS_VERSION) BITBLAS_SUPPORTED_SYM, MINIMUM_BITBLAS_VERSION)
......
...@@ -13,9 +13,8 @@ from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, ...@@ -13,9 +13,8 @@ from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod, UnquantizedLinearMethod,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import (QuantizationConfig,
from vllm.model_executor.layers.quantization.base_config import ( QuantizationMethods)
QuantizationConfig)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
......
...@@ -9,9 +9,8 @@ import torch.nn.functional as F ...@@ -9,9 +9,8 @@ import torch.nn.functional as F
from packaging import version from packaging import version
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import (QuantizationConfig,
from vllm.model_executor.layers.quantization.base_config import ( QuantizationMethods)
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import enum import enum
from enum import Enum from enum import Enum
from fractions import Fraction from fractions import Fraction
from typing import Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
import torch import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
...@@ -13,7 +13,6 @@ from torch.nn.parameter import Parameter ...@@ -13,7 +13,6 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.gptq_utils import ( from vllm.model_executor.layers.quantization.utils.gptq_utils import (
...@@ -26,6 +25,11 @@ from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ...@@ -26,6 +25,11 @@ from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
from vllm.transformers_utils.config import get_safetensors_params_metadata from vllm.transformers_utils.config import get_safetensors_params_metadata
from vllm.utils import is_list_of from vllm.utils import is_list_of
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
else:
QuantizationMethods = str
class GPTQConfig(QuantizationConfig): class GPTQConfig(QuantizationConfig):
"""Config class for GPTQ. """Config class for GPTQ.
......
...@@ -9,9 +9,8 @@ from torch.nn.parameter import Parameter ...@@ -9,9 +9,8 @@ from torch.nn.parameter import Parameter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import (QuantizationConfig,
from vllm.model_executor.layers.quantization.base_config import ( QuantizationMethods)
QuantizationConfig)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
BitBLASLinearKernel, MPLinearLayerConfig) BitBLASLinearKernel, MPLinearLayerConfig)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
......
...@@ -43,7 +43,7 @@ logger = init_logger(__name__) ...@@ -43,7 +43,7 @@ logger = init_logger(__name__)
def get_moe_quant_method( def get_moe_quant_method(
config: QuantizationConfig, config: "GPTQMarlinConfig",
layer: torch.nn.Module, layer: torch.nn.Module,
prefix: str, prefix: str,
moe_method_cls: type, moe_method_cls: type,
......
...@@ -9,9 +9,8 @@ from torch.nn.parameter import Parameter ...@@ -9,9 +9,8 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import (QuantizationConfig,
from vllm.model_executor.layers.quantization.base_config import ( QuantizationMethods)
QuantizationConfig)
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
GroupQuantScaleParameter, GroupQuantScaleParameter,
......
...@@ -14,11 +14,10 @@ from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase, ...@@ -14,11 +14,10 @@ from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase,
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import (QuantizationConfig,
QuantizationMethods)
from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod, from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
is_layer_skipped_awq) is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
Fp8LinearMethod) Fp8LinearMethod)
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
......
...@@ -7,8 +7,7 @@ import torch ...@@ -7,8 +7,7 @@ import torch
from packaging import version from packaging import version
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_GROUP_SIZES, BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_GROUP_SIZES,
......
...@@ -8,9 +8,8 @@ from torch.nn import Module ...@@ -8,9 +8,8 @@ from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import (QuantizationConfig,
from vllm.model_executor.layers.quantization.base_config import ( QuantizationMethods)
QuantizationConfig)
from vllm.model_executor.parameter import ModelWeightParameter from vllm.model_executor.parameter import ModelWeightParameter
ACTIVATION_SCHEMES = ["none", "dynamic"] ACTIVATION_SCHEMES = ["none", "dynamic"]
......
...@@ -4,21 +4,27 @@ from collections.abc import Mapping ...@@ -4,21 +4,27 @@ from collections.abc import Mapping
from copy import deepcopy from copy import deepcopy
from fractions import Fraction from fractions import Fraction
from types import MappingProxyType from types import MappingProxyType
from typing import Optional, Union from typing import TYPE_CHECKING, Optional, Union
import regex as re import regex as re
import torch import torch
from vllm.config import QuantizationConfig
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, UnquantizedEmbeddingMethod) ParallelLMHead, UnquantizedEmbeddingMethod)
if TYPE_CHECKING:
from ..gptq import GPTQConfig
from ..gptq_marlin import GPTQMarlinConfig
else:
GPTQConfig = object
GPTQMarlinConfig = object
# Match dynamic rules with module name (prefix) and override quantize # Match dynamic rules with module name (prefix) and override quantize
# config if module (prefix) matches a rule # config if module (prefix) matches a rule
def override_config(config: QuantizationConfig, prefix: str): def override_config(config: Union[GPTQConfig, GPTQMarlinConfig], prefix: str):
weight_bits = get_dynamic_override(config, prefix, "bits", weight_bits = get_dynamic_override(config, prefix, "bits",
config.weight_bits) config.weight_bits)
if isinstance(weight_bits, int): if isinstance(weight_bits, int):
...@@ -34,6 +40,7 @@ def override_config(config: QuantizationConfig, prefix: str): ...@@ -34,6 +40,7 @@ def override_config(config: QuantizationConfig, prefix: str):
config.pack_factor = Fraction(32, config.weight_bits) # packed into int32 config.pack_factor = Fraction(32, config.weight_bits) # packed into int32
if config.get_name() == "gptq_marlin": if config.get_name() == "gptq_marlin":
assert isinstance(config, GPTQMarlinConfig)
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym) is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
if isinstance(is_sym, bool): if isinstance(is_sym, bool):
config.is_sym = is_sym config.is_sym = is_sym
...@@ -45,6 +52,7 @@ def override_config(config: QuantizationConfig, prefix: str): ...@@ -45,6 +52,7 @@ def override_config(config: QuantizationConfig, prefix: str):
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.quant_type = config.TYPE_MAP[(config.weight_bits,
config.is_sym)] config.is_sym)]
elif config.get_name() == "gptq": elif config.get_name() == "gptq":
assert isinstance(config, GPTQConfig)
if config.weight_bits not in [2, 3, 4, 8]: if config.weight_bits not in [2, 3, 4, 8]:
raise ValueError( raise ValueError(
"Currently, only 2/3/4/8-bit weight quantization is " "Currently, only 2/3/4/8-bit weight quantization is "
...@@ -52,7 +60,7 @@ def override_config(config: QuantizationConfig, prefix: str): ...@@ -52,7 +60,7 @@ def override_config(config: QuantizationConfig, prefix: str):
def get_dynamic_override( def get_dynamic_override(
config: QuantizationConfig, config: Union[GPTQConfig, GPTQMarlinConfig],
layer_name: str, layer_name: str,
key: Optional[str] = None, key: Optional[str] = None,
default_value: Union[int, bool, default_value: Union[int, bool,
...@@ -116,7 +124,7 @@ def is_layer_gptq_quantized( ...@@ -116,7 +124,7 @@ def is_layer_gptq_quantized(
def get_linear_quant_method( def get_linear_quant_method(
config: QuantizationConfig, config: Union[GPTQConfig, GPTQMarlinConfig],
layer: torch.nn.Module, layer: torch.nn.Module,
prefix: str, prefix: str,
linear_method_cls: type, linear_method_cls: type,
......
...@@ -17,8 +17,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm ...@@ -17,8 +17,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.ovis import AIMv2Config from vllm.transformers_utils.configs.ovis import AIMv2Config
......
...@@ -9,13 +9,14 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature ...@@ -9,13 +9,14 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature
from transformers.models.aria.modeling_aria import AriaCrossAttention from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor from transformers.models.aria.processing_aria import AriaProcessor
from vllm.config import QuantizationConfig, VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment