Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
8e88495d
Unverified
Commit
8e88495d
authored
Jun 11, 2025
by
Sayak Paul
Committed by
GitHub
Jun 11, 2025
Browse files
[LoRA] support Flux Control LoRA with bnb 8bit. (#11655)
support Flux Control LoRA with bnb 8bit.
parent
b79803fe
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
2 deletions
+55
-2
src/diffusers/loaders/lora_pipeline.py
src/diffusers/loaders/lora_pipeline.py
+7
-2
tests/quantization/bnb/test_mixed_int8.py
tests/quantization/bnb/test_mixed_int8.py
+48
-0
No files found.
src/diffusers/loaders/lora_pipeline.py
View file @
8e88495d
...
...
@@ -81,12 +81,17 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
from
..quantizers.gguf.utils
import
dequantize_gguf_tensor
is_bnb_4bit_quantized
=
module
.
weight
.
__class__
.
__name__
==
"Params4bit"
is_bnb_8bit_quantized
=
module
.
weight
.
__class__
.
__name__
==
"Int8Params"
is_gguf_quantized
=
module
.
weight
.
__class__
.
__name__
==
"GGUFParameter"
if
is_bnb_4bit_quantized
and
not
is_bitsandbytes_available
():
raise
ValueError
(
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
)
if
is_bnb_8bit_quantized
and
not
is_bitsandbytes_available
():
raise
ValueError
(
"The checkpoint seems to have been quantized with `bitsandbytes` (8bits). Install `bitsandbytes` to load quantized checkpoints."
)
if
is_gguf_quantized
and
not
is_gguf_available
():
raise
ValueError
(
"The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
...
...
@@ -97,10 +102,10 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
weight_on_cpu
=
True
device
=
torch
.
accelerator
.
current_accelerator
().
type
if
hasattr
(
torch
,
"accelerator"
)
else
"cuda"
if
is_bnb_4bit_quantized
:
if
is_bnb_4bit_quantized
or
is_bnb_8bit_quantized
:
module_weight
=
dequantize_bnb_weight
(
module
.
weight
.
to
(
device
)
if
weight_on_cpu
else
module
.
weight
,
state
=
module
.
weight
.
quant_state
,
state
=
module
.
weight
.
quant_state
if
is_bnb_4bit_quantized
else
module
.
state
,
dtype
=
model
.
dtype
,
).
data
elif
is_gguf_quantized
:
...
...
tests/quantization/bnb/test_mixed_int8.py
View file @
8e88495d
...
...
@@ -19,15 +19,18 @@ import unittest
import
numpy
as
np
import
pytest
from
huggingface_hub
import
hf_hub_download
from
PIL
import
Image
from
diffusers
import
(
BitsAndBytesConfig
,
DiffusionPipeline
,
FluxControlPipeline
,
FluxTransformer2DModel
,
SanaTransformer2DModel
,
SD3Transformer2DModel
,
logging
,
)
from
diffusers.quantizers
import
PipelineQuantizationConfig
from
diffusers.utils
import
is_accelerate_version
from
diffusers.utils.testing_utils
import
(
CaptureLogger
,
...
...
@@ -39,6 +42,7 @@ from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance
,
require_accelerate
,
require_bitsandbytes_version_greater
,
require_peft_backend
,
require_peft_version_greater
,
require_torch
,
require_torch_accelerator
,
...
...
@@ -697,6 +701,50 @@ class SlowBnb8bitFluxTests(Base8bitTests):
self
.
assertTrue
(
max_diff
<
1e-3
)
@
require_transformers_version_greater
(
"4.44.0"
)
@
require_peft_backend
class
SlowBnb4BitFluxControlWithLoraTests
(
Base8bitTests
):
def
setUp
(
self
)
->
None
:
gc
.
collect
()
backend_empty_cache
(
torch_device
)
self
.
pipeline_8bit
=
FluxControlPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-dev"
,
quantization_config
=
PipelineQuantizationConfig
(
quant_backend
=
"bitsandbytes_8bit"
,
quant_kwargs
=
{
"load_in_8bit"
:
True
},
components_to_quantize
=
[
"transformer"
,
"text_encoder_2"
],
),
torch_dtype
=
torch
.
float16
,
)
self
.
pipeline_8bit
.
enable_model_cpu_offload
()
def
tearDown
(
self
):
del
self
.
pipeline_8bit
gc
.
collect
()
backend_empty_cache
(
torch_device
)
def
test_lora_loading
(
self
):
self
.
pipeline_8bit
.
load_lora_weights
(
"black-forest-labs/FLUX.1-Canny-dev-lora"
)
output
=
self
.
pipeline_8bit
(
prompt
=
self
.
prompt
,
control_image
=
Image
.
new
(
mode
=
"RGB"
,
size
=
(
256
,
256
)),
height
=
256
,
width
=
256
,
max_sequence_length
=
64
,
output_type
=
"np"
,
num_inference_steps
=
8
,
generator
=
torch
.
Generator
().
manual_seed
(
42
),
).
images
out_slice
=
output
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
()
expected_slice
=
np
.
array
([
0.2029
,
0.2136
,
0.2268
,
0.1921
,
0.1997
,
0.2185
,
0.2021
,
0.2183
,
0.2292
])
max_diff
=
numpy_cosine_similarity_distance
(
expected_slice
,
out_slice
)
self
.
assertTrue
(
max_diff
<
1e-3
,
msg
=
f
"
{
out_slice
=
}
!=
{
expected_slice
=
}
"
)
@
slow
class
BaseBnb8bitSerializationTests
(
Base8bitTests
):
def
setUp
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment