Unverified Commit 1a048124 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[bistandbytes] improve replacement warnings for bnb (#11132)

* improve replacement warnings for bnb

* updates to docs.
parent 4b27c4a4
...@@ -139,10 +139,12 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name ...@@ -139,10 +139,12 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
models by reducing the precision of the weights and activations, thus making models more efficient in terms models by reducing the precision of the weights and activations, thus making models more efficient in terms
of both storage and computation. of both storage and computation.
""" """
model, has_been_replaced = _replace_with_bnb_linear( model, _ = _replace_with_bnb_linear(model, modules_to_not_convert, current_key_name, quantization_config)
model, modules_to_not_convert, current_key_name, quantization_config
)
has_been_replaced = any(
isinstance(replaced_module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt))
for _, replaced_module in model.named_modules()
)
if not has_been_replaced: if not has_been_replaced:
logger.warning( logger.warning(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model." "You are loading your model in 8bit or 4bit but no linear modules were found in your model."
...@@ -283,16 +285,18 @@ def dequantize_and_replace( ...@@ -283,16 +285,18 @@ def dequantize_and_replace(
modules_to_not_convert=None, modules_to_not_convert=None,
quantization_config=None, quantization_config=None,
): ):
model, has_been_replaced = _dequantize_and_replace( model, _ = _dequantize_and_replace(
model, model,
dtype=model.dtype, dtype=model.dtype,
modules_to_not_convert=modules_to_not_convert, modules_to_not_convert=modules_to_not_convert,
quantization_config=quantization_config, quantization_config=quantization_config,
) )
has_been_replaced = any(
isinstance(replaced_module, torch.nn.Linear) for _, replaced_module in model.named_modules()
)
if not has_been_replaced: if not has_been_replaced:
logger.warning( logger.warning(
"For some reason the model has not been properly dequantized. You might see unexpected behavior." "Some linear modules were not dequantized. This could lead to unexpected behaviour. Please check your model."
) )
return model return model
......
...@@ -70,6 +70,8 @@ if is_torch_available(): ...@@ -70,6 +70,8 @@ if is_torch_available():
if is_bitsandbytes_available(): if is_bitsandbytes_available():
import bitsandbytes as bnb import bitsandbytes as bnb
from diffusers.quantizers.bitsandbytes.utils import replace_with_bnb_linear
@require_bitsandbytes_version_greater("0.43.2") @require_bitsandbytes_version_greater("0.43.2")
@require_accelerate @require_accelerate
...@@ -371,6 +373,18 @@ class BnB4BitBasicTests(Base4bitTests): ...@@ -371,6 +373,18 @@ class BnB4BitBasicTests(Base4bitTests):
assert key_to_target in str(err_context.exception) assert key_to_target in str(err_context.exception)
def test_bnb_4bit_logs_warning_for_no_quantization(self):
model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU())
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
_ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config)
assert (
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
in cap_logger.out
)
class BnB4BitTrainingTests(Base4bitTests): class BnB4BitTrainingTests(Base4bitTests):
def setUp(self): def setUp(self):
......
...@@ -68,6 +68,8 @@ if is_torch_available(): ...@@ -68,6 +68,8 @@ if is_torch_available():
if is_bitsandbytes_available(): if is_bitsandbytes_available():
import bitsandbytes as bnb import bitsandbytes as bnb
from diffusers.quantizers.bitsandbytes import replace_with_bnb_linear
@require_bitsandbytes_version_greater("0.43.2") @require_bitsandbytes_version_greater("0.43.2")
@require_accelerate @require_accelerate
...@@ -317,6 +319,18 @@ class BnB8bitBasicTests(Base8bitTests): ...@@ -317,6 +319,18 @@ class BnB8bitBasicTests(Base8bitTests):
# Check that this does not throw an error # Check that this does not throw an error
_ = self.model_fp16.to(torch_device) _ = self.model_fp16.to(torch_device)
def test_bnb_8bit_logs_warning_for_no_quantization(self):
model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU())
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
_ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config)
assert (
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
in cap_logger.out
)
class Bnb8bitDeviceTests(Base8bitTests): class Bnb8bitDeviceTests(Base8bitTests):
def setUp(self) -> None: def setUp(self) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment