Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
05e7a854
Unverified
Commit
05e7a854
authored
Jun 28, 2025
by
Sayak Paul
Committed by
GitHub
Jun 28, 2025
Browse files
[lora] fix: lora unloading behvaiour (#11822)
* fix: lora unloading behvaiour * fix * update
parent
76ec3d1f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
24 deletions
+43
-24
src/diffusers/loaders/peft.py
src/diffusers/loaders/peft.py
+2
-0
tests/lora/utils.py
tests/lora/utils.py
+41
-24
No files found.
src/diffusers/loaders/peft.py
View file @
05e7a854
...
...
@@ -693,6 +693,8 @@ class PeftAdapterMixin:
recurse_remove_peft_layers
(
self
)
if
hasattr
(
self
,
"peft_config"
):
del
self
.
peft_config
if
hasattr
(
self
,
"_hf_peft_config_loaded"
):
self
.
_hf_peft_config_loaded
=
None
_maybe_remove_and_reapply_group_offloading
(
self
)
...
...
tests/lora/utils.py
View file @
05e7a854
...
...
@@ -291,9 +291,7 @@ class PeftLoraLoaderMixinTests:
return
modules_to_save
def
check_if_adapters_added_correctly
(
self
,
pipe
,
text_lora_config
=
None
,
denoiser_lora_config
=
None
,
adapter_name
=
"default"
):
def
add_adapters_to_pipeline
(
self
,
pipe
,
text_lora_config
=
None
,
denoiser_lora_config
=
None
,
adapter_name
=
"default"
):
if
text_lora_config
is
not
None
:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
adapter_name
=
adapter_name
)
...
...
@@ -345,7 +343,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
pipe
,
_
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
output_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
...
...
@@ -428,7 +426,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
,
_
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
images_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
...
...
@@ -484,7 +482,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
pipe
,
_
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
output_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
...
...
@@ -522,7 +520,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
pipe
,
_
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
pipe
.
fuse_lora
()
# Fusing should still keep the LoRA layers
...
...
@@ -554,7 +552,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
pipe
,
_
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
pipe
.
unload_lora_weights
()
# unloading should remove the LoRA layers
...
...
@@ -589,7 +587,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
pipe
,
_
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
images_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
...
...
@@ -640,7 +638,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
pipe
,
_
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
state_dict
=
{}
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
...
...
@@ -691,7 +689,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
pipe
,
_
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
images_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
...
...
@@ -734,7 +732,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
,
_
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
images_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
...
...
@@ -775,7 +773,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
,
_
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
output_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
...
...
@@ -819,7 +817,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
denoiser
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
,
denoiser
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
)
...
...
@@ -857,7 +855,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
denoiser
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
,
denoiser
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
.
unload_lora_weights
()
# unloading should remove the LoRA layers
...
...
@@ -893,7 +891,7 @@ class PeftLoraLoaderMixinTests:
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
pipe
,
denoiser
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
,
denoiser
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
)
self
.
assertTrue
(
pipe
.
num_fused_loras
==
1
,
f
"
{
pipe
.
num_fused_loras
=
}
,
{
pipe
.
fused_loras
=
}
"
)
...
...
@@ -1010,7 +1008,7 @@ class PeftLoraLoaderMixinTests:
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
pipe
,
_
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
_
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
,
adapter_name
=
adapter_name
)
...
...
@@ -1032,7 +1030,7 @@ class PeftLoraLoaderMixinTests:
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
pipe
,
_
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
_
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
,
adapter_name
=
adapter_name
)
...
...
@@ -1759,7 +1757,7 @@ class PeftLoraLoaderMixinTests:
output_no_dora_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_dora_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
,
_
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
output_dora_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
...
...
@@ -1850,7 +1848,7 @@ class PeftLoraLoaderMixinTests:
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
pipe
,
_
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
,
_
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
.
unet
=
torch
.
compile
(
pipe
.
unet
,
mode
=
"reduce-overhead"
,
fullgraph
=
True
)
pipe
.
text_encoder
=
torch
.
compile
(
pipe
.
text_encoder
,
mode
=
"reduce-overhead"
,
fullgraph
=
True
)
...
...
@@ -1937,7 +1935,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
,
_
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
lora_scale
=
0.5
attention_kwargs
=
{
attention_kwargs_name
:
{
"scale"
:
lora_scale
}}
...
...
@@ -2119,7 +2117,7 @@ class PeftLoraLoaderMixinTests:
pipe
=
pipe
.
to
(
torch_device
,
dtype
=
compute_dtype
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
,
denoiser
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
,
denoiser
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
if
storage_dtype
is
not
None
:
denoiser
.
enable_layerwise_casting
(
storage_dtype
=
storage_dtype
,
compute_dtype
=
compute_dtype
)
...
...
@@ -2237,7 +2235,7 @@ class PeftLoraLoaderMixinTests:
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
,
_
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
_
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
=
text_lora_config
,
denoiser_lora_config
=
denoiser_lora_config
)
...
...
@@ -2290,7 +2288,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
check_if
_adapters_
added_correctly
(
pipe
,
_
=
self
.
add
_adapters_
to_pipeline
(
pipe
,
text_lora_config
=
text_lora_config
,
denoiser_lora_config
=
denoiser_lora_config
)
output_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
...
...
@@ -2309,6 +2307,25 @@ class PeftLoraLoaderMixinTests:
np
.
allclose
(
output_lora
,
output_lora_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora outputs should match."
)
def
test_lora_unload_add_adapter
(
self
):
"""Tests if `unload_lora_weights()` -> `add_adapter()` works."""
scheduler_cls
=
self
.
scheduler_classes
[
0
]
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
).
to
(
torch_device
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
=
text_lora_config
,
denoiser_lora_config
=
denoiser_lora_config
)
_
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
# unload and then add.
pipe
.
unload_lora_weights
()
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
=
text_lora_config
,
denoiser_lora_config
=
denoiser_lora_config
)
_
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
def
test_inference_load_delete_load_adapters
(
self
):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for
scheduler_cls
in
self
.
scheduler_classes
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment