Unverified Commit 13698db6 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Improve configs - `ModelConfig` (#17130)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 2c4f59af
...@@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter ...@@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
...@@ -123,7 +124,7 @@ class GPTQBitBLASConfig(QuantizationConfig): ...@@ -123,7 +124,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
f"quant_method={self.quant_method})") f"quant_method={self.quant_method})")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "gptq_bitblas" return "gptq_bitblas"
@classmethod @classmethod
...@@ -151,8 +152,8 @@ class GPTQBitBLASConfig(QuantizationConfig): ...@@ -151,8 +152,8 @@ class GPTQBitBLASConfig(QuantizationConfig):
lm_head_quantized) lm_head_quantized)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(
user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg) can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "bitblas" is_valid_user_quant = (user_quant is None or user_quant == "bitblas"
......
...@@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.layer import ( ...@@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
...@@ -100,7 +101,7 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -100,7 +101,7 @@ class GPTQMarlinConfig(QuantizationConfig):
f"dynamic={self.dynamic}") f"dynamic={self.dynamic}")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "gptq_marlin" return "gptq_marlin"
@classmethod @classmethod
...@@ -130,8 +131,8 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -130,8 +131,8 @@ class GPTQMarlinConfig(QuantizationConfig):
lm_head_quantized, dynamic, config) lm_head_quantized, dynamic, config)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(
user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin" is_valid_user_quant = (user_quant is None or user_quant == "marlin"
......
...@@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter ...@@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
...@@ -85,7 +86,7 @@ class GPTQMarlin24Config(QuantizationConfig): ...@@ -85,7 +86,7 @@ class GPTQMarlin24Config(QuantizationConfig):
self.quant_type, self.group_size) self.quant_type, self.group_size)
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "gptq_marlin_24" return "gptq_marlin_24"
@classmethod @classmethod
...@@ -108,8 +109,8 @@ class GPTQMarlin24Config(QuantizationConfig): ...@@ -108,8 +109,8 @@ class GPTQMarlin24Config(QuantizationConfig):
return cls(weight_bits, group_size) return cls(weight_bits, group_size)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(
user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
is_marlin_24_format = ( is_marlin_24_format = (
hf_quant_cfg.get("checkpoint_format") == "marlin_24") hf_quant_cfg.get("checkpoint_format") == "marlin_24")
......
...@@ -8,6 +8,7 @@ from vllm import _custom_ops as ops ...@@ -8,6 +8,7 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
...@@ -50,7 +51,7 @@ class HQQMarlinConfig(QuantizationConfig): ...@@ -50,7 +51,7 @@ class HQQMarlinConfig(QuantizationConfig):
f"group_size={self.group_size})") f"group_size={self.group_size})")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "hqq" return "hqq"
@classmethod @classmethod
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod, from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
is_layer_skipped_awq) is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -58,7 +59,7 @@ class IPEXConfig(QuantizationConfig): ...@@ -58,7 +59,7 @@ class IPEXConfig(QuantizationConfig):
f"group_size={self.group_size})") f"group_size={self.group_size})")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "ipex" return "ipex"
@classmethod @classmethod
...@@ -97,8 +98,8 @@ class IPEXConfig(QuantizationConfig): ...@@ -97,8 +98,8 @@ class IPEXConfig(QuantizationConfig):
lm_head_quantized) lm_head_quantized)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(
user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
if not current_platform.is_cpu() and not current_platform.is_xpu(): if not current_platform.is_cpu() and not current_platform.is_xpu():
return None return None
......
...@@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter ...@@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
...@@ -63,7 +64,7 @@ class MarlinConfig(QuantizationConfig): ...@@ -63,7 +64,7 @@ class MarlinConfig(QuantizationConfig):
f"lm_head_quantized={self.lm_head_quantized})") f"lm_head_quantized={self.lm_head_quantized})")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "marlin" return "marlin"
@classmethod @classmethod
...@@ -87,8 +88,8 @@ class MarlinConfig(QuantizationConfig): ...@@ -87,8 +88,8 @@ class MarlinConfig(QuantizationConfig):
return cls(group_size, lm_head_quantized) return cls(group_size, lm_head_quantized)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(
user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
# compat: autogptq >=0.8.0 use checkpoint_format: str # compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool # compat: autogptq <=0.7.1 is_marlin_format: bool
is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin" is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin"
......
...@@ -11,6 +11,7 @@ from vllm._custom_ops import (cutlass_scaled_fp4_mm, ...@@ -11,6 +11,7 @@ from vllm._custom_ops import (cutlass_scaled_fp4_mm,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
...@@ -42,7 +43,7 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -42,7 +43,7 @@ class ModelOptFp8Config(QuantizationConfig):
" the format is experimental and could change.") " the format is experimental and could change.")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "modelopt" return "modelopt"
@classmethod @classmethod
...@@ -184,8 +185,8 @@ class ModelOptNvFp4Config(QuantizationConfig): ...@@ -184,8 +185,8 @@ class ModelOptNvFp4Config(QuantizationConfig):
self.exclude_modules = exclude_modules self.exclude_modules = exclude_modules
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "modelopt_nvfp4" return "nvfp4"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> List[torch.dtype]:
......
...@@ -9,6 +9,7 @@ from vllm.model_executor.layers.fused_moe.layer import ( ...@@ -9,6 +9,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
...@@ -64,7 +65,7 @@ class MoeWNA16Config(QuantizationConfig): ...@@ -64,7 +65,7 @@ class MoeWNA16Config(QuantizationConfig):
self.modules_to_not_convert = modules_to_not_convert self.modules_to_not_convert = modules_to_not_convert
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "moe_wna16" return "moe_wna16"
@classmethod @classmethod
...@@ -100,8 +101,8 @@ class MoeWNA16Config(QuantizationConfig): ...@@ -100,8 +101,8 @@ class MoeWNA16Config(QuantizationConfig):
lm_head_quantized, modules_to_not_convert, config) lm_head_quantized, modules_to_not_convert, config)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(
user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg) can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
if can_convert and user_quant == "moe_wna16": if can_convert and user_quant == "moe_wna16":
return cls.get_name() return cls.get_name()
......
...@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional ...@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional
from torch.nn import Module from torch.nn import Module
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
...@@ -30,7 +31,7 @@ class NeuronQuantConfig(QuantizationConfig): ...@@ -30,7 +31,7 @@ class NeuronQuantConfig(QuantizationConfig):
self.dequant_dtype = dequant_dtype self.dequant_dtype = dequant_dtype
self.quantize_method = quantize_method self.quantize_method = quantize_method
def get_name(self) -> str: def get_name(self) -> QuantizationMethods:
return "neuron_quant" return "neuron_quant"
def get_supported_act_dtypes(self) -> List[str]: def get_supported_act_dtypes(self) -> List[str]:
......
...@@ -9,6 +9,7 @@ from vllm import _custom_ops as ops ...@@ -9,6 +9,7 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase) QuantizeMethodBase)
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
...@@ -50,7 +51,7 @@ class PTPCFp8Config(Fp8Config): ...@@ -50,7 +51,7 @@ class PTPCFp8Config(Fp8Config):
ignored_layers=ignored_layers) ignored_layers=ignored_layers)
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "ptpc_fp8" return "ptpc_fp8"
@classmethod @classmethod
......
...@@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter ...@@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
...@@ -84,7 +85,7 @@ class QQQConfig(QuantizationConfig): ...@@ -84,7 +85,7 @@ class QQQConfig(QuantizationConfig):
self.weight_bits, self.group_size) self.weight_bits, self.group_size)
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "qqq" return "qqq"
@classmethod @classmethod
......
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
...@@ -47,7 +48,7 @@ class QuarkConfig(QuantizationConfig): ...@@ -47,7 +48,7 @@ class QuarkConfig(QuantizationConfig):
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 70 return 70
def get_name(self) -> str: def get_name(self) -> QuantizationMethods:
return "quark" return "quark"
def get_quant_method(self, layer: torch.nn.Module, def get_quant_method(self, layer: torch.nn.Module,
......
...@@ -6,6 +6,7 @@ import torch.nn.functional as F ...@@ -6,6 +6,7 @@ import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
...@@ -20,7 +21,7 @@ class TorchAOConfig(QuantizationConfig): ...@@ -20,7 +21,7 @@ class TorchAOConfig(QuantizationConfig):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"TorchAOConfig({self.torchao_config})" return f"TorchAOConfig({self.torchao_config})"
def get_name(self) -> str: def get_name(self) -> QuantizationMethods:
return "torchao" return "torchao"
def get_supported_act_dtypes(self) -> List[torch.dtype]: def get_supported_act_dtypes(self) -> List[torch.dtype]:
......
...@@ -7,6 +7,7 @@ from torch.nn import Module ...@@ -7,6 +7,7 @@ from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.parameter import ModelWeightParameter from vllm.model_executor.parameter import ModelWeightParameter
...@@ -27,7 +28,7 @@ class Int8TpuConfig(QuantizationConfig): ...@@ -27,7 +28,7 @@ class Int8TpuConfig(QuantizationConfig):
f"Unsupported activation scheme {activation_scheme}") f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme self.activation_scheme = activation_scheme
def get_name(self) -> str: def get_name(self) -> QuantizationMethods:
return "tpu_int8" return "tpu_int8"
def get_supported_act_dtypes(self) -> List[torch.dtype]: def get_supported_act_dtypes(self) -> List[torch.dtype]:
......
...@@ -1496,7 +1496,7 @@ def get_rope( ...@@ -1496,7 +1496,7 @@ def get_rope(
if key in _ROPE_DICT: if key in _ROPE_DICT:
return _ROPE_DICT[key] return _ROPE_DICT[key]
if rope_scaling is None: if not rope_scaling:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype) is_neox_style, dtype)
else: else:
......
...@@ -180,7 +180,6 @@ def _get_neuron_config_after_override(default_neuron_config, ...@@ -180,7 +180,6 @@ def _get_neuron_config_after_override(default_neuron_config,
NeuronConfig, QuantizationConfig, NeuronConfig, QuantizationConfig,
SparseAttnConfig) SparseAttnConfig)
overridden_neuron_config = overridden_neuron_config or {}
sparse_attn = overridden_neuron_config.pop("sparse_attn", {}) sparse_attn = overridden_neuron_config.pop("sparse_attn", {})
if sparse_attn: if sparse_attn:
overridden_neuron_config["sparse_attn"] = SparseAttnConfig( overridden_neuron_config["sparse_attn"] = SparseAttnConfig(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment