Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
1a048124
Unverified
Commit
1a048124
authored
Apr 08, 2025
by
Sayak Paul
Committed by
GitHub
Apr 08, 2025
Browse files
[bistandbytes] improve replacement warnings for bnb (#11132)
* improve replacement warnings for bnb * updates to docs.
parent
4b27c4a4
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
38 additions
and
6 deletions
+38
-6
src/diffusers/quantizers/bitsandbytes/utils.py
src/diffusers/quantizers/bitsandbytes/utils.py
+10
-6
tests/quantization/bnb/test_4bit.py
tests/quantization/bnb/test_4bit.py
+14
-0
tests/quantization/bnb/test_mixed_int8.py
tests/quantization/bnb/test_mixed_int8.py
+14
-0
No files found.
src/diffusers/quantizers/bitsandbytes/utils.py
View file @
1a048124
...
...
@@ -139,10 +139,12 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
models by reducing the precision of the weights and activations, thus making models more efficient in terms
of both storage and computation.
"""
model
,
has_been_replaced
=
_replace_with_bnb_linear
(
model
,
modules_to_not_convert
,
current_key_name
,
quantization_config
)
model
,
_
=
_replace_with_bnb_linear
(
model
,
modules_to_not_convert
,
current_key_name
,
quantization_config
)
has_been_replaced
=
any
(
isinstance
(
replaced_module
,
(
bnb
.
nn
.
Linear4bit
,
bnb
.
nn
.
Linear8bitLt
))
for
_
,
replaced_module
in
model
.
named_modules
()
)
if
not
has_been_replaced
:
logger
.
warning
(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
...
...
@@ -283,16 +285,18 @@ def dequantize_and_replace(
modules_to_not_convert
=
None
,
quantization_config
=
None
,
):
model
,
has_been_replaced
=
_dequantize_and_replace
(
model
,
_
=
_dequantize_and_replace
(
model
,
dtype
=
model
.
dtype
,
modules_to_not_convert
=
modules_to_not_convert
,
quantization_config
=
quantization_config
,
)
has_been_replaced
=
any
(
isinstance
(
replaced_module
,
torch
.
nn
.
Linear
)
for
_
,
replaced_module
in
model
.
named_modules
()
)
if
not
has_been_replaced
:
logger
.
warning
(
"
For some reason the model has not been properly dequantized. You might see unexpected behavior
."
"
Some linear modules were not dequantized. This could lead to unexpected behaviour. Please check your model
."
)
return
model
...
...
tests/quantization/bnb/test_4bit.py
View file @
1a048124
...
...
@@ -70,6 +70,8 @@ if is_torch_available():
if
is_bitsandbytes_available
():
import
bitsandbytes
as
bnb
from
diffusers.quantizers.bitsandbytes.utils
import
replace_with_bnb_linear
@
require_bitsandbytes_version_greater
(
"0.43.2"
)
@
require_accelerate
...
...
@@ -371,6 +373,18 @@ class BnB4BitBasicTests(Base4bitTests):
assert
key_to_target
in
str
(
err_context
.
exception
)
def
test_bnb_4bit_logs_warning_for_no_quantization
(
self
):
model_with_no_linear
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
4
,
4
,
3
),
torch
.
nn
.
ReLU
())
quantization_config
=
BitsAndBytesConfig
(
load_in_4bit
=
True
)
logger
=
logging
.
get_logger
(
"diffusers.quantizers.bitsandbytes.utils"
)
logger
.
setLevel
(
30
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
_
=
replace_with_bnb_linear
(
model_with_no_linear
,
quantization_config
=
quantization_config
)
assert
(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
in
cap_logger
.
out
)
class
BnB4BitTrainingTests
(
Base4bitTests
):
def
setUp
(
self
):
...
...
tests/quantization/bnb/test_mixed_int8.py
View file @
1a048124
...
...
@@ -68,6 +68,8 @@ if is_torch_available():
if
is_bitsandbytes_available
():
import
bitsandbytes
as
bnb
from
diffusers.quantizers.bitsandbytes
import
replace_with_bnb_linear
@
require_bitsandbytes_version_greater
(
"0.43.2"
)
@
require_accelerate
...
...
@@ -317,6 +319,18 @@ class BnB8bitBasicTests(Base8bitTests):
# Check that this does not throw an error
_
=
self
.
model_fp16
.
to
(
torch_device
)
def
test_bnb_8bit_logs_warning_for_no_quantization
(
self
):
model_with_no_linear
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
4
,
4
,
3
),
torch
.
nn
.
ReLU
())
quantization_config
=
BitsAndBytesConfig
(
load_in_8bit
=
True
)
logger
=
logging
.
get_logger
(
"diffusers.quantizers.bitsandbytes.utils"
)
logger
.
setLevel
(
30
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
_
=
replace_with_bnb_linear
(
model_with_no_linear
,
quantization_config
=
quantization_config
)
assert
(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
in
cap_logger
.
out
)
class
Bnb8bitDeviceTests
(
Base8bitTests
):
def
setUp
(
self
)
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment