Unverified Commit 7836fdcc authored by Dipika Sikka's avatar Dipika Sikka Committed by GitHub
Browse files

[Misc] Fix `get_min_capability` (#5971)

parent deacb7ec
......@@ -43,7 +43,8 @@ class AWQConfig(QuantizationConfig):
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half]
def get_min_capability(self) -> int:
@classmethod
def get_min_capability(cls) -> int:
# The AWQ kernel only supports Turing or newer GPUs.
return 75
......
......@@ -44,8 +44,9 @@ class QuantizationConfig(ABC):
"""List of supported activation dtypes."""
raise NotImplementedError
@classmethod
@abstractmethod
def get_min_capability(self) -> int:
def get_min_capability(cls) -> int:
"""Minimum GPU capability to support the quantization method.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
......
......@@ -38,7 +38,7 @@ class BitsAndBytesConfig(QuantizationConfig):
return [torch.float32, torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(self) -> int:
def get_min_capability(cls) -> int:
return 70
@staticmethod
......
......@@ -33,10 +33,9 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
# Need to figure it out
@classmethod
def get_min_capability(cls) -> int:
return 60
return 75
def get_name(self) -> str:
return "compressed_tensors"
......@@ -84,6 +83,14 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_config_filenames(cls) -> List[str]:
return []
def _check_gptq_and_marlin_can_run(self):
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < 80:
raise RuntimeError("The quantization config is not supported for ",
"the current GPU. Minimum capability: 80. ",
f"Current capability: {capability}.")
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
......@@ -126,6 +133,7 @@ class CompressedTensorsConfig(QuantizationConfig):
input_quant: BaseModel) -> "CompressedTensorsScheme":
if self._is_wNa16_group_channel(weight_quant, input_quant):
self._check_gptq_and_marlin_can_run()
if (self.quant_format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
return CompressedTensorsW4A16Sparse24(
......
......@@ -39,7 +39,8 @@ class SqueezeLLMConfig(QuantizationConfig):
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half]
def get_min_capability(self) -> int:
@classmethod
def get_min_capability(cls) -> int:
return 70
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment