Commit ec5e299c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.3' into v0.7.3-dev

parents 47bd229c ed6e9075
......@@ -169,6 +169,7 @@ class AQLMConfig(QuantizationConfig):
num_codebooks: int,
out_group_size: int,
) -> None:
super().__init__()
self.in_group_size = in_group_size
self.nbits_per_codebook = nbits_per_codebook
self.num_codebooks = num_codebooks
......
......@@ -96,6 +96,7 @@ class AWQConfig(QuantizationConfig):
zero_point: bool,
modules_to_not_convert: Optional[List[str]] = None,
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
......
......@@ -13,15 +13,18 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
marlin_permute_scales, moe_awq_to_marlin_zero_points,
verify_marlin_supported, verify_marlin_supports_shape)
check_marlin_supports_layer, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_moe_permute_scales, marlin_permute_scales,
moe_awq_to_marlin_zero_points, verify_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
......@@ -40,18 +43,18 @@ class AWQMarlinConfig(QuantizationConfig):
8: scalar_types.uint8,
}
def __init__(self,
weight_bits: int,
group_size: int,
zero_point: bool,
def __init__(self, weight_bits: int, group_size: int, zero_point: bool,
lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]] = None) -> None:
modules_to_not_convert: Optional[List[str]],
full_config: Dict[str, Any]) -> None:
super().__init__()
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.zero_point = zero_point
self.lm_head_quantized = lm_head_quantized
self.weight_bits = weight_bits
self.modules_to_not_convert = modules_to_not_convert or []
self.full_config = full_config
if self.weight_bits not in self.TYPE_MAP:
raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
......@@ -96,7 +99,7 @@ class AWQMarlinConfig(QuantizationConfig):
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
return cls(weight_bits, group_size, zero_point, lm_head_quantized,
modules_to_not_convert)
modules_to_not_convert, config)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
......@@ -124,9 +127,21 @@ class AWQMarlinConfig(QuantizationConfig):
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
# Check if the layer is supported by AWQMarlin.
if not check_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
f"Layer '{prefix}' is not supported by AWQMarlin. "
"Falling back to unoptimized AWQ kernels.")
return AWQConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self)
if layer.num_experts > 32:
# For MoEs with many experts the moe_wna16 kernel is faster
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
return AWQMoEMethod(self)
return None
@classmethod
......
......@@ -2,7 +2,7 @@
import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Mapping, Optional, Type
from typing import Any, Dict, List, Optional, Type
import torch
from torch import nn
......@@ -59,7 +59,11 @@ def method_has_implemented_embedding(
class QuantizationConfig(ABC):
"""Base class for quantization configs."""
packed_modules_mapping: Mapping[str, List[str]] = dict()
def __init__(self):
super().__init__()
# mapping is updated by models as they initialize
self.packed_modules_mapping: Dict[str, List[str]] = dict()
@abstractmethod
def get_name(self) -> str:
......
......@@ -30,7 +30,7 @@ class BitsAndBytesConfig(QuantizationConfig):
llm_int8_skip_modules: Optional[List[str]] = None,
llm_int8_threshold: float = 6.0,
) -> None:
super().__init__()
self.load_in_8bit = load_in_8bit
self.load_in_4bit = load_in_4bit
self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
......@@ -133,8 +133,16 @@ def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
components = prefix.split('.')
# Check if any of the skip modules exactly matches any component
return any(module_name in components
for module_name in llm_int8_skip_modules)
substr_check = any(module_name in components
for module_name in llm_int8_skip_modules)
# Allow certain layers to not be quantized
set_components = set(".".join(components[:i + 1])
for i in range(len(components)))
set_llm_int8_skip_modules = set(llm_int8_skip_modules)
prefix_check = len(set_llm_int8_skip_modules & set_components) != 0
return substr_check or prefix_check
class BitsAndBytesLinearMethod(LinearMethodBase):
......
......@@ -52,7 +52,7 @@ class CompressedTensorsConfig(QuantizationConfig):
kv_cache_scheme: Optional[Dict[str, Any]] = None,
config: Optional[Dict[str, Any]] = None,
):
super().__init__()
self.ignore = ignore
self.quant_format = quant_format
# Map from [target -> scheme]
......@@ -409,13 +409,6 @@ class CompressedTensorsConfig(QuantizationConfig):
if self.supports_cutlass_24(weight_quant=weight_quant,
input_quant=input_quant,
sparsity_scheme=sparsity_scheme):
# FIXME(tlrmchlsmth): layers using W16A16 CUTLASS 2:4 sparse kernels
# currently produce bad output in some cases
if weight_quant is None:
logger.warning_once(
"CompressedTensors24 scheme is disabled for the w16a16 "
"case. Falling back to UnquantizedLinearMethod")
return None
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
model_compression_config = (None if sparsity_scheme is None
......
......@@ -64,7 +64,6 @@ class CompressedTensors24(CompressedTensorsScheme):
"Sparse CUTLASS not supported. vLLM must be built with "
"CUDA 12.2 or later to use this feature")
self.output_dtype = params_dtype
layer.logical_widths = output_partition_sizes
layer.input_size = input_size
layer.input_size_per_partition = input_size_per_partition
......@@ -205,6 +204,11 @@ class CompressedTensors24(CompressedTensorsScheme):
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False)
# Set all negative zero values to 0 prior to compression
if (layer.weight.dtype.is_floating_point
and layer.weight.dtype.itemsize >= 2):
layer.weight.data[layer.weight.data == -0.0] = 0.0
w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data)
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
layer.meta = torch.nn.Parameter(meta, requires_grad=False)
......@@ -254,9 +258,10 @@ class CompressedTensors24(CompressedTensorsScheme):
bt_meta=layer.meta,
scale_a=input_scale,
scale_b=layer.weight_scale,
out_dtype=self.output_dtype,
out_dtype=x.dtype,
bias=bias,
)
assert out.is_contiguous()
return out
......
......@@ -9,8 +9,8 @@ from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale)
apply_fp8_linear, cutlass_fp8_supported, maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
......@@ -93,6 +93,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
......
......@@ -25,6 +25,7 @@ class DeepSpeedFPConfig(QuantizationConfig):
weight_bits: int = 8,
group_size: int = 512,
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.valid_types = [torch.bfloat16, torch.float16]
......
......@@ -17,7 +17,7 @@ class ExpertsInt8Config(QuantizationConfig):
"""Config class for Int8 experts quantization."""
def __init__(self) -> None:
pass
super().__init__()
@classmethod
def get_name(cls) -> str:
......
......@@ -17,7 +17,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz)
apply_fp8_linear, maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter)
from vllm.platforms import current_platform
......@@ -29,6 +30,7 @@ class FBGEMMFp8Config(QuantizationConfig):
"""Config class for FBGEMM Fp8."""
def __init__(self, ignore_list: List[str], input_scale_ub: float):
super().__init__()
self.ignore_list = ignore_list if ignore_list else []
self.input_scale_ub = input_scale_ub
......@@ -83,6 +85,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
maybe_create_device_identity()
weight_loader = extra_weight_attrs.get("weight_loader")
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)
......
......@@ -24,8 +24,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, apply_fp8_linear, convert_to_channelwise,
cutlass_block_fp8_supported, cutlass_fp8_supported,
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
requantize_with_max_scale)
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
per_tensor_dequantize, requantize_with_max_scale)
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
......@@ -47,6 +47,7 @@ class Fp8Config(QuantizationConfig):
ignored_layers: Optional[List[str]] = None,
weight_block_size: Optional[List[int]] = None,
) -> None:
super().__init__()
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
logger.warning("Detected fp8 checkpoint. Please note that the "
......@@ -161,6 +162,8 @@ class Fp8LinearMethod(LinearMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
......
......@@ -20,7 +20,7 @@ class GGUFConfig(QuantizationConfig):
"""Config class for GGUF."""
def __init__(self, ) -> None:
pass
super().__init__()
def __repr__(self) -> str:
return ("GGUFConfig()")
......
......@@ -3,16 +3,17 @@
import enum
from enum import Enum
from fractions import Fraction
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
......@@ -32,7 +33,34 @@ class GPTQConfig(QuantizationConfig):
group_size: int,
desc_act: bool,
lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]],
) -> None:
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
# Format is Dict[str, Dict] where key is a regex string that can
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
# matching of a module.
# Default to positive match, override base quant config mode, if no
# prefix is used. Value is in dict format of field key and override
# value.
# Negative matching will skip quantization init for this module
# entirely:
# non-quantized inference. More details and quantization examples can be
# found at: https://github.com/ModelCloud/GPTQModel
# Example:
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
# dynamic = {
# #`.*\.` matches the layers_node prefix
# # positive match layer 10-15
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
# # positive match layer 16-21
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
# }
super().__init__()
self.dynamic = dynamic
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
......@@ -47,7 +75,8 @@ class GPTQConfig(QuantizationConfig):
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}),"
f"lm_head_quantized={self.lm_head_quantized}")
f"lm_head_quantized={self.lm_head_quantized}), "
f"dynamic={self.dynamic}")
@classmethod
def get_name(cls) -> str:
......@@ -68,19 +97,20 @@ class GPTQConfig(QuantizationConfig):
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
dynamic = {} if dynamic is None else dynamic
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, lm_head_quantized)
return cls(weight_bits, group_size, desc_act, lm_head_quantized,
dynamic)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQLinearMethod"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return GPTQLinearMethod(self)
return None
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
class ExllamaState(Enum):
......
......@@ -9,17 +9,19 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, marlin_moe_permute_scales,
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
......@@ -40,23 +42,49 @@ class GPTQMarlinConfig(QuantizationConfig):
(8, True): scalar_types.uint8b128,
}
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
) -> None:
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool, lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]],
full_config: Dict[str, Any]) -> None:
super().__init__()
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
# Format is Dict[str, Dict] where key is a regex string that can
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
# matching of a module.
# Default to positive match, override base quant config mode, if no
# prefix is used. Value is in dict format of field key and override
# value.
# Negative matching will skip quantization init for this module
# entirely:
# non-quantized inference. More details and quantization examples can be
# found at: https://github.com/ModelCloud/GPTQModel
# Example:
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
# dynamic = {
# #`.*\.` matches the layers_node prefix
# # positive match layer 10-15
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
# # positive match layer 16-21
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
# }
self.dynamic = dynamic
self.weight_bits = weight_bits
self.is_sym = is_sym
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.desc_act = desc_act
self.lm_head_quantized = lm_head_quantized
self.full_config = full_config
if (weight_bits, is_sym) not in self.TYPE_MAP:
raise ValueError("Unsupported quantization config: "
......@@ -68,7 +96,8 @@ class GPTQMarlinConfig(QuantizationConfig):
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}, "
f"lm_head_quantized={self.lm_head_quantized})")
f"lm_head_quantized={self.lm_head_quantized}), "
f"dynamic={self.dynamic}")
@classmethod
def get_name(cls) -> str:
......@@ -88,6 +117,9 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
dynamic = {} if dynamic is None else dynamic
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
......@@ -95,7 +127,7 @@ class GPTQMarlinConfig(QuantizationConfig):
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym,
lm_head_quantized)
lm_head_quantized, dynamic, config)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
......@@ -118,19 +150,20 @@ class GPTQMarlinConfig(QuantizationConfig):
" faster inference")
return None
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]:
if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
and self.lm_head_quantized):
return GPTQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self)
return None
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, FusedMoE):
if layer.num_experts > 32:
# For MoEs with many experts the moe_wna16 kernel is faster
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
return GPTQMarlinMoEMethod(self)
return get_linear_quant_method(self, layer, prefix,
GPTQMarlinLinearMethod)
@classmethod
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
......@@ -143,7 +176,7 @@ class GPTQMarlinConfig(QuantizationConfig):
if quant_method != "gptq":
return False
# If we cannot find the info needed in the config, cannot convert.
# Marlin conversion is only valid if required properties are found
if (num_bits is None or group_size is None or sym is None
or desc_act is None):
return False
......@@ -323,13 +356,18 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Currently assuming is_k_full is always True
# (input size per partition is the same as full input size)
# Supports only sym for now (no zp)
intermediate_size_full = extra_weight_attrs.pop(
"intermediate_size_full")
self.is_k_full = (not self.quant_config.desc_act) or (
intermediate_size_per_partition == intermediate_size_full)
if self.quant_config.group_size != -1:
scales_size13 = hidden_size // self.quant_config.group_size
scales_size2 = (intermediate_size_per_partition //
self.quant_config.group_size)
w2_scales_size = (intermediate_size_full
if self.quant_config.desc_act else
intermediate_size_per_partition)
scales_size2 = (w2_scales_size // self.quant_config.group_size)
strategy = FusedMoeWeightScaleSupported.GROUP.value
else:
scales_size13 = 1
......@@ -385,6 +423,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)
# dont shard the w2 scales when running act order
set_weight_attrs(w2_scales,
{"load_full_w2": self.quant_config.desc_act})
# up_proj scales
w13_qzeros = torch.nn.Parameter(
torch.empty(num_experts,
......@@ -406,6 +447,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
# dont shard the w2 scales when running act order
set_weight_attrs(w2_qzeros,
{"load_full_w2": self.quant_config.desc_act})
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
......@@ -575,4 +619,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.quant_config.quant_type.size_bits,
).to(orig_dtype)
is_k_full=self.is_k_full).to(orig_dtype)
......@@ -38,6 +38,7 @@ class GPTQMarlin24Config(QuantizationConfig):
weight_bits: int,
group_size: int,
) -> None:
super().__init__()
quant_type = {
4: scalar_types.uint4b8,
8: scalar_types.uint8b128,
......
......@@ -33,6 +33,7 @@ class HQQMarlinConfig(QuantizationConfig):
group_size: int,
skip_modules: Optional[List[str]] = None,
) -> None:
super().__init__()
assert group_size == 64, ("The only supported HQQ group size is "
"currently 64.")
assert weight_bits == 4, ("The only supported HQQ quantization "
......
......@@ -35,6 +35,7 @@ class IPEXConfig(QuantizationConfig):
desc_act: Optional[bool] = None,
lm_head_quantized: Optional[bool] = None,
) -> None:
super().__init__()
self.method = method
self.weight_bits = weight_bits
self.group_size = group_size
......
......@@ -28,6 +28,7 @@ class ModelOptFp8Config(QuantizationConfig):
self,
is_checkpoint_fp8_serialized: bool = False,
) -> None:
super().__init__()
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
......
......@@ -9,13 +9,10 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supports_layer)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
......@@ -27,6 +24,7 @@ class MoeWNA16Config(QuantizationConfig):
group_size: int, has_zp: bool, lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]],
full_config: Dict[str, Any]) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.has_zp = has_zp
......@@ -35,6 +33,12 @@ class MoeWNA16Config(QuantizationConfig):
self.linear_quant_method = linear_quant_method
self.full_config = full_config
self.use_marlin = False
# Avoid circular import
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig)
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
if self.linear_quant_method == "gptq":
self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(
full_config)
......@@ -87,8 +91,8 @@ class MoeWNA16Config(QuantizationConfig):
modules_to_not_convert = []
elif linear_quant_method == "awq":
has_zp = cls.get_from_keys(config, ["zero_point"])
modules_to_not_convert = cls.get_from_keys(
config, ["modules_to_not_convert"])
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
else:
raise ValueError("moe_wna16 only support gptq and awq.")
......@@ -113,6 +117,8 @@ class MoeWNA16Config(QuantizationConfig):
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
# Avoid circular import
from vllm.model_executor.layers.quantization.awq import AWQConfig
awq_min_capability = AWQConfig.get_min_capability()
gptq_compatible = quant_method == "gptq" and \
......@@ -127,6 +133,13 @@ class MoeWNA16Config(QuantizationConfig):
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
elif isinstance(layer, LinearBase):
# Avoid circular import
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig)
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
if self.linear_quant_method == "gptq":
if self.use_marlin:
return GPTQMarlinConfig.from_config(
......@@ -135,7 +148,8 @@ class MoeWNA16Config(QuantizationConfig):
return GPTQConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
elif self.linear_quant_method == "awq":
if self.use_marlin:
if self.use_marlin and check_marlin_supports_layer(
layer, self.group_size):
return AWQMarlinConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment