Unverified Commit 61aedb5f authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Move`VllmConfig` from `config/__init__.py` to `config/vllm.py` (#25271)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent d3bd1711
......@@ -20,8 +20,7 @@ from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import (
......
......@@ -9,7 +9,8 @@ from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, QuantizationConfig
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, make_local_attention_virtual_batches,
subclass_attention_backend)
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utility functions for vLLM config dataclasses."""
import ast
import inspect
import textwrap
from dataclasses import MISSING, Field, field, fields, is_dataclass
from typing import TYPE_CHECKING, Any, TypeVar
from dataclasses import MISSING, Field, field, fields, is_dataclass, replace
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
import regex as re
from typing_extensions import runtime_checkable
if TYPE_CHECKING:
from _typeshed import DataclassInstance
ConfigType = type[DataclassInstance]
else:
ConfigType = type
DataclassInstance = Any
ConfigType = type[DataclassInstance]
ConfigT = TypeVar("ConfigT", bound=ConfigType)
......@@ -143,3 +143,33 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
def is_init_field(cls: ConfigType, name: str) -> bool:
return next(f for f in fields(cls) if f.name == name).init
@runtime_checkable
class SupportsHash(Protocol):
def compute_hash(self) -> str:
...
class SupportsMetricsInfo(Protocol):
def metrics_info(self) -> dict[str, str]:
...
def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT:
processed_overrides = {}
for field_name, value in overrides.items():
assert hasattr(
config, field_name), f"{type(config)} has no field `{field_name}`"
current_value = getattr(config, field_name)
if is_dataclass(current_value) and not is_dataclass(value):
assert isinstance(value, dict), (
f"Overrides to {type(config)}.{field_name} must be a dict"
f" or {type(current_value)}, but got {type(value)}")
value = update_config(
current_value, # type: ignore[type-var]
value)
processed_overrides[field_name] = value
return replace(config, **processed_overrides)
This diff is collapsed.
......@@ -29,8 +29,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
......
......@@ -9,9 +9,8 @@ import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import (QuantizationConfig,
QuantizationMethods)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
......
......@@ -7,9 +7,8 @@ from packaging import version
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import (QuantizationConfig,
QuantizationMethods)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_NUM_BITS,
BITBLAS_SUPPORTED_SYM, MINIMUM_BITBLAS_VERSION)
......
......@@ -13,9 +13,8 @@ from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import (QuantizationConfig,
QuantizationMethods)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
......
......@@ -9,9 +9,8 @@ import torch.nn.functional as F
from packaging import version
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import (QuantizationConfig,
QuantizationMethods)
from vllm.model_executor.utils import set_weight_attrs
......
......@@ -4,7 +4,7 @@
import enum
from enum import Enum
from fractions import Fraction
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
......@@ -13,7 +13,6 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
......@@ -26,6 +25,11 @@ from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
from vllm.transformers_utils.config import get_safetensors_params_metadata
from vllm.utils import is_list_of
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
else:
QuantizationMethods = str
class GPTQConfig(QuantizationConfig):
"""Config class for GPTQ.
......
......@@ -9,9 +9,8 @@ from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import (QuantizationConfig,
QuantizationMethods)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
BitBLASLinearKernel, MPLinearLayerConfig)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
......
......@@ -43,7 +43,7 @@ logger = init_logger(__name__)
def get_moe_quant_method(
config: QuantizationConfig,
config: "GPTQMarlinConfig",
layer: torch.nn.Module,
prefix: str,
moe_method_cls: type,
......
......@@ -9,9 +9,8 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import (QuantizationConfig,
QuantizationMethods)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
......
......@@ -14,11 +14,10 @@ from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase,
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization import (QuantizationConfig,
QuantizationMethods)
from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
Fp8LinearMethod)
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
......
......@@ -7,8 +7,7 @@ import torch
from packaging import version
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_GROUP_SIZES,
......
......@@ -8,9 +8,8 @@ from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import (QuantizationConfig,
QuantizationMethods)
from vllm.model_executor.parameter import ModelWeightParameter
ACTIVATION_SCHEMES = ["none", "dynamic"]
......
......@@ -4,21 +4,27 @@ from collections.abc import Mapping
from copy import deepcopy
from fractions import Fraction
from types import MappingProxyType
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union
import regex as re
import torch
from vllm.config import QuantizationConfig
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, UnquantizedEmbeddingMethod)
if TYPE_CHECKING:
from ..gptq import GPTQConfig
from ..gptq_marlin import GPTQMarlinConfig
else:
GPTQConfig = object
GPTQMarlinConfig = object
# Match dynamic rules with module name (prefix) and override quantize
# config if module (prefix) matches a rule
def override_config(config: QuantizationConfig, prefix: str):
def override_config(config: Union[GPTQConfig, GPTQMarlinConfig], prefix: str):
weight_bits = get_dynamic_override(config, prefix, "bits",
config.weight_bits)
if isinstance(weight_bits, int):
......@@ -34,6 +40,7 @@ def override_config(config: QuantizationConfig, prefix: str):
config.pack_factor = Fraction(32, config.weight_bits) # packed into int32
if config.get_name() == "gptq_marlin":
assert isinstance(config, GPTQMarlinConfig)
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
if isinstance(is_sym, bool):
config.is_sym = is_sym
......@@ -45,6 +52,7 @@ def override_config(config: QuantizationConfig, prefix: str):
config.quant_type = config.TYPE_MAP[(config.weight_bits,
config.is_sym)]
elif config.get_name() == "gptq":
assert isinstance(config, GPTQConfig)
if config.weight_bits not in [2, 3, 4, 8]:
raise ValueError(
"Currently, only 2/3/4/8-bit weight quantization is "
......@@ -52,7 +60,7 @@ def override_config(config: QuantizationConfig, prefix: str):
def get_dynamic_override(
config: QuantizationConfig,
config: Union[GPTQConfig, GPTQMarlinConfig],
layer_name: str,
key: Optional[str] = None,
default_value: Union[int, bool,
......@@ -116,7 +124,7 @@ def is_layer_gptq_quantized(
def get_linear_quant_method(
config: QuantizationConfig,
config: Union[GPTQConfig, GPTQMarlinConfig],
layer: torch.nn.Module,
prefix: str,
linear_method_cls: type,
......
......@@ -17,8 +17,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.ovis import AIMv2Config
......
......@@ -9,13 +9,14 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature
from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor
from vllm.config import QuantizationConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment