"vscode:/vscode.git/clone" did not exist on "b337647aa0ce103a84aac1e07a8fd738a5a4f13f"
Unverified Commit 7836fdcc authored by Dipika Sikka's avatar Dipika Sikka Committed by GitHub
Browse files

[Misc] Fix `get_min_capability` (#5971)

parent deacb7ec
...@@ -43,7 +43,8 @@ class AWQConfig(QuantizationConfig): ...@@ -43,7 +43,8 @@ class AWQConfig(QuantizationConfig):
def get_supported_act_dtypes(self) -> List[torch.dtype]: def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half] return [torch.half]
def get_min_capability(self) -> int: @classmethod
def get_min_capability(cls) -> int:
# The AWQ kernel only supports Turing or newer GPUs. # The AWQ kernel only supports Turing or newer GPUs.
return 75 return 75
......
...@@ -44,8 +44,9 @@ class QuantizationConfig(ABC): ...@@ -44,8 +44,9 @@ class QuantizationConfig(ABC):
"""List of supported activation dtypes.""" """List of supported activation dtypes."""
raise NotImplementedError raise NotImplementedError
@classmethod
@abstractmethod @abstractmethod
def get_min_capability(self) -> int: def get_min_capability(cls) -> int:
"""Minimum GPU capability to support the quantization method. """Minimum GPU capability to support the quantization method.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere. E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
......
...@@ -38,7 +38,7 @@ class BitsAndBytesConfig(QuantizationConfig): ...@@ -38,7 +38,7 @@ class BitsAndBytesConfig(QuantizationConfig):
return [torch.float32, torch.float16, torch.bfloat16] return [torch.float32, torch.float16, torch.bfloat16]
@classmethod @classmethod
def get_min_capability(self) -> int: def get_min_capability(cls) -> int:
return 70 return 70
@staticmethod @staticmethod
......
...@@ -33,10 +33,9 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -33,10 +33,9 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16] return [torch.float16, torch.bfloat16]
# Need to figure it out
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 60 return 75
def get_name(self) -> str: def get_name(self) -> str:
return "compressed_tensors" return "compressed_tensors"
...@@ -84,6 +83,14 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -84,6 +83,14 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> List[str]:
return [] return []
def _check_gptq_and_marlin_can_run(self):
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < 80:
raise RuntimeError("The quantization config is not supported for ",
"the current GPU. Minimum capability: 80. ",
f"Current capability: {capability}.")
def _is_static_tensor_w8a8(self, weight_quant: BaseModel, def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool: input_quant: BaseModel) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
...@@ -126,6 +133,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -126,6 +133,7 @@ class CompressedTensorsConfig(QuantizationConfig):
input_quant: BaseModel) -> "CompressedTensorsScheme": input_quant: BaseModel) -> "CompressedTensorsScheme":
if self._is_wNa16_group_channel(weight_quant, input_quant): if self._is_wNa16_group_channel(weight_quant, input_quant):
self._check_gptq_and_marlin_can_run()
if (self.quant_format == CompressionFormat.marlin_24.value if (self.quant_format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
return CompressedTensorsW4A16Sparse24( return CompressedTensorsW4A16Sparse24(
......
...@@ -39,7 +39,8 @@ class SqueezeLLMConfig(QuantizationConfig): ...@@ -39,7 +39,8 @@ class SqueezeLLMConfig(QuantizationConfig):
def get_supported_act_dtypes(self) -> List[torch.dtype]: def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half] return [torch.half]
def get_min_capability(self) -> int: @classmethod
def get_min_capability(cls) -> int:
return 70 return 70
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment