Unverified Commit 1872bde7 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[BitsandBytes] Verify if GPU is available (#30533)

Change order
parent 998dbe06
...@@ -58,6 +58,8 @@ class Bnb4BitHfQuantizer(HfQuantizer): ...@@ -58,6 +58,8 @@ class Bnb4BitHfQuantizer(HfQuantizer):
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
def validate_environment(self, *args, **kwargs): def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not (is_accelerate_available() and is_bitsandbytes_available()): if not (is_accelerate_available() and is_bitsandbytes_available()):
raise ImportError( raise ImportError(
"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` " "Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` "
...@@ -70,9 +72,6 @@ class Bnb4BitHfQuantizer(HfQuantizer): ...@@ -70,9 +72,6 @@ class Bnb4BitHfQuantizer(HfQuantizer):
" sure the weights are in PyTorch format." " sure the weights are in PyTorch format."
) )
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
device_map = kwargs.get("device_map", None) device_map = kwargs.get("device_map", None)
if ( if (
device_map is not None device_map is not None
......
...@@ -58,6 +58,9 @@ class Bnb8BitHfQuantizer(HfQuantizer): ...@@ -58,6 +58,9 @@ class Bnb8BitHfQuantizer(HfQuantizer):
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
def validate_environment(self, *args, **kwargs): def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not (is_accelerate_available() and is_bitsandbytes_available()): if not (is_accelerate_available() and is_bitsandbytes_available()):
raise ImportError( raise ImportError(
"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` " "Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` "
...@@ -70,9 +73,6 @@ class Bnb8BitHfQuantizer(HfQuantizer): ...@@ -70,9 +73,6 @@ class Bnb8BitHfQuantizer(HfQuantizer):
" sure the weights are in PyTorch format." " sure the weights are in PyTorch format."
) )
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
device_map = kwargs.get("device_map", None) device_map = kwargs.get("device_map", None)
if ( if (
device_map is not None device_map is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment