Unverified Commit 13698db6 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Improve configs - `ModelConfig` (#17130)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 2c4f59af
...@@ -738,7 +738,7 @@ class VllmRunner: ...@@ -738,7 +738,7 @@ class VllmRunner:
- `block_size`: Set to `16` instead of `None` to reduce memory usage. - `block_size`: Set to `16` instead of `None` to reduce memory usage.
- `enable_chunked_prefill`: Set to `False` instead of `None` for - `enable_chunked_prefill`: Set to `False` instead of `None` for
test reproducibility. test reproducibility.
- `enforce_eager`: Set to `False` instead of `None` to test CUDA graph. - `enforce_eager`: Set to `False` to test CUDA graph.
""" """
def __init__( def __init__(
......
...@@ -8,7 +8,7 @@ from typing import Literal, Optional ...@@ -8,7 +8,7 @@ from typing import Literal, Optional
import pytest import pytest
from vllm.config import PoolerConfig, config from vllm.config import config
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
get_type, is_not_builtin, is_type, get_type, is_not_builtin, is_type,
literal_to_kwargs, nullable_kvs, literal_to_kwargs, nullable_kvs,
...@@ -222,17 +222,6 @@ def test_prefix_cache_default(): ...@@ -222,17 +222,6 @@ def test_prefix_cache_default():
assert not engine_args.enable_prefix_caching assert not engine_args.enable_prefix_caching
def test_valid_pooling_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args([
'--override-pooler-config',
'{"pooling_type": "MEAN"}',
])
engine_args = EngineArgs.from_cli_args(args=args)
assert engine_args.override_pooler_config == PoolerConfig(
pooling_type="MEAN", )
@pytest.mark.parametrize( @pytest.mark.parametrize(
("arg"), ("arg"),
[ [
......
...@@ -14,7 +14,7 @@ import torch.nn.functional as F ...@@ -14,7 +14,7 @@ import torch.nn.functional as F
from vllm.model_executor.layers.linear import LinearBase # noqa: E501 from vllm.model_executor.layers.linear import LinearBase # noqa: E501
from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import ( from vllm.model_executor.layers.quantization import (
get_quantization_config, register_quantization_config) QuantizationMethods, get_quantization_config, register_quantization_config)
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig) QuantizationConfig)
...@@ -54,7 +54,7 @@ class CustomQuantConfig(QuantizationConfig): ...@@ -54,7 +54,7 @@ class CustomQuantConfig(QuantizationConfig):
"""Initialize the quantization config.""" """Initialize the quantization config."""
self.num_bits = num_bits self.num_bits = num_bits
def get_name(self) -> str: def get_name(self) -> QuantizationMethods:
"""Name of the quantization method.""" """Name of the quantization method."""
return "custom_quant" return "custom_quant"
......
...@@ -185,7 +185,7 @@ def test_get_pooling_config(): ...@@ -185,7 +185,7 @@ def test_get_pooling_config():
revision=None, revision=None,
) )
pooling_config = model_config._init_pooler_config(None) pooling_config = model_config._init_pooler_config()
assert pooling_config is not None assert pooling_config is not None
assert pooling_config.normalize assert pooling_config.normalize
...@@ -205,11 +205,12 @@ def test_get_pooling_config_from_args(): ...@@ -205,11 +205,12 @@ def test_get_pooling_config_from_args():
dtype="float16", dtype="float16",
revision=None) revision=None)
override_config = PoolerConfig(pooling_type='CLS', normalize=True) override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
model_config.override_pooler_config = override_pooler_config
pooling_config = model_config._init_pooler_config(override_config) pooling_config = model_config._init_pooler_config()
assert pooling_config is not None assert pooling_config is not None
assert asdict(pooling_config) == asdict(override_config) assert asdict(pooling_config) == asdict(override_pooler_config)
@pytest.mark.skipif(current_platform.is_rocm(), @pytest.mark.skipif(current_platform.is_rocm(),
......
This diff is collapsed.
This diff is collapsed.
...@@ -13,7 +13,7 @@ from typing_extensions import TypeVar, deprecated ...@@ -13,7 +13,7 @@ from typing_extensions import TypeVar, deprecated
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score) BeamSearchSequence, get_beam_search_score)
from vllm.config import CompilationConfig from vllm.config import CompilationConfig, ModelDType, TokenizerMode
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig, from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
TaskOption) TaskOption)
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
...@@ -32,6 +32,7 @@ from vllm.logger import init_logger ...@@ -32,6 +32,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import ( from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest, LLMGuidedOptions) GuidedDecodingRequest, LLMGuidedOptions)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput, from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
PoolingRequestOutput, RequestOutput, PoolingRequestOutput, RequestOutput,
ScoringRequestOutput) ScoringRequestOutput)
...@@ -163,20 +164,20 @@ class LLM: ...@@ -163,20 +164,20 @@ class LLM:
self, self,
model: str, model: str,
tokenizer: Optional[str] = None, tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto", tokenizer_mode: TokenizerMode = "auto",
skip_tokenizer_init: bool = False, skip_tokenizer_init: bool = False,
trust_remote_code: bool = False, trust_remote_code: bool = False,
allowed_local_media_path: str = "", allowed_local_media_path: str = "",
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
dtype: str = "auto", dtype: ModelDType = "auto",
quantization: Optional[str] = None, quantization: Optional[QuantizationMethods] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
swap_space: float = 4, swap_space: float = 4,
cpu_offload_gb: float = 0, cpu_offload_gb: float = 0,
enforce_eager: Optional[bool] = None, enforce_eager: bool = False,
max_seq_len_to_capture: int = 8192, max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False, disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False, disable_async_output_proc: bool = False,
...@@ -189,12 +190,7 @@ class LLM: ...@@ -189,12 +190,7 @@ class LLM:
compilation_config: Optional[Union[int, dict[str, Any]]] = None, compilation_config: Optional[Union[int, dict[str, Any]]] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
''' """LLM constructor."""
LLM constructor.
Note: if enforce_eager is unset (enforce_eager is None)
it defaults to False.
'''
if "disable_log_stats" not in kwargs: if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True kwargs["disable_log_stats"] = True
......
...@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter ...@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
...@@ -186,7 +187,7 @@ class AQLMConfig(QuantizationConfig): ...@@ -186,7 +187,7 @@ class AQLMConfig(QuantizationConfig):
f"out_group_size={self.out_group_size})") f"out_group_size={self.out_group_size})")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "aqlm" return "aqlm"
@classmethod @classmethod
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.parameter import (GroupQuantScaleParameter, from vllm.model_executor.parameter import (GroupQuantScaleParameter,
...@@ -44,7 +45,7 @@ class AWQConfig(QuantizationConfig): ...@@ -44,7 +45,7 @@ class AWQConfig(QuantizationConfig):
f"zero_point={self.zero_point}, " f"zero_point={self.zero_point}, "
f"modules_to_not_convert={self.modules_to_not_convert})") f"modules_to_not_convert={self.modules_to_not_convert})")
def get_name(self) -> str: def get_name(self) -> QuantizationMethods:
return "awq" return "awq"
def get_supported_act_dtypes(self) -> List[torch.dtype]: def get_supported_act_dtypes(self) -> List[torch.dtype]:
......
...@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.layer import ( ...@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod, UnquantizedLinearMethod,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.awq import (AWQConfig, from vllm.model_executor.layers.quantization.awq import (AWQConfig,
is_layer_skipped_awq) is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -73,7 +74,7 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -73,7 +74,7 @@ class AWQMarlinConfig(QuantizationConfig):
f"modules_to_not_convert={self.modules_to_not_convert})") f"modules_to_not_convert={self.modules_to_not_convert})")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "awq_marlin" return "awq_marlin"
@classmethod @classmethod
...@@ -101,8 +102,8 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -101,8 +102,8 @@ class AWQMarlinConfig(QuantizationConfig):
modules_to_not_convert, config) modules_to_not_convert, config)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(
user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg) can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin" is_valid_user_quant = (user_quant is None or user_quant == "marlin"
or user_quant == "awq_marlin") or user_quant == "awq_marlin")
......
...@@ -2,11 +2,16 @@ ...@@ -2,11 +2,16 @@
import inspect import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
import torch import torch
from torch import nn from torch import nn
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
else:
QuantizationMethods = str
class QuantizeMethodBase(ABC): class QuantizeMethodBase(ABC):
"""Base class for different quantized methods.""" """Base class for different quantized methods."""
...@@ -66,7 +71,7 @@ class QuantizationConfig(ABC): ...@@ -66,7 +71,7 @@ class QuantizationConfig(ABC):
self.packed_modules_mapping: Dict[str, List[str]] = dict() self.packed_modules_mapping: Dict[str, List[str]] = dict()
@abstractmethod @abstractmethod
def get_name(self) -> str: def get_name(self) -> QuantizationMethods:
"""Name of the quantization method.""" """Name of the quantization method."""
raise NotImplementedError raise NotImplementedError
...@@ -99,8 +104,8 @@ class QuantizationConfig(ABC): ...@@ -99,8 +104,8 @@ class QuantizationConfig(ABC):
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(
user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
""" """
Detects if this quantization method can support a given checkpoint Detects if this quantization method can support a given checkpoint
format by overriding the user specified quantization method -- format by overriding the user specified quantization method --
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
...@@ -100,7 +101,7 @@ class BitBLASConfig(QuantizationConfig): ...@@ -100,7 +101,7 @@ class BitBLASConfig(QuantizationConfig):
f"quant_method={self.quant_method})") f"quant_method={self.quant_method})")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "bitblas" return "bitblas"
@classmethod @classmethod
...@@ -139,8 +140,8 @@ class BitBLASConfig(QuantizationConfig): ...@@ -139,8 +140,8 @@ class BitBLASConfig(QuantizationConfig):
lm_head_quantized) lm_head_quantized)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(
user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
# compat: autogptq >=0.8.0 use checkpoint_format: str # compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_bitblas_format: bool # compat: autogptq <=0.7.1 is_bitblas_format: bool
is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas" is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas"
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod, UnquantizedLinearMethod,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
...@@ -56,7 +57,7 @@ class BitsAndBytesConfig(QuantizationConfig): ...@@ -56,7 +57,7 @@ class BitsAndBytesConfig(QuantizationConfig):
f"llm_int8_skip_modules={self.llm_int8_skip_modules})") f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
@classmethod @classmethod
def get_name(self) -> str: def get_name(self) -> QuantizationMethods:
return "bitsandbytes" return "bitsandbytes"
@classmethod @classmethod
......
...@@ -16,6 +16,7 @@ from vllm.logger import init_logger ...@@ -16,6 +16,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
...@@ -71,7 +72,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -71,7 +72,7 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 70 return 70
def get_name(self) -> str: def get_name(self) -> QuantizationMethods:
return "compressed-tensors" return "compressed-tensors"
def get_quant_method( def get_quant_method(
......
...@@ -7,6 +7,7 @@ import torch.nn as nn ...@@ -7,6 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
...@@ -41,8 +42,8 @@ class DeepSpeedFPConfig(QuantizationConfig): ...@@ -41,8 +42,8 @@ class DeepSpeedFPConfig(QuantizationConfig):
f"group_size={self.group_size}") f"group_size={self.group_size}")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "DeepSpeedFP" return "deepspeedfp"
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig": def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig":
......
...@@ -8,6 +8,7 @@ from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group ...@@ -8,6 +8,7 @@ from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
...@@ -20,7 +21,7 @@ class ExpertsInt8Config(QuantizationConfig): ...@@ -20,7 +21,7 @@ class ExpertsInt8Config(QuantizationConfig):
super().__init__() super().__init__()
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "experts_int8" return "experts_int8"
@classmethod @classmethod
......
...@@ -9,6 +9,7 @@ from torch.nn.parameter import Parameter ...@@ -9,6 +9,7 @@ from torch.nn.parameter import Parameter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
...@@ -38,7 +39,7 @@ class FBGEMMFp8Config(QuantizationConfig): ...@@ -38,7 +39,7 @@ class FBGEMMFp8Config(QuantizationConfig):
self.fp8_linear = Fp8LinearOp() self.fp8_linear = Fp8LinearOp()
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "fbgemm_fp8" return "fbgemm_fp8"
@classmethod @classmethod
......
...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, ...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
...@@ -83,7 +84,7 @@ class Fp8Config(QuantizationConfig): ...@@ -83,7 +84,7 @@ class Fp8Config(QuantizationConfig):
self.weight_block_size = weight_block_size self.weight_block_size = weight_block_size
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "fp8" return "fp8"
@classmethod @classmethod
......
...@@ -13,6 +13,7 @@ from vllm.model_executor.layers.activation import SiluAndMul ...@@ -13,6 +13,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase) FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -31,7 +32,7 @@ class GGUFConfig(QuantizationConfig): ...@@ -31,7 +32,7 @@ class GGUFConfig(QuantizationConfig):
def __repr__(self) -> str: def __repr__(self) -> str:
return ("GGUFConfig()") return ("GGUFConfig()")
def get_name(self) -> str: def get_name(self) -> QuantizationMethods:
return "gguf" return "gguf"
def get_supported_act_dtypes(self) -> List[torch.dtype]: def get_supported_act_dtypes(self) -> List[torch.dtype]:
......
...@@ -10,6 +10,7 @@ from torch.nn.parameter import Parameter ...@@ -10,6 +10,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.gptq_utils import ( from vllm.model_executor.layers.quantization.utils.gptq_utils import (
...@@ -79,7 +80,7 @@ class GPTQConfig(QuantizationConfig): ...@@ -79,7 +80,7 @@ class GPTQConfig(QuantizationConfig):
f"dynamic={self.dynamic}") f"dynamic={self.dynamic}")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "gptq" return "gptq"
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment