"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "934bebf19252da6e1f2583d92e31d583b49498a2"
Unverified Commit 311f7438 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Bugfix] Fix gptq failure on T4s (#7264)

parent 469b3bc5
...@@ -126,8 +126,7 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -126,8 +126,7 @@ class AWQMarlinConfig(QuantizationConfig):
return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits], return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
group_size=group_size, group_size=group_size,
has_zp=has_zp, has_zp=has_zp)
min_capability=cls.get_min_capability())
class AWQMarlinLinearMethod(LinearMethodBase): class AWQMarlinLinearMethod(LinearMethodBase):
......
...@@ -136,8 +136,7 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -136,8 +136,7 @@ class GPTQMarlinConfig(QuantizationConfig):
return False return False
return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)], return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
group_size=group_size, group_size=group_size)
min_capability=cls.get_min_capability())
class GPTQMarlinLinearMethod(LinearMethodBase): class GPTQMarlinLinearMethod(LinearMethodBase):
......
...@@ -26,12 +26,13 @@ USE_FP32_REDUCE_DEFAULT = True ...@@ -26,12 +26,13 @@ USE_FP32_REDUCE_DEFAULT = True
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl # TODO: we may want to move this into the C++ so its closer to the actual impl
def query_marlin_supported_quant_types(has_zp: bool, def query_marlin_supported_quant_types(has_zp: bool,
min_capability: Optional[int] = None): device_capability: Optional[int] = None
if min_capability is None: ):
if device_capability is None:
major, minor = current_platform.get_device_capability() major, minor = current_platform.get_device_capability()
min_capability = major * 10 + minor device_capability = major * 10 + minor
if min_capability < 80: if device_capability < 80:
return [] return []
if has_zp: if has_zp:
...@@ -48,20 +49,20 @@ def _check_marlin_supported( ...@@ -48,20 +49,20 @@ def _check_marlin_supported(
quant_type: ScalarType, quant_type: ScalarType,
group_size: Optional[int], group_size: Optional[int],
has_zp: bool, has_zp: bool,
min_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if min_capability is None: if device_capability is None:
major, minor = current_platform.get_device_capability() major, minor = current_platform.get_device_capability()
min_capability = major * 10 + minor device_capability = major * 10 + minor
supported_types = query_marlin_supported_quant_types( supported_types = query_marlin_supported_quant_types(
has_zp, min_capability) has_zp, device_capability)
if quant_type not in supported_types: if quant_type not in supported_types:
return (False, f"Marlin does not support weight_bits = {quant_type}. " return (False, f"Marlin does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} " f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, " f"are supported (for group_size = {group_size}, "
f"min_capability = {min_capability}, zp = {has_zp}).") f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
return (False, f"Marlin does not support group_size = {group_size}. " return (False, f"Marlin does not support group_size = {group_size}. "
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
...@@ -73,9 +74,9 @@ def _check_marlin_supported( ...@@ -73,9 +74,9 @@ def _check_marlin_supported(
def check_marlin_supported(quant_type: ScalarType, def check_marlin_supported(quant_type: ScalarType,
group_size: int, group_size: int,
has_zp: bool = False, has_zp: bool = False,
min_capability: Optional[int] = None) -> bool: device_capability: Optional[int] = None) -> bool:
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
min_capability) device_capability)
return cond return cond
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment