Unverified Commit a8d604ca authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Misc] Disambiguate quantized types via a new ScalarType (#6396)

parent b482b9a5
...@@ -9,9 +9,13 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( ...@@ -9,9 +9,13 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N) GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.scalar_type import scalar_types
__all__ = ["CompressedTensorsW4A16Sparse24"] __all__ = ["CompressedTensorsW4A16Sparse24"]
W4A16SPARSE24_SUPPORTED_BITS = [4] W4A16SPARSE24_SUPPORTED_TYPES_MAP = {
4: scalar_types.uint4b8,
}
W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
...@@ -22,9 +26,15 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): ...@@ -22,9 +26,15 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
group_size: Optional[int] = None): group_size: Optional[int] = None):
self.strategy = strategy self.strategy = strategy
self.group_size = group_size self.group_size = group_size
self.num_bits = num_bits
self.tile_size = 16 self.tile_size = 16
if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}. "
f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}")
self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits]
if self.strategy == "group" and self.group_size is None: if self.strategy == "group" and self.group_size is None:
raise ValueError( raise ValueError(
"group_size must be given when using strategy group") "group_size must be given when using strategy group")
...@@ -43,7 +53,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): ...@@ -43,7 +53,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
params_dtype: torch.dtype, weight_loader: Callable, params_dtype: torch.dtype, weight_loader: Callable,
**kwargs): **kwargs):
pack_factor = 32 // self.num_bits pack_factor = 32 // self.quant_type.size_bits
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
qweight = Parameter( qweight = Parameter(
...@@ -138,7 +148,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): ...@@ -138,7 +148,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
size_n = scales.shape[1] size_n = scales.shape[1]
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
workspace, self.num_bits, size_m, workspace, self.quant_type, size_m,
size_n, size_k) size_n, size_k)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
......
...@@ -8,12 +8,17 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( ...@@ -8,12 +8,17 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme) CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, replace_tensor, verify_gptq_marlin_supported, marlin_permute_scales, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape) verify_marlin_supports_shape)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.scalar_type import scalar_types
__all__ = ["CompressedTensorsWNA16"] __all__ = ["CompressedTensorsWNA16"]
WNA16_SUPPORTED_BITS = [4, 8] WNA16_SUPPORTED_TYPES_MAP = {
4: scalar_types.uint4b8,
8: scalar_types.uint8b128,
}
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsWNA16(CompressedTensorsScheme): class CompressedTensorsWNA16(CompressedTensorsScheme):
...@@ -22,8 +27,8 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -22,8 +27,8 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
strategy: str, strategy: str,
num_bits: int, num_bits: int,
group_size: Optional[int] = None): group_size: Optional[int] = None):
self.num_bits = num_bits
self.pack_factor = 32 // self.num_bits self.pack_factor = 32 // num_bits
self.strategy = strategy self.strategy = strategy
self.group_size: int self.group_size: int
...@@ -37,10 +42,16 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -37,10 +42,16 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
else: else:
self.group_size = group_size self.group_size = group_size
if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}. "
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]
# Verify supported on platform. # Verify supported on platform.
verify_gptq_marlin_supported(num_bits=self.num_bits, verify_marlin_supported(quant_type=self.quant_type,
group_size=self.group_size, group_size=self.group_size)
is_sym=True)
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
...@@ -150,7 +161,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -150,7 +161,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
perm=layer.g_idx_sort_indices, perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
num_bits=self.num_bits) num_bits=self.quant_type.size_bits)
replace_tensor(layer, "weight_packed", marlin_qweight) replace_tensor(layer, "weight_packed", marlin_qweight)
# Permute scales from compressed-tensors format to marlin format. # Permute scales from compressed-tensors format to marlin format.
...@@ -172,7 +183,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -172,7 +183,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
g_idx=layer.g_idx, g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices, g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace, workspace=layer.workspace,
num_bits=self.num_bits, wtype=self.quant_type,
output_size_per_partition=layer.output_size_per_partition, output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition, input_size_per_partition=layer.input_size_per_partition,
is_k_full=True, is_k_full=True,
......
...@@ -10,11 +10,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, ...@@ -10,11 +10,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, check_gptq_marlin_supported, marlin_is_k_full, apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
verify_gptq_marlin_supported, verify_marlin_supports_shape) verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.scalar_type import scalar_types
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -22,6 +23,12 @@ logger = init_logger(__name__) ...@@ -22,6 +23,12 @@ logger = init_logger(__name__)
class GPTQMarlinConfig(QuantizationConfig): class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin""" """Config class for GPTQ Marlin"""
# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
def __init__(self, weight_bits: int, group_size: int, desc_act: bool, def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool, lm_head_quantized: bool) -> None: is_sym: bool, lm_head_quantized: bool) -> None:
if desc_act and group_size == -1: if desc_act and group_size == -1:
...@@ -29,20 +36,23 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -29,20 +36,23 @@ class GPTQMarlinConfig(QuantizationConfig):
# (since we have only one group per output channel) # (since we have only one group per output channel)
desc_act = False desc_act = False
self.weight_bits = weight_bits self.pack_factor = 32 // weight_bits # packed into int32
self.pack_factor = 32 // self.weight_bits # packed into int32
self.group_size = group_size self.group_size = group_size
self.desc_act = desc_act self.desc_act = desc_act
self.is_sym = is_sym
self.lm_head_quantized = lm_head_quantized self.lm_head_quantized = lm_head_quantized
if (weight_bits, is_sym) not in self.TYPE_MAP:
raise ValueError("Unsupported quantization config: "
f"bits={weight_bits}, sym={is_sym}")
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
# Verify supported on platform. # Verify supported on platform.
verify_gptq_marlin_supported(num_bits=self.weight_bits, verify_marlin_supported(quant_type=self.quant_type,
group_size=self.group_size, group_size=self.group_size)
is_sym=self.is_sym)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, " return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, " f"group_size={self.group_size}, "
f"desc_act={self.desc_act}, " f"desc_act={self.desc_act}, "
f"lm_head_quantized={self.lm_head_quantized})") f"lm_head_quantized={self.lm_head_quantized})")
...@@ -122,11 +132,12 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -122,11 +132,12 @@ class GPTQMarlinConfig(QuantizationConfig):
or desc_act is None): or desc_act is None):
return False return False
return check_gptq_marlin_supported( if (num_bits, sym) not in cls.TYPE_MAP:
num_bits=num_bits, return False
group_size=group_size,
is_sym=sym, return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
min_capability=cls.get_min_capability()) group_size=group_size,
min_capability=cls.get_min_capability())
class GPTQMarlinLinearMethod(LinearMethodBase): class GPTQMarlinLinearMethod(LinearMethodBase):
...@@ -293,7 +304,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -293,7 +304,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
perm=layer.g_idx_sort_indices, perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
num_bits=self.quant_config.weight_bits) num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qweight", marlin_qweight) replace_tensor(layer, "qweight", marlin_qweight)
# Permute scales from autogptq format to marlin format. # Permute scales from autogptq format to marlin format.
...@@ -319,7 +330,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -319,7 +330,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
g_idx=layer.g_idx, g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices, g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace, workspace=layer.workspace,
num_bits=self.quant_config.weight_bits, wtype=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition, output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition, input_size_per_partition=layer.input_size_per_partition,
is_k_full=layer.is_k_full, is_k_full=layer.is_k_full,
......
...@@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase ...@@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.scalar_type import scalar_types
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -17,9 +18,10 @@ GPTQ_MARLIN_24_MIN_THREAD_N = 128 ...@@ -17,9 +18,10 @@ GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K = 128 GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL = 64 GPTQ_MARLIN_24_MAX_PARALLEL = 64
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [
scalar_types.uint4b8, scalar_types.uint8b128
]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
GPTQ_MARLIN_24_SUPPORTED_SYM = [True]
class GPTQMarlin24Config(QuantizationConfig): class GPTQMarlin24Config(QuantizationConfig):
...@@ -31,14 +33,19 @@ class GPTQMarlin24Config(QuantizationConfig): ...@@ -31,14 +33,19 @@ class GPTQMarlin24Config(QuantizationConfig):
weight_bits: int, weight_bits: int,
group_size: int, group_size: int,
) -> None: ) -> None:
self.weight_bits = weight_bits quant_type = {
4: scalar_types.uint4b8,
8: scalar_types.uint8b128,
}.get(weight_bits)
self.group_size = group_size self.group_size = group_size
# Verify # Verify
if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS: if quant_type is None or \
quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES:
raise ValueError( raise ValueError(
f"Marlin_24 does not support weight_bits = {self.weight_bits}. " f"Marlin_24 does not support quant_type = {quant_type}. "
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} " f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} "
"are supported.") "are supported.")
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
raise ValueError( raise ValueError(
...@@ -46,8 +53,10 @@ class GPTQMarlin24Config(QuantizationConfig): ...@@ -46,8 +53,10 @@ class GPTQMarlin24Config(QuantizationConfig):
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} " f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
"are supported.") "are supported.")
self.quant_type = quant_type
# 4 Bits packed into 32 bit datatype. # 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // self.weight_bits self.pack_factor = 32 // self.quant_type.size_bits
# Tile size used by marlin kernels. # Tile size used by marlin kernels.
self.tile_size = 16 self.tile_size = 16
...@@ -66,8 +75,8 @@ class GPTQMarlin24Config(QuantizationConfig): ...@@ -66,8 +75,8 @@ class GPTQMarlin24Config(QuantizationConfig):
self.perm_len = 1024 self.perm_len = 1024
def __repr__(self) -> str: def __repr__(self) -> str:
return "Marlin24Config(weight_bits={}, group_size={})".format( return "Marlin24Config(quant_type={}, group_size={})".format(
self.weight_bits, self.group_size) self.quant_type, self.group_size)
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> str:
...@@ -279,7 +288,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase): ...@@ -279,7 +288,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
workspace, workspace,
self.quant_config.weight_bits, self.quant_config.quant_type,
size_m, size_n, size_k) size_m, size_n, size_k)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
from .quant_utils import pack_cols, unpack_cols from .quant_utils import pack_cols, unpack_cols
...@@ -13,7 +14,6 @@ GPTQ_MARLIN_MIN_THREAD_N = 64 ...@@ -13,7 +14,6 @@ GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128 GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16 GPTQ_MARLIN_MAX_PARALLEL = 16
MARLIN_SUPPORTED_NUM_BITS = [4, 8]
MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# In case there is a performance issue with Marlin, the variable below can be # In case there is a performance issue with Marlin, the variable below can be
...@@ -22,76 +22,70 @@ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] ...@@ -22,76 +22,70 @@ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
USE_FP32_REDUCE_DEFAULT = True USE_FP32_REDUCE_DEFAULT = True
def _check_marlin_supported(num_bits: int, group_size: int, is_sym: bool, # For binary size and compile time, we don't support the same types for with and
min_capability: Optional[int], # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
has_zp: bool) -> Tuple[bool, Optional[str]]: # TODO: we may want to move this into the C++ so its closer to the actual impl
if min_capability is not None: def query_marlin_supported_quant_types(has_zp: bool,
min_capability: Optional[int] = None):
if min_capability is None:
major, minor = current_platform.get_device_capability() major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor min_capability = major * 10 + minor
if device_capability < min_capability:
return (False, "Marlin does not support device_capability = {}"
", the min_capability required is {}".format(
device_capability, min_capability))
if num_bits not in MARLIN_SUPPORTED_NUM_BITS:
return (False, "Marlin does not support weight_bits = {}. "
"Only weight_bits = {} are supported.".format(
num_bits, MARLIN_SUPPORTED_NUM_BITS))
if group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
return (False, "Marlin does not support group_size = {}. Only "
"group_sizes = {} are supported.".format(
group_size, MARLIN_SUPPORTED_GROUP_SIZES))
if not has_zp and not is_sym:
return (False,
"Marlin without zero_points must have symmetric quantization")
return True, None if min_capability < 80:
return []
if has_zp:
# AWQ style, unsigned + runtime zero-point
return [scalar_types.uint4, scalar_types.uint8]
else:
# GPTQ style, unsigned + symmetric bias
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
# to add `scalar_types.float8_e4m3fn` here
return [scalar_types.uint4b8, scalar_types.uint8b128]
def check_gptq_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
min_capability: int) -> bool:
cond, _ = _check_marlin_supported(num_bits,
group_size,
is_sym,
min_capability,
has_zp=False)
return cond
def _check_marlin_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
min_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
def check_awq_marlin_supported(num_bits: int, group_size: int, has_zp: bool, if min_capability is None:
min_capability: int) -> bool: major, minor = current_platform.get_device_capability()
cond, _ = _check_marlin_supported(num_bits, min_capability = major * 10 + minor
group_size,
False,
min_capability,
has_zp=has_zp)
return cond
supported_types = query_marlin_supported_quant_types(
has_zp, min_capability)
def verify_gptq_marlin_supported(num_bits: int, group_size: int, if quant_type not in supported_types:
is_sym: bool) -> None: return (False, f"Marlin does not support weight_bits = {quant_type}. "
cond, err_msg = _check_marlin_supported(num_bits, f"Only types = {supported_types} "
group_size, f"are supported (for group_size = {group_size}, "
is_sym, f"min_capability = {min_capability}, zp = {has_zp}).")
min_capability=None, if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
has_zp=False) return (False, f"Marlin does not support group_size = {group_size}. "
if not cond: f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
assert err_msg is not None "are supported.")
raise ValueError("GPTQ" + err_msg)
return True, None
def check_marlin_supported(quant_type: ScalarType,
group_size: int,
has_zp: bool = False,
min_capability: Optional[int] = None) -> bool:
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
min_capability)
return cond
def verify_awq_marlin_supported(num_bits: int, group_size: int, def verify_marlin_supported(quant_type: ScalarType,
has_zp: bool) -> None: group_size: int,
cond, err_msg = _check_marlin_supported(num_bits, has_zp: bool = False) -> None:
group_size, cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
False,
min_capability=None,
has_zp=has_zp)
if not cond: if not cond:
assert err_msg is not None assert err_msg is not None
raise ValueError("AWQ" + err_msg) raise ValueError(err_msg)
def verify_marlin_supports_shape(output_size_per_partition: int, def verify_marlin_supports_shape(output_size_per_partition: int,
...@@ -245,7 +239,7 @@ def apply_gptq_marlin_linear( ...@@ -245,7 +239,7 @@ def apply_gptq_marlin_linear(
g_idx: torch.Tensor, g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor, g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor, workspace: torch.Tensor,
num_bits: int, wtype: ScalarType,
output_size_per_partition: int, output_size_per_partition: int,
input_size_per_partition: int, input_size_per_partition: int,
is_k_full: bool, is_k_full: bool,
...@@ -261,7 +255,7 @@ def apply_gptq_marlin_linear( ...@@ -261,7 +255,7 @@ def apply_gptq_marlin_linear(
g_idx, g_idx,
g_idx_sort_indices, g_idx_sort_indices,
workspace, workspace,
num_bits, wtype,
size_m=reshaped_x.shape[0], size_m=reshaped_x.shape[0],
size_n=output_size_per_partition, size_n=output_size_per_partition,
size_k=input_size_per_partition, size_k=input_size_per_partition,
...@@ -283,7 +277,7 @@ def apply_awq_marlin_linear( ...@@ -283,7 +277,7 @@ def apply_awq_marlin_linear(
g_idx: torch.Tensor, g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor, g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor, workspace: torch.Tensor,
num_bits: int, quant_type: ScalarType,
output_size_per_partition: int, output_size_per_partition: int,
input_size_per_partition: int, input_size_per_partition: int,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
...@@ -298,7 +292,7 @@ def apply_awq_marlin_linear( ...@@ -298,7 +292,7 @@ def apply_awq_marlin_linear(
g_idx, g_idx,
g_idx_sort_indices, g_idx_sort_indices,
workspace, workspace,
num_bits, quant_type,
size_m=reshaped_x.shape[0], size_m=reshaped_x.shape[0],
size_n=output_size_per_partition, size_n=output_size_per_partition,
size_k=input_size_per_partition, size_k=input_size_per_partition,
......
...@@ -5,10 +5,12 @@ from typing import List ...@@ -5,10 +5,12 @@ from typing import List
import numpy as np import numpy as np
import torch import torch
from vllm.scalar_type import ScalarType
from .marlin_utils import (GPTQ_MARLIN_TILE, marlin_permute_scales, from .marlin_utils import (GPTQ_MARLIN_TILE, marlin_permute_scales,
marlin_zero_points) marlin_zero_points)
from .quant_utils import (get_pack_factor, quantize_weights, from .quant_utils import (get_pack_factor, gptq_quantize_weights,
quantize_weights_with_zp, sort_weights) quantize_weights, sort_weights)
class MarlinWorkspace: class MarlinWorkspace:
...@@ -90,9 +92,10 @@ def get_weight_perm(num_bits: int): ...@@ -90,9 +92,10 @@ def get_weight_perm(num_bits: int):
return perm return perm
def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int, def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int,
act_order: bool): act_order: bool):
size_k, size_n = w.shape size_k, size_n = w.shape
num_bits = quant_type.size_bits
# Normalize group_size # Normalize group_size
if group_size == -1: if group_size == -1:
...@@ -100,8 +103,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int, ...@@ -100,8 +103,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
assert group_size <= size_k assert group_size <= size_k
# Quantize (and apply act_order if provided) # Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
act_order) w, quant_type, group_size, act_order)
# For act_order, sort the "weights" and "g_idx" so that group ids are # For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing # increasing
...@@ -122,7 +125,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int, ...@@ -122,7 +125,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
return res_list return res_list
def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int): def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType,
group_size: int):
size_k, size_n = w.shape size_k, size_n = w.shape
# Normalize group_size # Normalize group_size
...@@ -135,13 +139,18 @@ def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int): ...@@ -135,13 +139,18 @@ def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int):
num_groups = size_k // group_size num_groups = size_k // group_size
# Quantize with zp # Quantize with zp
w_ref, q_w, s, zp = quantize_weights_with_zp(w, num_bits, group_size) w_ref, q_w, s, zp = quantize_weights(w,
quant_type,
group_size,
zero_points=True)
# Reformat to marlin # Reformat to marlin
weight_perm = get_weight_perm(num_bits) weight_perm = get_weight_perm(quant_type.size_bits)
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
weight_perm)
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
marlin_zp = marlin_zero_points(zp, num_groups, size_n, num_bits) marlin_zp = marlin_zero_points(zp, num_groups, size_n,
quant_type.size_bits)
# Create result # Create result
res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
......
...@@ -6,8 +6,10 @@ from typing import List ...@@ -6,8 +6,10 @@ from typing import List
import numpy import numpy
import torch import torch
from vllm.scalar_type import ScalarType
from .marlin_utils_test import marlin_weights from .marlin_utils_test import marlin_weights
from .quant_utils import quantize_weights from .quant_utils import gptq_quantize_weights
# This is PyTorch implementation of main part of reorder_meta() # This is PyTorch implementation of main part of reorder_meta()
...@@ -348,13 +350,11 @@ def check_24(w, num_rows_to_sample=50, _verbose=False): ...@@ -348,13 +350,11 @@ def check_24(w, num_rows_to_sample=50, _verbose=False):
print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
def compress_quantized_24_weight(q_24, size_k, size_n, num_bits): def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
assert q_24.shape == (size_k, size_n) assert q_24.shape == (size_k, size_n)
# Remove zp to normalize over 0 # Remove bias to normalize over 0
max_q_val = (1 << num_bits) - 1 q_24_no_zp = q_24 - wtype.bias
zp = (max_q_val + 1) // 2
q_24_no_zp = q_24 - zp
# Compress # Compress
q_24_no_zp = q_24_no_zp.t().contiguous() q_24_no_zp = q_24_no_zp.t().contiguous()
...@@ -362,8 +362,8 @@ def compress_quantized_24_weight(q_24, size_k, size_n, num_bits): ...@@ -362,8 +362,8 @@ def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
q_24_no_zp) q_24_no_zp)
q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
# Restore zp # Restore bias
q_24_comp = q_24_no_zp_comp + zp q_24_comp = q_24_no_zp_comp + wtype.bias
# Resize meta to its actual shape (without moving any data) # Resize meta to its actual shape (without moving any data)
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
...@@ -427,7 +427,7 @@ def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int, ...@@ -427,7 +427,7 @@ def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int,
def marlin_24_quantize( def marlin_24_quantize(
w: torch.Tensor, w: torch.Tensor,
num_bits: int, quant_type: ScalarType,
group_size: int, group_size: int,
): ):
size_k, size_n = w.shape size_k, size_n = w.shape
...@@ -441,20 +441,18 @@ def marlin_24_quantize( ...@@ -441,20 +441,18 @@ def marlin_24_quantize(
w_24, mask_24 = inject_24(w, size_k, size_n) w_24, mask_24 = inject_24(w, size_k, size_n)
# Quantize # Quantize
w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24, w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights(
num_bits, w_24, quant_type, group_size, act_order=False)
group_size,
act_order=False)
# Compress quantized weight # Compress quantized weight
q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n,
num_bits) quant_type)
size_k_comp = size_k // 2 size_k_comp = size_k // 2
# Reformat to marlin # Reformat to marlin
weight_perm = get_weight_perm_24(num_bits) weight_perm = get_weight_perm_24(quant_type.size_bits)
marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n,
num_bits, weight_perm) quant_type.size_bits, weight_perm)
marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
# Create result # Create result
......
...@@ -4,7 +4,11 @@ from typing import List ...@@ -4,7 +4,11 @@ from typing import List
import numpy import numpy
import torch import torch
SUPPORTED_NUM_BITS = [4, 8] from vllm.model_executor.layers.quantization.qqq import (
MARLIN_QQQ_SUPPORTED_NUM_BITS)
from vllm.scalar_type import ScalarType, scalar_types
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# Note: this is a hack. We should update each model to register the # Note: this is a hack. We should update each model to register the
...@@ -45,7 +49,7 @@ def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: ...@@ -45,7 +49,7 @@ def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
def get_pack_factor(num_bits): def get_pack_factor(num_bits):
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
return 32 // num_bits return 32 // num_bits
...@@ -74,24 +78,23 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): ...@@ -74,24 +78,23 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
) )
def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, def quantize_weights(w: torch.Tensor,
act_order: bool): quant_type: ScalarType,
group_size: int,
zero_points: bool = False):
assert quant_type.is_integer(), \
"Floating point quantization may work but has not been tested"
orig_device = w.device orig_device = w.device
orig_type = w.dtype
size_k, size_n = w.shape size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float" assert w.is_floating_point(), "w must be float"
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
assert group_size in SUPPORTED_GROUP_SIZES + [
size_k
], f"Unsupported groupsize = {group_size}"
if group_size == -1: if group_size == -1:
group_size = size_k group_size = size_k
assert group_size <= size_k assert group_size <= size_k
max_q_val = 2**num_bits - 1
half_q_val = (max_q_val + 1) // 2
# Reshape to [groupsize, -1] # Reshape to [groupsize, -1]
if group_size < size_k: if group_size < size_k:
w = w.reshape((-1, group_size, size_n)) w = w.reshape((-1, group_size, size_n))
...@@ -99,16 +102,34 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, ...@@ -99,16 +102,34 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
w = w.reshape((group_size, -1)) w = w.reshape((group_size, -1))
# Compute scale for each group # Compute scale for each group
s = torch.max(torch.abs(w), 0, keepdim=True)[0] max_val = torch.max(w, 0, keepdim=True).values
s *= 2 / max_q_val # 2 => symmetric min_val = torch.min(w, 0, keepdim=True).values
max_q_val = quant_type.max()
min_q_val = quant_type.min()
if zero_points:
assert not quant_type.is_signed() and quant_type.max() > 0
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
.clamp(min_q_val, max_q_val).int()
else:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s = torch.max(
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
maybe_w_zp = None
# Quantize # Quantize
q_w = torch.round(w / s).int() w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
q_w += half_q_val w_q = torch.clamp(w_q, min_q_val, max_q_val)
q_w = torch.clamp(q_w, 0, max_q_val)
# Compute ref (dequantized) # Compute ref (dequantized)
w_ref = (q_w - half_q_val).half() * s w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
if quant_type.has_bias():
w_q += quant_type.bias
# Restore original shapes # Restore original shapes
if group_size < size_k: if group_size < size_k:
...@@ -119,90 +140,48 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, ...@@ -119,90 +140,48 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
w = w.reshape((size_k, size_n)).contiguous() w = w.reshape((size_k, size_n)).contiguous()
return w return w
q_w = reshape_w(q_w) w_q = reshape_w(w_q)
w_ref = reshape_w(w_ref) w_ref = reshape_w(w_ref)
s = s.reshape((-1, size_n)).contiguous() w_s = w_s.reshape((-1, size_n)).contiguous()
# Apply act_order if zero_points:
g_idx = torch.empty(0, dtype=torch.int, device=w.device) maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
rand_perm = torch.empty(0, dtype=torch.int, device=w.device) maybe_w_zp = maybe_w_zp.to(device=orig_device)
if act_order:
assert (
group_size < size_k
), "For act_order, groupsize = {} must be less than size_k = {}".format(
group_size, size_k)
w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size)
return ( return (
w_ref.to(device=orig_device), w_ref.to(device=orig_device),
q_w.to(device=orig_device), w_q.to(device=orig_device),
s.to(device=orig_device), w_s.to(device=orig_device),
g_idx.to(device=orig_device), maybe_w_zp,
rand_perm.to(device=orig_device),
) )
def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int): def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
orig_device = w.device group_size: int, act_order: bool):
size_k, size_n = w.shape size_k, _ = w.shape
assert w.is_floating_point(), "w must be float" assert w.is_floating_point(), "w must be float"
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, \
f"Unsupported gptq type = {quant_type}"
assert group_size in SUPPORTED_GROUP_SIZES + [ assert group_size in SUPPORTED_GROUP_SIZES + [
size_k size_k
], f"Unsupported groupsize = {group_size}" ], f"Unsupported groupsize = {group_size}"
if group_size == -1: w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
group_size = size_k
assert group_size <= size_k
max_q_val = 2**num_bits - 1
min_q_val = 0
# Reshape to [groupsize, -1] # Apply act_order
if group_size < size_k: g_idx = torch.empty(0, dtype=torch.int, device=w.device)
w = w.reshape((-1, group_size, size_n)) rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
w = w.permute(1, 0, 2) if act_order:
w = w.reshape((group_size, -1)) assert (
group_size < size_k
# Compute scale for each group ), "For act_order, groupsize = {} must be less than size_k = {}".format(
max = torch.max(w, 0, keepdim=True)[0] group_size, size_k)
min = torch.min(w, 0, keepdim=True)[0]
s = (max - min).clamp(min=1e-5) / max_q_val
# Compute zero-point for each group
zp = (-torch.round(min / s)).clamp(min_q_val, max_q_val).int()
# Quantize
q_w = torch.round(w / s).int() + zp
q_w = torch.clamp(q_w, min_q_val, max_q_val)
# Compute ref (dequantized)
w_ref = (q_w - zp).half() * s
# Restore original shapes
if group_size < size_k:
def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((size_k, size_n)).contiguous()
return w
q_w = reshape_w(q_w)
w_ref = reshape_w(w_ref)
s = s.reshape((-1, size_n)).contiguous() w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size)
zp = zp.reshape((-1, size_n)).contiguous()
return ( return w_ref, w_q, w_s, g_idx, rand_perm
w_ref.to(device=orig_device),
q_w.to(device=orig_device),
s.to(device=orig_device),
zp.to(device=orig_device),
)
# QQQ employs different quant schemes for per-group and # QQQ employs different quant schemes for per-group and
...@@ -212,7 +191,8 @@ def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): ...@@ -212,7 +191,8 @@ def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
size_k, size_n = w.shape size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float" assert w.is_floating_point(), "w must be float"
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \
f"Unsupported num_bits = {num_bits}"
assert group_size in SUPPORTED_GROUP_SIZES + [ assert group_size in SUPPORTED_GROUP_SIZES + [
size_k size_k
], f"Unsupported groupsize = {group_size}" ], f"Unsupported groupsize = {group_size}"
......
from ._core_ext import NanRepr, ScalarType
# naming generally follows: https://github.com/jax-ml/ml_dtypes
# for floating point types (leading f) the scheme is:
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
# flags:
# - no-flags: means it follows IEEE 754 conventions
# - f: means finite values only (no infinities)
# - n: means nans are supported (non-standard encoding)
# for integer types the scheme is:
# `[u]int<size_bits>[b<bias>]`
# - if bias is not present it means its zero
class scalar_types:
int4 = ScalarType.int_(4, None)
uint4 = ScalarType.uint(4, None)
int8 = ScalarType.int_(8, None)
uint8 = ScalarType.uint(8, None)
float8_e4m3fn = ScalarType.float_(4, 3, True,
NanRepr.EXTD_RANGE_MAX_MIN.value)
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value)
# "gptq" types
uint4b8 = ScalarType.uint(4, 8)
uint8b128 = ScalarType.uint(8, 128)
# colloquial names
bfloat16 = float16_e8m7
float16 = float16_e5m10
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment