Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
6a376cee
Unverified
Commit
6a376cee
authored
Dec 30, 2023
by
Sayak Paul
Committed by
GitHub
Dec 30, 2023
Browse files
[LoRA] remove unnecessary components from lora peft test suite (#6401)
remove unnecessary components from lora peft suite/
parent
9f283b01
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
52 deletions
+22
-52
tests/lora/test_lora_layers_peft.py
tests/lora/test_lora_layers_peft.py
+22
-52
No files found.
tests/lora/test_lora_layers_peft.py
View file @
6a376cee
...
...
@@ -22,7 +22,6 @@ import unittest
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
huggingface_hub
import
hf_hub_download
from
huggingface_hub.repocard
import
RepoCard
from
packaging
import
version
...
...
@@ -41,8 +40,6 @@ from diffusers import (
StableDiffusionXLPipeline
,
UNet2DConditionModel
,
)
from
diffusers.loaders
import
AttnProcsLayers
from
diffusers.models.attention_processor
import
LoRAAttnProcessor
,
LoRAAttnProcessor2_0
from
diffusers.utils.import_utils
import
is_accelerate_available
,
is_peft_available
from
diffusers.utils.testing_utils
import
(
floats_tensor
,
...
...
@@ -78,28 +75,6 @@ def state_dicts_almost_equal(sd1, sd2):
return
models_are_equal
def
create_unet_lora_layers
(
unet
:
nn
.
Module
):
lora_attn_procs
=
{}
for
name
in
unet
.
attn_processors
.
keys
():
cross_attention_dim
=
None
if
name
.
endswith
(
"attn1.processor"
)
else
unet
.
config
.
cross_attention_dim
if
name
.
startswith
(
"mid_block"
):
hidden_size
=
unet
.
config
.
block_out_channels
[
-
1
]
elif
name
.
startswith
(
"up_blocks"
):
block_id
=
int
(
name
[
len
(
"up_blocks."
)])
hidden_size
=
list
(
reversed
(
unet
.
config
.
block_out_channels
))[
block_id
]
elif
name
.
startswith
(
"down_blocks"
):
block_id
=
int
(
name
[
len
(
"down_blocks."
)])
hidden_size
=
unet
.
config
.
block_out_channels
[
block_id
]
lora_attn_processor_class
=
(
LoRAAttnProcessor2_0
if
hasattr
(
F
,
"scaled_dot_product_attention"
)
else
LoRAAttnProcessor
)
lora_attn_procs
[
name
]
=
lora_attn_processor_class
(
hidden_size
=
hidden_size
,
cross_attention_dim
=
cross_attention_dim
)
unet_lora_layers
=
AttnProcsLayers
(
lora_attn_procs
)
return
lora_attn_procs
,
unet_lora_layers
@
require_peft_backend
class
PeftLoraLoaderMixinTests
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
...
...
@@ -140,8 +115,6 @@ class PeftLoraLoaderMixinTests:
r
=
rank
,
lora_alpha
=
rank
,
target_modules
=
[
"to_q"
,
"to_k"
,
"to_v"
,
"to_out.0"
],
init_lora_weights
=
False
)
unet_lora_attn_procs
,
unet_lora_layers
=
create_unet_lora_layers
(
unet
)
if
self
.
has_two_text_encoders
:
pipeline_components
=
{
"unet"
:
unet
,
...
...
@@ -165,11 +138,8 @@ class PeftLoraLoaderMixinTests:
"feature_extractor"
:
None
,
"image_encoder"
:
None
,
}
lora_components
=
{
"unet_lora_layers"
:
unet_lora_layers
,
"unet_lora_attn_procs"
:
unet_lora_attn_procs
,
}
return
pipeline_components
,
lora_components
,
text_lora_config
,
unet_lora_config
return
pipeline_components
,
text_lora_config
,
unet_lora_config
def
get_dummy_inputs
(
self
,
with_generator
=
True
):
batch_size
=
1
...
...
@@ -216,7 +186,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference and makes sure it works as expected
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -231,7 +201,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -262,7 +232,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -309,7 +279,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -351,7 +321,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -394,7 +364,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where users could use saving utilities for LoRA.
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -459,7 +429,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -510,7 +480,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -583,7 +553,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -637,7 +607,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected - with unet
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -683,7 +653,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -730,7 +700,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -780,7 +750,7 @@ class PeftLoraLoaderMixinTests:
multiple adapters and set them
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -848,7 +818,7 @@ class PeftLoraLoaderMixinTests:
multiple adapters and set/delete them
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -938,7 +908,7 @@ class PeftLoraLoaderMixinTests:
multiple adapters and set them
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1010,7 +980,7 @@ class PeftLoraLoaderMixinTests:
def
test_lora_fuse_nan
(
self
):
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1048,7 +1018,7 @@ class PeftLoraLoaderMixinTests:
are the expected results
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1075,7 +1045,7 @@ class PeftLoraLoaderMixinTests:
are the expected results
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1113,7 +1083,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected - with unet and multi-adapter case
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1175,7 +1145,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment