Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
6a376cee
Unverified
Commit
6a376cee
authored
Dec 30, 2023
by
Sayak Paul
Committed by
GitHub
Dec 30, 2023
Browse files
[LoRA] remove unnecessary components from lora peft test suite (#6401)
remove unnecessary components from lora peft suite/
parent
9f283b01
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
52 deletions
+22
-52
tests/lora/test_lora_layers_peft.py
tests/lora/test_lora_layers_peft.py
+22
-52
No files found.
tests/lora/test_lora_layers_peft.py
View file @
6a376cee
...
@@ -22,7 +22,6 @@ import unittest
...
@@ -22,7 +22,6 @@ import unittest
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
huggingface_hub
import
hf_hub_download
from
huggingface_hub
import
hf_hub_download
from
huggingface_hub.repocard
import
RepoCard
from
huggingface_hub.repocard
import
RepoCard
from
packaging
import
version
from
packaging
import
version
...
@@ -41,8 +40,6 @@ from diffusers import (
...
@@ -41,8 +40,6 @@ from diffusers import (
StableDiffusionXLPipeline
,
StableDiffusionXLPipeline
,
UNet2DConditionModel
,
UNet2DConditionModel
,
)
)
from
diffusers.loaders
import
AttnProcsLayers
from
diffusers.models.attention_processor
import
LoRAAttnProcessor
,
LoRAAttnProcessor2_0
from
diffusers.utils.import_utils
import
is_accelerate_available
,
is_peft_available
from
diffusers.utils.import_utils
import
is_accelerate_available
,
is_peft_available
from
diffusers.utils.testing_utils
import
(
from
diffusers.utils.testing_utils
import
(
floats_tensor
,
floats_tensor
,
...
@@ -78,28 +75,6 @@ def state_dicts_almost_equal(sd1, sd2):
...
@@ -78,28 +75,6 @@ def state_dicts_almost_equal(sd1, sd2):
return
models_are_equal
return
models_are_equal
def
create_unet_lora_layers
(
unet
:
nn
.
Module
):
lora_attn_procs
=
{}
for
name
in
unet
.
attn_processors
.
keys
():
cross_attention_dim
=
None
if
name
.
endswith
(
"attn1.processor"
)
else
unet
.
config
.
cross_attention_dim
if
name
.
startswith
(
"mid_block"
):
hidden_size
=
unet
.
config
.
block_out_channels
[
-
1
]
elif
name
.
startswith
(
"up_blocks"
):
block_id
=
int
(
name
[
len
(
"up_blocks."
)])
hidden_size
=
list
(
reversed
(
unet
.
config
.
block_out_channels
))[
block_id
]
elif
name
.
startswith
(
"down_blocks"
):
block_id
=
int
(
name
[
len
(
"down_blocks."
)])
hidden_size
=
unet
.
config
.
block_out_channels
[
block_id
]
lora_attn_processor_class
=
(
LoRAAttnProcessor2_0
if
hasattr
(
F
,
"scaled_dot_product_attention"
)
else
LoRAAttnProcessor
)
lora_attn_procs
[
name
]
=
lora_attn_processor_class
(
hidden_size
=
hidden_size
,
cross_attention_dim
=
cross_attention_dim
)
unet_lora_layers
=
AttnProcsLayers
(
lora_attn_procs
)
return
lora_attn_procs
,
unet_lora_layers
@
require_peft_backend
@
require_peft_backend
class
PeftLoraLoaderMixinTests
:
class
PeftLoraLoaderMixinTests
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
...
@@ -140,8 +115,6 @@ class PeftLoraLoaderMixinTests:
...
@@ -140,8 +115,6 @@ class PeftLoraLoaderMixinTests:
r
=
rank
,
lora_alpha
=
rank
,
target_modules
=
[
"to_q"
,
"to_k"
,
"to_v"
,
"to_out.0"
],
init_lora_weights
=
False
r
=
rank
,
lora_alpha
=
rank
,
target_modules
=
[
"to_q"
,
"to_k"
,
"to_v"
,
"to_out.0"
],
init_lora_weights
=
False
)
)
unet_lora_attn_procs
,
unet_lora_layers
=
create_unet_lora_layers
(
unet
)
if
self
.
has_two_text_encoders
:
if
self
.
has_two_text_encoders
:
pipeline_components
=
{
pipeline_components
=
{
"unet"
:
unet
,
"unet"
:
unet
,
...
@@ -165,11 +138,8 @@ class PeftLoraLoaderMixinTests:
...
@@ -165,11 +138,8 @@ class PeftLoraLoaderMixinTests:
"feature_extractor"
:
None
,
"feature_extractor"
:
None
,
"image_encoder"
:
None
,
"image_encoder"
:
None
,
}
}
lora_components
=
{
"unet_lora_layers"
:
unet_lora_layers
,
return
pipeline_components
,
text_lora_config
,
unet_lora_config
"unet_lora_attn_procs"
:
unet_lora_attn_procs
,
}
return
pipeline_components
,
lora_components
,
text_lora_config
,
unet_lora_config
def
get_dummy_inputs
(
self
,
with_generator
=
True
):
def
get_dummy_inputs
(
self
,
with_generator
=
True
):
batch_size
=
1
batch_size
=
1
...
@@ -216,7 +186,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -216,7 +186,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference and makes sure it works as expected
Tests a simple inference and makes sure it works as expected
"""
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -231,7 +201,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -231,7 +201,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
and makes sure it works as expected
"""
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -262,7 +232,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -262,7 +232,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
and makes sure it works as expected
"""
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -309,7 +279,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -309,7 +279,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
and makes sure it works as expected
"""
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -351,7 +321,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -351,7 +321,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
and makes sure it works as expected
"""
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -394,7 +364,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -394,7 +364,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where users could use saving utilities for LoRA.
Tests a simple usecase where users could use saving utilities for LoRA.
"""
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -459,7 +429,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -459,7 +429,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -510,7 +480,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -510,7 +480,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
"""
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -583,7 +553,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -583,7 +553,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
and makes sure it works as expected
"""
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -637,7 +607,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -637,7 +607,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected - with unet
and makes sure it works as expected - with unet
"""
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -683,7 +653,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -683,7 +653,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
and makes sure it works as expected
"""
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -730,7 +700,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -730,7 +700,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
and makes sure it works as expected
"""
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -780,7 +750,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -780,7 +750,7 @@ class PeftLoraLoaderMixinTests:
multiple adapters and set them
multiple adapters and set them
"""
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -848,7 +818,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -848,7 +818,7 @@ class PeftLoraLoaderMixinTests:
multiple adapters and set/delete them
multiple adapters and set/delete them
"""
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -938,7 +908,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -938,7 +908,7 @@ class PeftLoraLoaderMixinTests:
multiple adapters and set them
multiple adapters and set them
"""
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1010,7 +980,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1010,7 +980,7 @@ class PeftLoraLoaderMixinTests:
def
test_lora_fuse_nan
(
self
):
def
test_lora_fuse_nan
(
self
):
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1048,7 +1018,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1048,7 +1018,7 @@ class PeftLoraLoaderMixinTests:
are the expected results
are the expected results
"""
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1075,7 +1045,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1075,7 +1045,7 @@ class PeftLoraLoaderMixinTests:
are the expected results
are the expected results
"""
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1113,7 +1083,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1113,7 +1083,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected - with unet and multi-adapter case
and makes sure it works as expected - with unet and multi-adapter case
"""
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1175,7 +1145,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1175,7 +1145,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
and makes sure it works as expected
"""
"""
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
for
scheduler_cls
in
[
DDIMScheduler
,
LCMScheduler
]:
components
,
_
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
unet_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
=
pipe
.
to
(
self
.
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment