Unverified Commit dc6eb448 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

Improve error msg when using bitsandbytes (#31350)

improve error msg when using bnb
parent 517df566
...@@ -60,10 +60,11 @@ class Bnb4BitHfQuantizer(HfQuantizer): ...@@ -60,10 +60,11 @@ class Bnb4BitHfQuantizer(HfQuantizer):
def validate_environment(self, *args, **kwargs): def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.") raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not (is_accelerate_available() and is_bitsandbytes_available()): if not is_accelerate_available():
raise ImportError("Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install accelerate`")
if not is_bitsandbytes_available():
raise ImportError( raise ImportError(
"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` " "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
"and the latest version of bitsandbytes: `pip install -U bitsandbytes`"
) )
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
......
...@@ -61,10 +61,11 @@ class Bnb8BitHfQuantizer(HfQuantizer): ...@@ -61,10 +61,11 @@ class Bnb8BitHfQuantizer(HfQuantizer):
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.") raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not (is_accelerate_available() and is_bitsandbytes_available()): if not is_accelerate_available():
raise ImportError("Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate`")
if not is_bitsandbytes_available():
raise ImportError( raise ImportError(
"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` " "Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
"and the latest version of bitsandbytes: `pip install -U bitsandbytes`"
) )
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment