Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
1a048124
Unverified
Commit
1a048124
authored
Apr 08, 2025
by
Sayak Paul
Committed by
GitHub
Apr 08, 2025
Browse files
[bistandbytes] improve replacement warnings for bnb (#11132)
* improve replacement warnings for bnb * updates to docs.
parent
4b27c4a4
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
38 additions
and
6 deletions
+38
-6
src/diffusers/quantizers/bitsandbytes/utils.py
src/diffusers/quantizers/bitsandbytes/utils.py
+10
-6
tests/quantization/bnb/test_4bit.py
tests/quantization/bnb/test_4bit.py
+14
-0
tests/quantization/bnb/test_mixed_int8.py
tests/quantization/bnb/test_mixed_int8.py
+14
-0
No files found.
src/diffusers/quantizers/bitsandbytes/utils.py
View file @
1a048124
...
@@ -139,10 +139,12 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
...
@@ -139,10 +139,12 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
models by reducing the precision of the weights and activations, thus making models more efficient in terms
models by reducing the precision of the weights and activations, thus making models more efficient in terms
of both storage and computation.
of both storage and computation.
"""
"""
model
,
has_been_replaced
=
_replace_with_bnb_linear
(
model
,
_
=
_replace_with_bnb_linear
(
model
,
modules_to_not_convert
,
current_key_name
,
quantization_config
)
model
,
modules_to_not_convert
,
current_key_name
,
quantization_config
)
has_been_replaced
=
any
(
isinstance
(
replaced_module
,
(
bnb
.
nn
.
Linear4bit
,
bnb
.
nn
.
Linear8bitLt
))
for
_
,
replaced_module
in
model
.
named_modules
()
)
if
not
has_been_replaced
:
if
not
has_been_replaced
:
logger
.
warning
(
logger
.
warning
(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
...
@@ -283,16 +285,18 @@ def dequantize_and_replace(
...
@@ -283,16 +285,18 @@ def dequantize_and_replace(
modules_to_not_convert
=
None
,
modules_to_not_convert
=
None
,
quantization_config
=
None
,
quantization_config
=
None
,
):
):
model
,
has_been_replaced
=
_dequantize_and_replace
(
model
,
_
=
_dequantize_and_replace
(
model
,
model
,
dtype
=
model
.
dtype
,
dtype
=
model
.
dtype
,
modules_to_not_convert
=
modules_to_not_convert
,
modules_to_not_convert
=
modules_to_not_convert
,
quantization_config
=
quantization_config
,
quantization_config
=
quantization_config
,
)
)
has_been_replaced
=
any
(
isinstance
(
replaced_module
,
torch
.
nn
.
Linear
)
for
_
,
replaced_module
in
model
.
named_modules
()
)
if
not
has_been_replaced
:
if
not
has_been_replaced
:
logger
.
warning
(
logger
.
warning
(
"
For some reason the model has not been properly dequantized. You might see unexpected behavior
."
"
Some linear modules were not dequantized. This could lead to unexpected behaviour. Please check your model
."
)
)
return
model
return
model
...
...
tests/quantization/bnb/test_4bit.py
View file @
1a048124
...
@@ -70,6 +70,8 @@ if is_torch_available():
...
@@ -70,6 +70,8 @@ if is_torch_available():
if
is_bitsandbytes_available
():
if
is_bitsandbytes_available
():
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
from
diffusers.quantizers.bitsandbytes.utils
import
replace_with_bnb_linear
@
require_bitsandbytes_version_greater
(
"0.43.2"
)
@
require_bitsandbytes_version_greater
(
"0.43.2"
)
@
require_accelerate
@
require_accelerate
...
@@ -371,6 +373,18 @@ class BnB4BitBasicTests(Base4bitTests):
...
@@ -371,6 +373,18 @@ class BnB4BitBasicTests(Base4bitTests):
assert
key_to_target
in
str
(
err_context
.
exception
)
assert
key_to_target
in
str
(
err_context
.
exception
)
def
test_bnb_4bit_logs_warning_for_no_quantization
(
self
):
model_with_no_linear
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
4
,
4
,
3
),
torch
.
nn
.
ReLU
())
quantization_config
=
BitsAndBytesConfig
(
load_in_4bit
=
True
)
logger
=
logging
.
get_logger
(
"diffusers.quantizers.bitsandbytes.utils"
)
logger
.
setLevel
(
30
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
_
=
replace_with_bnb_linear
(
model_with_no_linear
,
quantization_config
=
quantization_config
)
assert
(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
in
cap_logger
.
out
)
class
BnB4BitTrainingTests
(
Base4bitTests
):
class
BnB4BitTrainingTests
(
Base4bitTests
):
def
setUp
(
self
):
def
setUp
(
self
):
...
...
tests/quantization/bnb/test_mixed_int8.py
View file @
1a048124
...
@@ -68,6 +68,8 @@ if is_torch_available():
...
@@ -68,6 +68,8 @@ if is_torch_available():
if
is_bitsandbytes_available
():
if
is_bitsandbytes_available
():
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
from
diffusers.quantizers.bitsandbytes
import
replace_with_bnb_linear
@
require_bitsandbytes_version_greater
(
"0.43.2"
)
@
require_bitsandbytes_version_greater
(
"0.43.2"
)
@
require_accelerate
@
require_accelerate
...
@@ -317,6 +319,18 @@ class BnB8bitBasicTests(Base8bitTests):
...
@@ -317,6 +319,18 @@ class BnB8bitBasicTests(Base8bitTests):
# Check that this does not throw an error
# Check that this does not throw an error
_
=
self
.
model_fp16
.
to
(
torch_device
)
_
=
self
.
model_fp16
.
to
(
torch_device
)
def
test_bnb_8bit_logs_warning_for_no_quantization
(
self
):
model_with_no_linear
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
4
,
4
,
3
),
torch
.
nn
.
ReLU
())
quantization_config
=
BitsAndBytesConfig
(
load_in_8bit
=
True
)
logger
=
logging
.
get_logger
(
"diffusers.quantizers.bitsandbytes.utils"
)
logger
.
setLevel
(
30
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
_
=
replace_with_bnb_linear
(
model_with_no_linear
,
quantization_config
=
quantization_config
)
assert
(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
in
cap_logger
.
out
)
class
Bnb8bitDeviceTests
(
Base8bitTests
):
class
Bnb8bitDeviceTests
(
Base8bitTests
):
def
setUp
(
self
)
->
None
:
def
setUp
(
self
)
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment