Unverified Commit 13698db6 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Improve configs - `ModelConfig` (#17130)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 2c4f59af
......@@ -738,7 +738,7 @@ class VllmRunner:
- `block_size`: Set to `16` instead of `None` to reduce memory usage.
- `enable_chunked_prefill`: Set to `False` instead of `None` for
test reproducibility.
- `enforce_eager`: Set to `False` instead of `None` to test CUDA graph.
- `enforce_eager`: Set to `False` to test CUDA graph.
"""
def __init__(
......
......@@ -8,7 +8,7 @@ from typing import Literal, Optional
import pytest
from vllm.config import PoolerConfig, config
from vllm.config import config
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
get_type, is_not_builtin, is_type,
literal_to_kwargs, nullable_kvs,
......@@ -222,17 +222,6 @@ def test_prefix_cache_default():
assert not engine_args.enable_prefix_caching
def test_valid_pooling_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args([
'--override-pooler-config',
'{"pooling_type": "MEAN"}',
])
engine_args = EngineArgs.from_cli_args(args=args)
assert engine_args.override_pooler_config == PoolerConfig(
pooling_type="MEAN", )
@pytest.mark.parametrize(
("arg"),
[
......
......@@ -14,7 +14,7 @@ import torch.nn.functional as F
from vllm.model_executor.layers.linear import LinearBase # noqa: E501
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import (
get_quantization_config, register_quantization_config)
QuantizationMethods, get_quantization_config, register_quantization_config)
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig)
......@@ -54,7 +54,7 @@ class CustomQuantConfig(QuantizationConfig):
"""Initialize the quantization config."""
self.num_bits = num_bits
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
"""Name of the quantization method."""
return "custom_quant"
......
......@@ -185,7 +185,7 @@ def test_get_pooling_config():
revision=None,
)
pooling_config = model_config._init_pooler_config(None)
pooling_config = model_config._init_pooler_config()
assert pooling_config is not None
assert pooling_config.normalize
......@@ -205,11 +205,12 @@ def test_get_pooling_config_from_args():
dtype="float16",
revision=None)
override_config = PoolerConfig(pooling_type='CLS', normalize=True)
override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
model_config.override_pooler_config = override_pooler_config
pooling_config = model_config._init_pooler_config(override_config)
pooling_config = model_config._init_pooler_config()
assert pooling_config is not None
assert asdict(pooling_config) == asdict(override_config)
assert asdict(pooling_config) == asdict(override_pooler_config)
@pytest.mark.skipif(current_platform.is_rocm(),
......
This diff is collapsed.
This diff is collapsed.
......@@ -13,7 +13,7 @@ from typing_extensions import TypeVar, deprecated
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.config import CompilationConfig
from vllm.config import CompilationConfig, ModelDType, TokenizerMode
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
TaskOption)
from vllm.engine.llm_engine import LLMEngine
......@@ -32,6 +32,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest, LLMGuidedOptions)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
PoolingRequestOutput, RequestOutput,
ScoringRequestOutput)
......@@ -163,20 +164,20 @@ class LLM:
self,
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
tokenizer_mode: TokenizerMode = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
allowed_local_media_path: str = "",
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
dtype: ModelDType = "auto",
quantization: Optional[QuantizationMethods] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: Optional[int] = None,
gpu_memory_utilization: float = 0.9,
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: Optional[bool] = None,
enforce_eager: bool = False,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
......@@ -189,12 +190,7 @@ class LLM:
compilation_config: Optional[Union[int, dict[str, Any]]] = None,
**kwargs,
) -> None:
'''
LLM constructor.
Note: if enforce_eager is unset (enforce_eager is None)
it defaults to False.
'''
"""LLM constructor."""
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
......
......@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
......@@ -186,7 +187,7 @@ class AQLMConfig(QuantizationConfig):
f"out_group_size={self.out_group_size})")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "aqlm"
@classmethod
......
......@@ -7,6 +7,7 @@ import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
......@@ -44,7 +45,7 @@ class AWQConfig(QuantizationConfig):
f"zero_point={self.zero_point}, "
f"modules_to_not_convert={self.modules_to_not_convert})")
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
return "awq"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
......
......@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
......@@ -73,7 +74,7 @@ class AWQMarlinConfig(QuantizationConfig):
f"modules_to_not_convert={self.modules_to_not_convert})")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "awq_marlin"
@classmethod
......@@ -101,8 +102,8 @@ class AWQMarlinConfig(QuantizationConfig):
modules_to_not_convert, config)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
or user_quant == "awq_marlin")
......
......@@ -2,11 +2,16 @@
import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
import torch
from torch import nn
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
else:
QuantizationMethods = str
class QuantizeMethodBase(ABC):
"""Base class for different quantized methods."""
......@@ -66,7 +71,7 @@ class QuantizationConfig(ABC):
self.packed_modules_mapping: Dict[str, List[str]] = dict()
@abstractmethod
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
"""Name of the quantization method."""
raise NotImplementedError
......@@ -99,8 +104,8 @@ class QuantizationConfig(ABC):
raise NotImplementedError
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
"""
Detects if this quantization method can support a given checkpoint
format by overriding the user specified quantization method --
......
......@@ -5,6 +5,7 @@ import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
......@@ -100,7 +101,7 @@ class BitBLASConfig(QuantizationConfig):
f"quant_method={self.quant_method})")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "bitblas"
@classmethod
......@@ -139,8 +140,8 @@ class BitBLASConfig(QuantizationConfig):
lm_head_quantized)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_bitblas_format: bool
is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas"
......
......@@ -7,6 +7,7 @@ import torch
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import direct_register_custom_op
......@@ -56,7 +57,7 @@ class BitsAndBytesConfig(QuantizationConfig):
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
@classmethod
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
return "bitsandbytes"
@classmethod
......
......@@ -16,6 +16,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
......@@ -71,7 +72,7 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_min_capability(cls) -> int:
return 70
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
return "compressed-tensors"
def get_quant_method(
......
......@@ -7,6 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
......@@ -41,8 +42,8 @@ class DeepSpeedFPConfig(QuantizationConfig):
f"group_size={self.group_size}")
@classmethod
def get_name(cls) -> str:
return "DeepSpeedFP"
def get_name(cls) -> QuantizationMethods:
return "deepspeedfp"
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig":
......
......@@ -8,6 +8,7 @@ from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
......@@ -20,7 +21,7 @@ class ExpertsInt8Config(QuantizationConfig):
super().__init__()
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "experts_int8"
@classmethod
......
......@@ -9,6 +9,7 @@ from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
......@@ -38,7 +39,7 @@ class FBGEMMFp8Config(QuantizationConfig):
self.fp8_linear = Fp8LinearOp()
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "fbgemm_fp8"
@classmethod
......
......@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
......@@ -83,7 +84,7 @@ class Fp8Config(QuantizationConfig):
self.weight_block_size = weight_block_size
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "fp8"
@classmethod
......
......@@ -13,6 +13,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -31,7 +32,7 @@ class GGUFConfig(QuantizationConfig):
def __repr__(self) -> str:
return ("GGUFConfig()")
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
return "gguf"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
......
......@@ -10,6 +10,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
......@@ -79,7 +80,7 @@ class GPTQConfig(QuantizationConfig):
f"dynamic={self.dynamic}")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "gptq"
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment