Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
09e777a3
Unverified
Commit
09e777a3
authored
Sep 24, 2025
by
Sayak Paul
Committed by
GitHub
Sep 24, 2025
Browse files
[tests] Single scheduler in lora tests (#12315)
* single scheduler please. * up * up * up
parent
a72bc0c4
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
1044 additions
and
1143 deletions
+1044
-1143
tests/lora/test_lora_layers_auraflow.py
tests/lora/test_lora_layers_auraflow.py
+0
-1
tests/lora/test_lora_layers_cogvideox.py
tests/lora/test_lora_layers_cogvideox.py
+0
-2
tests/lora/test_lora_layers_cogview4.py
tests/lora/test_lora_layers_cogview4.py
+17
-19
tests/lora/test_lora_layers_flux.py
tests/lora/test_lora_layers_flux.py
+2
-4
tests/lora/test_lora_layers_hunyuanvideo.py
tests/lora/test_lora_layers_hunyuanvideo.py
+0
-1
tests/lora/test_lora_layers_ltx_video.py
tests/lora/test_lora_layers_ltx_video.py
+0
-1
tests/lora/test_lora_layers_lumina2.py
tests/lora/test_lora_layers_lumina2.py
+27
-31
tests/lora/test_lora_layers_mochi.py
tests/lora/test_lora_layers_mochi.py
+0
-1
tests/lora/test_lora_layers_qwenimage.py
tests/lora/test_lora_layers_qwenimage.py
+0
-1
tests/lora/test_lora_layers_sana.py
tests/lora/test_lora_layers_sana.py
+2
-3
tests/lora/test_lora_layers_sd3.py
tests/lora/test_lora_layers_sd3.py
+0
-1
tests/lora/test_lora_layers_wan.py
tests/lora/test_lora_layers_wan.py
+0
-1
tests/lora/test_lora_layers_wanvace.py
tests/lora/test_lora_layers_wanvace.py
+1
-3
tests/lora/utils.py
tests/lora/utils.py
+995
-1074
No files found.
tests/lora/test_lora_layers_auraflow.py
View file @
09e777a3
...
...
@@ -43,7 +43,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
AuraFlowLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
AuraFlowPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
transformer_kwargs
=
{
...
...
tests/lora/test_lora_layers_cogvideox.py
View file @
09e777a3
...
...
@@ -21,7 +21,6 @@ from transformers import AutoTokenizer, T5EncoderModel
from
diffusers
import
(
AutoencoderKLCogVideoX
,
CogVideoXDDIMScheduler
,
CogVideoXDPMScheduler
,
CogVideoXPipeline
,
CogVideoXTransformer3DModel
,
...
...
@@ -44,7 +43,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class
=
CogVideoXPipeline
scheduler_cls
=
CogVideoXDPMScheduler
scheduler_kwargs
=
{
"timestep_spacing"
:
"trailing"
}
scheduler_classes
=
[
CogVideoXDDIMScheduler
,
CogVideoXDPMScheduler
]
transformer_kwargs
=
{
"num_attention_heads"
:
4
,
...
...
tests/lora/test_lora_layers_cogview4.py
View file @
09e777a3
...
...
@@ -50,7 +50,6 @@ class TokenizerWrapper:
class
CogView4LoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
CogView4Pipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
transformer_kwargs
=
{
...
...
@@ -124,30 +123,29 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
_
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
components
,
_
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
images_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
images_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pipe
.
save_pretrained
(
tmpdirname
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pipe
.
save_pretrained
(
tmpdirname
)
pipe_from_pretrained
=
self
.
pipeline_class
.
from_pretrained
(
tmpdirname
)
pipe_from_pretrained
.
to
(
torch_device
)
pipe_from_pretrained
=
self
.
pipeline_class
.
from_pretrained
(
tmpdirname
)
pipe_from_pretrained
.
to
(
torch_device
)
images_lora_save_pretrained
=
pipe_from_pretrained
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
images_lora_save_pretrained
=
pipe_from_pretrained
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
images_lora
,
images_lora_save_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Loading from saved checkpoints should give same results."
,
)
self
.
assertTrue
(
np
.
allclose
(
images_lora
,
images_lora_save_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Loading from saved checkpoints should give same results."
,
)
@
parameterized
.
expand
([(
"block_level"
,
True
),
(
"leaf_level"
,
False
)])
@
require_torch_accelerator
...
...
tests/lora/test_lora_layers_flux.py
View file @
09e777a3
...
...
@@ -55,9 +55,8 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
@
require_peft_backend
class
FluxLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
FluxPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
()
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_kwargs
=
{}
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
transformer_kwargs
=
{
"patch_size"
:
1
,
"in_channels"
:
4
,
...
...
@@ -282,9 +281,8 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class
FluxControlLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
FluxControlPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
()
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_kwargs
=
{}
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
transformer_kwargs
=
{
"patch_size"
:
1
,
"in_channels"
:
8
,
...
...
tests/lora/test_lora_layers_hunyuanvideo.py
View file @
09e777a3
...
...
@@ -51,7 +51,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
HunyuanVideoLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
HunyuanVideoPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
transformer_kwargs
=
{
...
...
tests/lora/test_lora_layers_ltx_video.py
View file @
09e777a3
...
...
@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
LTXVideoLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
LTXPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
transformer_kwargs
=
{
...
...
tests/lora/test_lora_layers_lumina2.py
View file @
09e777a3
...
...
@@ -39,7 +39,6 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
class
Lumina2LoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
Lumina2Pipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
transformer_kwargs
=
{
...
...
@@ -141,33 +140,30 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
strict
=
False
,
)
def
test_lora_fuse_nan
(
self
):
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
# corrupt one LoRA weight with `inf` values
with
torch
.
no_grad
():
pipe
.
transformer
.
layers
[
0
].
attn
.
to_q
.
lora_A
[
"adapter-1"
].
weight
+=
float
(
"inf"
)
# with `safe_fusing=True` we should see an Error
with
self
.
assertRaises
(
ValueError
):
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
,
safe_fusing
=
True
)
# without we should not see an error, but every image will be black
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
,
safe_fusing
=
False
)
out
=
pipe
(
**
inputs
)[
0
]
self
.
assertTrue
(
np
.
isnan
(
out
).
all
())
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
# corrupt one LoRA weight with `inf` values
with
torch
.
no_grad
():
pipe
.
transformer
.
layers
[
0
].
attn
.
to_q
.
lora_A
[
"adapter-1"
].
weight
+=
float
(
"inf"
)
# with `safe_fusing=True` we should see an Error
with
self
.
assertRaises
(
ValueError
):
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
,
safe_fusing
=
True
)
# without we should not see an error, but every image will be black
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
,
safe_fusing
=
False
)
out
=
pipe
(
**
inputs
)[
0
]
self
.
assertTrue
(
np
.
isnan
(
out
).
all
())
tests/lora/test_lora_layers_mochi.py
View file @
09e777a3
...
...
@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
MochiLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
MochiPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
transformer_kwargs
=
{
...
...
tests/lora/test_lora_layers_qwenimage.py
View file @
09e777a3
...
...
@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
QwenImageLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
QwenImagePipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
transformer_kwargs
=
{
...
...
tests/lora/test_lora_layers_sana.py
View file @
09e777a3
...
...
@@ -31,9 +31,8 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@
require_peft_backend
class
SanaLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
SanaPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
(
shift
=
7.0
)
scheduler_kwargs
=
{}
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_kwargs
=
{
"shift"
:
7.0
}
transformer_kwargs
=
{
"patch_size"
:
1
,
"in_channels"
:
4
,
...
...
tests/lora/test_lora_layers_sd3.py
View file @
09e777a3
...
...
@@ -55,7 +55,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class
=
StableDiffusion3Pipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_kwargs
=
{}
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
transformer_kwargs
=
{
"sample_size"
:
32
,
"patch_size"
:
1
,
...
...
tests/lora/test_lora_layers_wan.py
View file @
09e777a3
...
...
@@ -42,7 +42,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
WanLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
WanPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
transformer_kwargs
=
{
...
...
tests/lora/test_lora_layers_wanvace.py
View file @
09e777a3
...
...
@@ -50,7 +50,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
WanVACELoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
WanVACEPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
transformer_kwargs
=
{
...
...
@@ -165,9 +164,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@
require_peft_version_greater
(
"0.13.2"
)
def
test_lora_exclude_modules_wanvace
(
self
):
scheduler_cls
=
self
.
scheduler_classes
[
0
]
exclude_module_name
=
"vace_blocks.0.proj_out"
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
).
to
(
torch_device
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
...
...
tests/lora/utils.py
View file @
09e777a3
...
...
@@ -26,8 +26,6 @@ from parameterized import parameterized
from
diffusers
import
(
AutoencoderKL
,
DDIMScheduler
,
LCMScheduler
,
UNet2DConditionModel
,
)
from
diffusers.utils
import
logging
...
...
@@ -109,7 +107,6 @@ class PeftLoraLoaderMixinTests:
scheduler_cls
=
None
scheduler_kwargs
=
None
scheduler_classes
=
[
DDIMScheduler
,
LCMScheduler
]
has_two_text_encoders
=
False
has_three_text_encoders
=
False
...
...
@@ -129,13 +126,13 @@ class PeftLoraLoaderMixinTests:
text_encoder_target_modules
=
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
"out_proj"
]
denoiser_target_modules
=
[
"to_q"
,
"to_k"
,
"to_v"
,
"to_out.0"
]
def
get_dummy_components
(
self
,
scheduler_cls
=
None
,
use_dora
=
False
,
lora_alpha
=
None
):
def
get_dummy_components
(
self
,
use_dora
=
False
,
lora_alpha
=
None
):
if
self
.
unet_kwargs
and
self
.
transformer_kwargs
:
raise
ValueError
(
"Both `unet_kwargs` and `transformer_kwargs` cannot be specified."
)
if
self
.
has_two_text_encoders
and
self
.
has_three_text_encoders
:
raise
ValueError
(
"Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True."
)
scheduler_cls
=
self
.
scheduler_cls
if
scheduler_cls
is
None
else
scheduler_cls
scheduler_cls
=
self
.
scheduler_cls
rank
=
4
lora_alpha
=
rank
if
lora_alpha
is
None
else
lora_alpha
...
...
@@ -319,152 +316,143 @@ class PeftLoraLoaderMixinTests:
"""
Tests a simple inference and makes sure it works as expected
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
()
output_no_lora
=
pipe
(
**
inputs
)[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
()
output_no_lora
=
pipe
(
**
inputs
)[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
def
test_simple_inference_with_text_lora
(
self
):
"""
Tests a simple inference with lora attached on the text encoder
and makes sure it works as expected
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
output_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
not
np
.
allclose
(
output_lora
,
output_no_lora
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora should change the output"
)
output_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
not
np
.
allclose
(
output_lora
,
output_no_lora
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora should change the output"
)
@
require_peft_version_greater
(
"0.13.1"
)
def
test_low_cpu_mem_usage_with_injection
(
self
):
"""Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
inject_adapter_in_model
(
text_lora_config
,
pipe
.
text_encoder
,
low_cpu_mem_usage
=
True
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
inject_adapter_in_model
(
text_lora_config
,
pipe
.
text_encoder
,
low_cpu_mem_usage
=
True
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder."
)
self
.
assertTrue
(
"meta"
in
{
p
.
device
.
type
for
p
in
pipe
.
text_encoder
.
parameters
()},
"The LoRA params should be on 'meta' device."
,
)
te_state_dict
=
initialize_dummy_state_dict
(
get_peft_model_state_dict
(
pipe
.
text_encoder
))
set_peft_model_state_dict
(
pipe
.
text_encoder
,
te_state_dict
,
low_cpu_mem_usage
=
True
)
self
.
assertTrue
(
"meta"
not
in
{
p
.
device
.
type
for
p
in
pipe
.
text_encoder
.
parameters
()},
"No param should be on 'meta' device."
,
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
inject_adapter_in_model
(
denoiser_lora_config
,
denoiser
,
low_cpu_mem_usage
=
True
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
self
.
assertTrue
(
"meta"
in
{
p
.
device
.
type
for
p
in
denoiser
.
parameters
()},
"The LoRA params should be on 'meta' device."
)
denoiser_state_dict
=
initialize_dummy_state_dict
(
get_peft_model_state_dict
(
denoiser
))
set_peft_model_state_dict
(
denoiser
,
denoiser_state_dict
,
low_cpu_mem_usage
=
True
)
self
.
assertTrue
(
"meta"
not
in
{
p
.
device
.
type
for
p
in
denoiser
.
parameters
()},
"No param should be on 'meta' device."
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
inject_adapter_in_model
(
text_lora_config
,
pipe
.
text_encoder_2
,
low_cpu_mem_usage
=
True
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder
.
"
check_if_lora_correctly_set
(
pipe
.
text_encoder
_2
),
"Lora not correctly set in text encoder
2
"
)
self
.
assertTrue
(
"meta"
in
{
p
.
device
.
type
for
p
in
pipe
.
text_encoder
.
parameters
()},
"meta"
in
{
p
.
device
.
type
for
p
in
pipe
.
text_encoder
_2
.
parameters
()},
"The LoRA params should be on 'meta' device."
,
)
te_state_dict
=
initialize_dummy_state_dict
(
get_peft_model_state_dict
(
pipe
.
text_encoder
))
set_peft_model_state_dict
(
pipe
.
text_encoder
,
te_state_dict
,
low_cpu_mem_usage
=
True
)
te
2
_state_dict
=
initialize_dummy_state_dict
(
get_peft_model_state_dict
(
pipe
.
text_encoder
_2
))
set_peft_model_state_dict
(
pipe
.
text_encoder
_2
,
te
2
_state_dict
,
low_cpu_mem_usage
=
True
)
self
.
assertTrue
(
"meta"
not
in
{
p
.
device
.
type
for
p
in
pipe
.
text_encoder
.
parameters
()},
"meta"
not
in
{
p
.
device
.
type
for
p
in
pipe
.
text_encoder
_2
.
parameters
()},
"No param should be on 'meta' device."
,
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
inject_adapter_in_model
(
denoiser_lora_config
,
denoiser
,
low_cpu_mem_usage
=
True
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
self
.
assertTrue
(
"meta"
in
{
p
.
device
.
type
for
p
in
denoiser
.
parameters
()},
"The LoRA params should be on 'meta' device."
)
denoiser_state_dict
=
initialize_dummy_state_dict
(
get_peft_model_state_dict
(
denoiser
))
set_peft_model_state_dict
(
denoiser
,
denoiser_state_dict
,
low_cpu_mem_usage
=
True
)
self
.
assertTrue
(
"meta"
not
in
{
p
.
device
.
type
for
p
in
denoiser
.
parameters
()},
"No param should be on 'meta' device."
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
inject_adapter_in_model
(
text_lora_config
,
pipe
.
text_encoder_2
,
low_cpu_mem_usage
=
True
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
)
self
.
assertTrue
(
"meta"
in
{
p
.
device
.
type
for
p
in
pipe
.
text_encoder_2
.
parameters
()},
"The LoRA params should be on 'meta' device."
,
)
te2_state_dict
=
initialize_dummy_state_dict
(
get_peft_model_state_dict
(
pipe
.
text_encoder_2
))
set_peft_model_state_dict
(
pipe
.
text_encoder_2
,
te2_state_dict
,
low_cpu_mem_usage
=
True
)
self
.
assertTrue
(
"meta"
not
in
{
p
.
device
.
type
for
p
in
pipe
.
text_encoder_2
.
parameters
()},
"No param should be on 'meta' device."
,
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
()
output_lora
=
pipe
(
**
inputs
)[
0
]
self
.
assertTrue
(
output_lora
.
shape
==
self
.
output_shape
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
()
output_lora
=
pipe
(
**
inputs
)[
0
]
self
.
assertTrue
(
output_lora
.
shape
==
self
.
output_shape
)
@
require_peft_version_greater
(
"0.13.1"
)
@
require_transformers_version_greater
(
"4.45.2"
)
def
test_low_cpu_mem_usage_with_loading
(
self
):
"""Tests if we can load LoRA state dict with low_cpu_mem_usage."""
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
images_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
images_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
modules_to_save
=
self
.
_get_modules_to_save
(
pipe
,
has_denoiser
=
True
)
lora_state_dicts
=
self
.
_get_lora_state_dicts
(
modules_to_save
)
self
.
pipeline_class
.
save_lora_weights
(
save_directory
=
tmpdirname
,
safe_serialization
=
False
,
**
lora_state_dicts
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
modules_to_save
=
self
.
_get_modules_to_save
(
pipe
,
has_denoiser
=
True
)
lora_state_dicts
=
self
.
_get_lora_state_dicts
(
modules_to_save
)
self
.
pipeline_class
.
save_lora_weights
(
save_directory
=
tmpdirname
,
safe_serialization
=
False
,
**
lora_state_dicts
)
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.bin"
)))
pipe
.
unload_lora_weights
()
pipe
.
load_lora_weights
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.bin"
),
low_cpu_mem_usage
=
False
)
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.bin"
)))
pipe
.
unload_lora_weights
()
pipe
.
load_lora_weights
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.bin"
),
low_cpu_mem_usage
=
False
)
for
module_name
,
module
in
modules_to_save
.
items
():
self
.
assertTrue
(
check_if_lora_correctly_set
(
module
),
f
"Lora not correctly set in
{
module_name
}
"
)
for
module_name
,
module
in
modules_to_save
.
items
():
self
.
assertTrue
(
check_if_lora_correctly_set
(
module
),
f
"Lora not correctly set in
{
module_name
}
"
)
images_lora_from_pretrained
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
images_lora
,
images_lora_from_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Loading from saved checkpoints should give same results."
,
)
images_lora_from_pretrained
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
images_lora
,
images_lora_from_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Loading from saved checkpoints should give same results."
,
)
# Now, check for `low_cpu_mem_usage.`
pipe
.
unload_lora_weights
()
pipe
.
load_lora_weights
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.bin"
),
low_cpu_mem_usage
=
True
)
# Now, check for `low_cpu_mem_usage.`
pipe
.
unload_lora_weights
()
pipe
.
load_lora_weights
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.bin"
),
low_cpu_mem_usage
=
True
)
for
module_name
,
module
in
modules_to_save
.
items
():
self
.
assertTrue
(
check_if_lora_correctly_set
(
module
),
f
"Lora not correctly set in
{
module_name
}
"
)
for
module_name
,
module
in
modules_to_save
.
items
():
self
.
assertTrue
(
check_if_lora_correctly_set
(
module
),
f
"Lora not correctly set in
{
module_name
}
"
)
images_lora_from_pretrained_low_cpu
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
images_lora_from_pretrained_low_cpu
,
images_lora_from_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Loading from saved checkpoints with `low_cpu_mem_usage` should give same results."
,
)
images_lora_from_pretrained_low_cpu
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
images_lora_from_pretrained_low_cpu
,
images_lora_from_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Loading from saved checkpoints with `low_cpu_mem_usage` should give same results."
,
)
def
test_simple_inference_with_text_lora_and_scale
(
self
):
"""
...
...
@@ -472,147 +460,140 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
attention_kwargs_name
=
determine_attention_kwargs_name
(
self
.
pipeline_class
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
output_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
not
np
.
allclose
(
output_lora
,
output_no_lora
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora should change the output"
)
output_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
not
np
.
allclose
(
output_lora
,
output_no_lora
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora should change the output"
)
attention_kwargs
=
{
attention_kwargs_name
:
{
"scale"
:
0.5
}}
output_lora_scale
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
),
**
attention_kwargs
)[
0
]
attention_kwargs
=
{
attention_kwargs_name
:
{
"scale"
:
0.5
}}
output_lora_scale
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
),
**
attention_kwargs
)[
0
]
self
.
assertTrue
(
not
np
.
allclose
(
output_lora
,
output_lora_scale
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora + scale should change the output"
,
)
self
.
assertTrue
(
not
np
.
allclose
(
output_lora
,
output_lora_scale
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora + scale should change the output"
,
)
attention_kwargs
=
{
attention_kwargs_name
:
{
"scale"
:
0.0
}}
output_lora_0_scale
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
),
**
attention_kwargs
)[
0
]
attention_kwargs
=
{
attention_kwargs_name
:
{
"scale"
:
0.0
}}
output_lora_0_scale
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
),
**
attention_kwargs
)[
0
]
self
.
assertTrue
(
np
.
allclose
(
output_no_lora
,
output_lora_0_scale
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora + 0 scale should lead to same result as no LoRA"
,
)
self
.
assertTrue
(
np
.
allclose
(
output_no_lora
,
output_lora_0_scale
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora + 0 scale should lead to same result as no LoRA"
,
)
def
test_simple_inference_with_text_lora_fused
(
self
):
"""
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
pipe
.
fuse_lora
()
# Fusing should still keep the LoRA layers
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
pipe
.
fuse_lora
()
# Fusing should still keep the LoRA layers
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
)
ouput_fused
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertFalse
(
np
.
allclose
(
ouput_fused
,
output_no_lora
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Fused lora should change the output"
)
ouput_fused
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertFalse
(
np
.
allclose
(
ouput_fused
,
output_no_lora
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Fused lora should change the output"
)
def
test_simple_inference_with_text_lora_unloaded
(
self
):
"""
Tests a simple inference with lora attached to text encoder, then unloads the lora weights
and makes sure it works as expected
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
pipe
.
unload_lora_weights
()
# unloading should remove the LoRA layers
self
.
assertFalse
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly unloaded in text encoder"
)
pipe
.
unload_lora_weights
()
# unloading should remove the LoRA layers
self
.
assertFalse
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly unloaded in text encoder"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertFalse
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly unloaded in text encoder 2"
,
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertFalse
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly unloaded in text encoder 2"
,
)
ouput_unloaded
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
ouput_unloaded
,
output_no_lora
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Fused lora should change the output"
,
)
ouput_unloaded
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
ouput_unloaded
,
output_no_lora
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Fused lora should change the output"
,
)
def
test_simple_inference_with_text_lora_save_load
(
self
):
"""
Tests a simple usecase where users could use saving utilities for LoRA.
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
images_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
images_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
modules_to_save
=
self
.
_get_modules_to_save
(
pipe
)
lora_state_dicts
=
self
.
_get_lora_state_dicts
(
modules_to_save
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
modules_to_save
=
self
.
_get_modules_to_save
(
pipe
)
lora_state_dicts
=
self
.
_get_lora_state_dicts
(
modules_to_save
)
self
.
pipeline_class
.
save_lora_weights
(
save_directory
=
tmpdirname
,
safe_serialization
=
False
,
**
lora_state_dicts
)
self
.
pipeline_class
.
save_lora_weights
(
save_directory
=
tmpdirname
,
safe_serialization
=
False
,
**
lora_state_dicts
)
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.bin"
)))
pipe
.
unload_lora_weights
()
pipe
.
load_lora_weights
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.bin"
))
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.bin"
)))
pipe
.
unload_lora_weights
()
pipe
.
load_lora_weights
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.bin"
))
for
module_name
,
module
in
modules_to_save
.
items
():
self
.
assertTrue
(
check_if_lora_correctly_set
(
module
),
f
"Lora not correctly set in
{
module_name
}
"
)
for
module_name
,
module
in
modules_to_save
.
items
():
self
.
assertTrue
(
check_if_lora_correctly_set
(
module
),
f
"Lora not correctly set in
{
module_name
}
"
)
images_lora_from_pretrained
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
images_lora_from_pretrained
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
images_lora
,
images_lora_from_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Loading from saved checkpoints should give same results."
,
)
self
.
assertTrue
(
np
.
allclose
(
images_lora
,
images_lora_from_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Loading from saved checkpoints should give same results."
,
)
def
test_simple_inference_with_partial_text_lora
(
self
):
"""
...
...
@@ -620,142 +601,139 @@ class PeftLoraLoaderMixinTests:
with different ranks and some adapters removed
and makes sure it works as expected
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
_
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
text_lora_config
=
LoraConfig
(
r
=
4
,
rank_pattern
=
{
self
.
text_encoder_target_modules
[
i
]:
i
+
1
for
i
in
range
(
3
)},
lora_alpha
=
4
,
target_modules
=
self
.
text_encoder_target_modules
,
init_lora_weights
=
False
,
use_dora
=
False
,
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
components
,
_
,
_
=
self
.
get_dummy_components
()
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
text_lora_config
=
LoraConfig
(
r
=
4
,
rank_pattern
=
{
self
.
text_encoder_target_modules
[
i
]:
i
+
1
for
i
in
range
(
3
)},
lora_alpha
=
4
,
target_modules
=
self
.
text_encoder_target_modules
,
init_lora_weights
=
False
,
use_dora
=
False
,
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
state_dict
=
{}
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
# Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
# supports missing layers (PR#8324).
state_dict
=
{
f
"text_encoder.
{
module_name
}
"
:
param
for
module_name
,
param
in
get_peft_model_state_dict
(
pipe
.
text_encoder
).
items
()
if
"text_model.encoder.layers.4"
not
in
module_name
}
state_dict
=
{}
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
# Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
# supports missing layers (PR#8324).
state_dict
=
{
f
"text_encoder.
{
module_name
}
"
:
param
for
module_name
,
param
in
get_peft_model_state_dict
(
pipe
.
text_encoder
).
items
()
if
"text_model.encoder.layers.4"
not
in
module_name
}
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
state_dict
.
update
(
{
f
"text_encoder_2.
{
module_name
}
"
:
param
for
module_name
,
param
in
get_peft_model_state_dict
(
pipe
.
text_encoder_2
).
items
()
if
"text_model.encoder.layers.4"
not
in
module_name
}
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
state_dict
.
update
(
{
f
"text_encoder_2.
{
module_name
}
"
:
param
for
module_name
,
param
in
get_peft_model_state_dict
(
pipe
.
text_encoder_2
).
items
()
if
"text_model.encoder.layers.4"
not
in
module_name
}
)
output_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
not
np
.
allclose
(
output_lora
,
output_no_lora
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora should change the output"
)
output_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
not
np
.
allclose
(
output_lora
,
output_no_lora
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora should change the output"
)
# Unload lora and load it back using the pipe.load_lora_weights machinery
pipe
.
unload_lora_weights
()
pipe
.
load_lora_weights
(
state_dict
)
# Unload lora and load it back using the pipe.load_lora_weights machinery
pipe
.
unload_lora_weights
()
pipe
.
load_lora_weights
(
state_dict
)
output_partial_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
not
np
.
allclose
(
output_partial_lora
,
output_lora
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Removing adapters should change the output"
,
)
output_partial_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
not
np
.
allclose
(
output_partial_lora
,
output_lora
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Removing adapters should change the output"
,
)
def
test_simple_inference_save_pretrained_with_text_lora
(
self
):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
images_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
=
None
)
images_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pipe
.
save_pretrained
(
tmpdirname
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pipe
.
save_pretrained
(
tmpdirname
)
pipe_from_pretrained
=
self
.
pipeline_class
.
from_pretrained
(
tmpdirname
)
pipe_from_pretrained
.
to
(
torch_device
)
pipe_from_pretrained
=
self
.
pipeline_class
.
from_pretrained
(
tmpdirname
)
pipe_from_pretrained
.
to
(
torch_device
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe_from_pretrained
.
text_encoder
),
"Lora not correctly set in text encoder"
,
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe_from_pretrained
.
text_encoder
),
"Lora not correctly set in text encoder"
,
check_if_lora_correctly_set
(
pipe_from_pretrained
.
text_encoder
_2
),
"Lora not correctly set in text encoder
2
"
,
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe_from_pretrained
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
,
)
images_lora_save_pretrained
=
pipe_from_pretrained
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
images_lora_save_pretrained
=
pipe_from_pretrained
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
images_lora
,
images_lora_save_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Loading from saved checkpoints should give same results."
,
)
self
.
assertTrue
(
np
.
allclose
(
images_lora
,
images_lora_save_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Loading from saved checkpoints should give same results."
,
)
def
test_simple_inference_with_text_denoiser_lora_save_load
(
self
):
"""
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
images_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
images_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
modules_to_save
=
self
.
_get_modules_to_save
(
pipe
,
has_denoiser
=
True
)
lora_state_dicts
=
self
.
_get_lora_state_dicts
(
modules_to_save
)
self
.
pipeline_class
.
save_lora_weights
(
save_directory
=
tmpdirname
,
safe_serialization
=
False
,
**
lora_state_dicts
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
modules_to_save
=
self
.
_get_modules_to_save
(
pipe
,
has_denoiser
=
True
)
lora_state_dicts
=
self
.
_get_lora_state_dicts
(
modules_to_save
)
self
.
pipeline_class
.
save_lora_weights
(
save_directory
=
tmpdirname
,
safe_serialization
=
False
,
**
lora_state_dicts
)
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.bin"
)))
pipe
.
unload_lora_weights
()
pipe
.
load_lora_weights
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.bin"
))
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.bin"
)))
pipe
.
unload_lora_weights
()
pipe
.
load_lora_weights
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.bin"
))
for
module_name
,
module
in
modules_to_save
.
items
():
self
.
assertTrue
(
check_if_lora_correctly_set
(
module
),
f
"Lora not correctly set in
{
module_name
}
"
)
for
module_name
,
module
in
modules_to_save
.
items
():
self
.
assertTrue
(
check_if_lora_correctly_set
(
module
),
f
"Lora not correctly set in
{
module_name
}
"
)
images_lora_from_pretrained
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
images_lora
,
images_lora_from_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Loading from saved checkpoints should give same results."
,
)
images_lora_from_pretrained
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
images_lora
,
images_lora_from_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Loading from saved checkpoints should give same results."
,
)
def
test_simple_inference_with_text_denoiser_lora_and_scale
(
self
):
"""
...
...
@@ -763,120 +741,112 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
attention_kwargs_name
=
determine_attention_kwargs_name
(
self
.
pipeline_class
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
output_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
not
np
.
allclose
(
output_lora
,
output_no_lora
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora should change the output"
)
output_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
not
np
.
allclose
(
output_lora
,
output_no_lora
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora should change the output"
)
attention_kwargs
=
{
attention_kwargs_name
:
{
"scale"
:
0.5
}}
output_lora_scale
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
),
**
attention_kwargs
)[
0
]
attention_kwargs
=
{
attention_kwargs_name
:
{
"scale"
:
0.5
}}
output_lora_scale
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
),
**
attention_kwargs
)[
0
]
self
.
assertTrue
(
not
np
.
allclose
(
output_lora
,
output_lora_scale
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora + scale should change the output"
,
)
self
.
assertTrue
(
not
np
.
allclose
(
output_lora
,
output_lora_scale
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora + scale should change the output"
,
)
attention_kwargs
=
{
attention_kwargs_name
:
{
"scale"
:
0.0
}}
output_lora_0_scale
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
),
**
attention_kwargs
)[
0
]
attention_kwargs
=
{
attention_kwargs_name
:
{
"scale"
:
0.0
}}
output_lora_0_scale
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
),
**
attention_kwargs
)[
0
]
self
.
assertTrue
(
np
.
allclose
(
output_no_lora
,
output_lora_0_scale
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora + 0 scale should lead to same result as no LoRA"
,
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
np
.
allclose
(
output_no_lora
,
output_lora_0_scale
,
atol
=
1e-3
,
rtol
=
1e-3
)
,
"
Lora + 0 scale should lead to same result as no LoRA
"
,
pipe
.
text_encoder
.
text_model
.
encoder
.
layers
[
0
].
self_attn
.
q_proj
.
scaling
[
"default"
]
==
1.0
,
"
The scaling parameter has not been correctly restored!
"
,
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
pipe
.
text_encoder
.
text_model
.
encoder
.
layers
[
0
].
self_attn
.
q_proj
.
scaling
[
"default"
]
==
1.0
,
"The scaling parameter has not been correctly restored!"
,
)
def
test_simple_inference_with_text_lora_denoiser_fused
(
self
):
"""
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
denoiser
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
,
denoiser
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
)
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
)
# Fusing should still keep the LoRA layers
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
# Fusing should still keep the LoRA layers
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
)
output_fused
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertFalse
(
np
.
allclose
(
output_fused
,
output_no_lora
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Fused lora should change the output"
)
output_fused
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertFalse
(
np
.
allclose
(
output_fused
,
output_no_lora
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Fused lora should change the output"
)
def
test_simple_inference_with_text_denoiser_lora_unloaded
(
self
):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
denoiser
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
,
denoiser
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
.
unload_lora_weights
()
# unloading should remove the LoRA layers
self
.
assertFalse
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly unloaded in text encoder"
)
self
.
assertFalse
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly unloaded in denoiser"
)
pipe
.
unload_lora_weights
()
# unloading should remove the LoRA layers
self
.
assertFalse
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly unloaded in text encoder"
)
self
.
assertFalse
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly unloaded in denoiser"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertFalse
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly unloaded in text encoder 2"
,
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertFalse
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly unloaded in text encoder 2"
,
)
output_unloaded
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
output_unloaded
,
output_no_lora
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Fused lora should change the output"
,
)
output_unloaded
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
output_unloaded
,
output_no_lora
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Fused lora should change the output"
,
)
def
test_simple_inference_with_text_denoiser_lora_unfused
(
self
,
expected_atol
:
float
=
1e-3
,
expected_rtol
:
float
=
1e-3
...
...
@@ -885,125 +855,120 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
pipe
,
denoiser
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
,
denoiser
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
)
self
.
assertTrue
(
pipe
.
num_fused_loras
==
1
,
f
"
{
pipe
.
num_fused_loras
=
}
,
{
pipe
.
fused_loras
=
}
"
)
output_fused_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
)
self
.
assertTrue
(
pipe
.
num_fused_loras
==
1
,
f
"
{
pipe
.
num_fused_loras
=
}
,
{
pipe
.
fused_loras
=
}
"
)
output_fused_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
unfuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
)
self
.
assertTrue
(
pipe
.
num_fused_loras
==
0
,
f
"
{
pipe
.
num_fused_loras
=
}
,
{
pipe
.
fused_loras
=
}
"
)
output_unfused_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
unfuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
)
self
.
assertTrue
(
pipe
.
num_fused_loras
==
0
,
f
"
{
pipe
.
num_fused_loras
=
}
,
{
pipe
.
fused_loras
=
}
"
)
output_unfused_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
# unloading should remove the LoRA layers
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Unfuse should still keep LoRA layers"
)
# unloading should remove the LoRA layers
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Unfuse should still keep LoRA layers"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Unfuse should still keep LoRA layers"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Unfuse should still keep LoRA layers"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Unfuse should still keep LoRA layers"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Unfuse should still keep LoRA layers"
)
# Fuse and unfuse should lead to the same results
self
.
assertTrue
(
np
.
allclose
(
output_fused_lora
,
output_unfused_lora
,
atol
=
expected_atol
,
rtol
=
expected_rtol
),
"Fused lora should not change the output"
,
)
# Fuse and unfuse should lead to the same results
self
.
assertTrue
(
np
.
allclose
(
output_fused_lora
,
output_unfused_lora
,
atol
=
expected_atol
,
rtol
=
expected_rtol
),
"Fused lora should not change the output"
,
)
def
test_simple_inference_with_text_denoiser_multi_adapter
(
self
):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
)
pipe
.
set_adapters
(
"adapter-1"
)
output_adapter_1
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertFalse
(
np
.
allclose
(
output_no_lora
,
output_adapter_1
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter outputs should be different."
,
)
pipe
.
set_adapters
(
"adapter-1"
)
output_adapter_1
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertFalse
(
np
.
allclose
(
output_no_lora
,
output_adapter_1
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter outputs should be different."
,
)
pipe
.
set_adapters
(
"adapter-2"
)
output_adapter_2
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertFalse
(
np
.
allclose
(
output_no_lora
,
output_adapter_2
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter outputs should be different."
,
)
pipe
.
set_adapters
(
"adapter-2"
)
output_adapter_2
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertFalse
(
np
.
allclose
(
output_no_lora
,
output_adapter_2
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter outputs should be different."
,
)
pipe
.
set_adapters
([
"adapter-1"
,
"adapter-2"
])
output_adapter_mixed
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertFalse
(
np
.
allclose
(
output_no_lora
,
output_adapter_mixed
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter outputs should be different."
,
)
pipe
.
set_adapters
([
"adapter-1"
,
"adapter-2"
])
output_adapter_mixed
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertFalse
(
np
.
allclose
(
output_no_lora
,
output_adapter_mixed
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter outputs should be different."
,
)
# Fuse and unfuse should lead to the same results
self
.
assertFalse
(
np
.
allclose
(
output_adapter_1
,
output_adapter_2
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 1 and 2 should give different results"
,
)
# Fuse and unfuse should lead to the same results
self
.
assertFalse
(
np
.
allclose
(
output_adapter_1
,
output_adapter_2
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 1 and 2 should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_adapter_1
,
output_adapter_mixed
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 1 and mixed adapters should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_adapter_1
,
output_adapter_mixed
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 1 and mixed adapters should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_adapter_2
,
output_adapter_mixed
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 2 and mixed adapters should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_adapter_2
,
output_adapter_mixed
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 2 and mixed adapters should give different results"
,
)
pipe
.
disable_lora
()
output_disabled
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
disable_lora
()
output_disabled
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
output_no_lora
,
output_disabled
,
atol
=
1e-3
,
rtol
=
1e-3
),
"output with no lora and output with lora disabled should give same results"
,
)
self
.
assertTrue
(
np
.
allclose
(
output_no_lora
,
output_disabled
,
atol
=
1e-3
,
rtol
=
1e-3
),
"output with no lora and output with lora disabled should give same results"
,
)
def
test_wrong_adapter_name_raises_error
(
self
):
adapter_name
=
"adapter-1"
scheduler_cls
=
self
.
scheduler_classes
[
0
]
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1024,8 +989,7 @@ class PeftLoraLoaderMixinTests:
def
test_multiple_wrong_adapter_name_raises_error
(
self
):
adapter_name
=
"adapter-1"
scheduler_cls
=
self
.
scheduler_classes
[
0
]
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1054,131 +1018,127 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
one adapter and set different weights for different blocks (i.e. block lora)
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
)
weights_1
=
{
"text_encoder"
:
2
,
"unet"
:
{
"down"
:
5
}}
pipe
.
set_adapters
(
"adapter-1"
,
weights_1
)
output_weights_1
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
weights_1
=
{
"text_encoder"
:
2
,
"unet"
:
{
"down"
:
5
}}
pipe
.
set_adapters
(
"adapter-1"
,
weights_1
)
output_weights_1
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
weights_2
=
{
"unet"
:
{
"up"
:
5
}}
pipe
.
set_adapters
(
"adapter-1"
,
weights_2
)
output_weights_2
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
weights_2
=
{
"unet"
:
{
"up"
:
5
}}
pipe
.
set_adapters
(
"adapter-1"
,
weights_2
)
output_weights_2
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertFalse
(
np
.
allclose
(
output_weights_1
,
output_weights_2
,
atol
=
1e-3
,
rtol
=
1e-3
),
"LoRA weights 1 and 2 should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_no_lora
,
output_weights_1
,
atol
=
1e-3
,
rtol
=
1e-3
),
"No adapter and LoRA weights 1 should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_no_lora
,
output_weights_2
,
atol
=
1e-3
,
rtol
=
1e-3
),
"No adapter and LoRA weights 2 should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_weights_1
,
output_weights_2
,
atol
=
1e-3
,
rtol
=
1e-3
),
"LoRA weights 1 and 2 should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_no_lora
,
output_weights_1
,
atol
=
1e-3
,
rtol
=
1e-3
),
"No adapter and LoRA weights 1 should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_no_lora
,
output_weights_2
,
atol
=
1e-3
,
rtol
=
1e-3
),
"No adapter and LoRA weights 2 should give different results"
,
)
pipe
.
disable_lora
()
output_disabled
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
disable_lora
()
output_disabled
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
output_no_lora
,
output_disabled
,
atol
=
1e-3
,
rtol
=
1e-3
),
"output with no lora and output with lora disabled should give same results"
,
)
self
.
assertTrue
(
np
.
allclose
(
output_no_lora
,
output_disabled
,
atol
=
1e-3
,
rtol
=
1e-3
),
"output with no lora and output with lora disabled should give same results"
,
)
def
test_simple_inference_with_text_denoiser_multi_adapter_block_lora
(
self
):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set different weights for different blocks (i.e. block lora)
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
)
scales_1
=
{
"text_encoder"
:
2
,
"unet"
:
{
"down"
:
5
}}
scales_2
=
{
"unet"
:
{
"down"
:
5
,
"mid"
:
5
}}
scales_1
=
{
"text_encoder"
:
2
,
"unet"
:
{
"down"
:
5
}}
scales_2
=
{
"unet"
:
{
"down"
:
5
,
"mid"
:
5
}}
pipe
.
set_adapters
(
"adapter-1"
,
scales_1
)
output_adapter_1
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
set_adapters
(
"adapter-1"
,
scales_1
)
output_adapter_1
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
set_adapters
(
"adapter-2"
,
scales_2
)
output_adapter_2
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
set_adapters
(
"adapter-2"
,
scales_2
)
output_adapter_2
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
set_adapters
([
"adapter-1"
,
"adapter-2"
],
[
scales_1
,
scales_2
])
output_adapter_mixed
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
set_adapters
([
"adapter-1"
,
"adapter-2"
],
[
scales_1
,
scales_2
])
output_adapter_mixed
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
# Fuse and unfuse should lead to the same results
self
.
assertFalse
(
np
.
allclose
(
output_adapter_1
,
output_adapter_2
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 1 and 2 should give different results"
,
)
# Fuse and unfuse should lead to the same results
self
.
assertFalse
(
np
.
allclose
(
output_adapter_1
,
output_adapter_2
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 1 and 2 should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_adapter_1
,
output_adapter_mixed
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 1 and mixed adapters should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_adapter_1
,
output_adapter_mixed
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 1 and mixed adapters should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_adapter_2
,
output_adapter_mixed
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 2 and mixed adapters should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_adapter_2
,
output_adapter_mixed
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 2 and mixed adapters should give different results"
,
)
pipe
.
disable_lora
()
output_disabled
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
disable_lora
()
output_disabled
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
output_no_lora
,
output_disabled
,
atol
=
1e-3
,
rtol
=
1e-3
),
"output with no lora and output with lora disabled should give same results"
,
)
self
.
assertTrue
(
np
.
allclose
(
output_no_lora
,
output_disabled
,
atol
=
1e-3
,
rtol
=
1e-3
),
"output with no lora and output with lora disabled should give same results"
,
)
# a mismatching number of adapter_names and adapter_weights should raise an error
with
self
.
assertRaises
(
ValueError
):
pipe
.
set_adapters
([
"adapter-1"
,
"adapter-2"
],
[
scales_1
])
# a mismatching number of adapter_names and adapter_weights should raise an error
with
self
.
assertRaises
(
ValueError
):
pipe
.
set_adapters
([
"adapter-1"
,
"adapter-2"
],
[
scales_1
])
def
test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options
(
self
):
"""Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""
...
...
@@ -1274,170 +1234,164 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set/delete them
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
lora_loadable_components
=
self
.
pipeline_class
.
_lora_loadable_modules
if
"text_encoder_2"
in
lora_loadable_components
:
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
lora_loadable_components
=
self
.
pipeline_class
.
_lora_loadable_modules
if
"text_encoder_2"
in
lora_loadable_components
:
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
)
pipe
.
set_adapters
(
"adapter-1"
)
output_adapter_1
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
set_adapters
(
"adapter-1"
)
output_adapter_1
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
set_adapters
(
"adapter-2"
)
output_adapter_2
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
set_adapters
(
"adapter-2"
)
output_adapter_2
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
set_adapters
([
"adapter-1"
,
"adapter-2"
])
output_adapter_mixed
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
set_adapters
([
"adapter-1"
,
"adapter-2"
])
output_adapter_mixed
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertFalse
(
np
.
allclose
(
output_adapter_1
,
output_adapter_2
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 1 and 2 should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_adapter_1
,
output_adapter_2
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 1 and 2 should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_adapter_1
,
output_adapter_mixed
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 1 and mixed adapters should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_adapter_1
,
output_adapter_mixed
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 1 and mixed adapters should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_adapter_2
,
output_adapter_mixed
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 2 and mixed adapters should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_adapter_2
,
output_adapter_mixed
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 2 and mixed adapters should give different results"
,
)
pipe
.
delete_adapters
(
"adapter-1"
)
output_deleted_adapter_1
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
delete_adapters
(
"adapter-1"
)
output_deleted_adapter_1
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
output_deleted_adapter_1
,
output_adapter_2
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 1 and 2 should give different results"
,
)
self
.
assertTrue
(
np
.
allclose
(
output_deleted_adapter_1
,
output_adapter_2
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 1 and 2 should give different results"
,
)
pipe
.
delete_adapters
(
"adapter-2"
)
output_deleted_adapters
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
delete_adapters
(
"adapter-2"
)
output_deleted_adapters
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
output_no_lora
,
output_deleted_adapters
,
atol
=
1e-3
,
rtol
=
1e-3
),
"output with no lora and output with lora disabled should give same results"
,
)
self
.
assertTrue
(
np
.
allclose
(
output_no_lora
,
output_deleted_adapters
,
atol
=
1e-3
,
rtol
=
1e-3
),
"output with no lora and output with lora disabled should give same results"
,
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
pipe
.
set_adapters
([
"adapter-1"
,
"adapter-2"
])
pipe
.
delete_adapters
([
"adapter-1"
,
"adapter-2"
])
pipe
.
set_adapters
([
"adapter-1"
,
"adapter-2"
])
pipe
.
delete_adapters
([
"adapter-1"
,
"adapter-2"
])
output_deleted_adapters
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
output_deleted_adapters
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
output_no_lora
,
output_deleted_adapters
,
atol
=
1e-3
,
rtol
=
1e-3
),
"output with no lora and output with lora disabled should give same results"
,
)
self
.
assertTrue
(
np
.
allclose
(
output_no_lora
,
output_deleted_adapters
,
atol
=
1e-3
,
rtol
=
1e-3
),
"output with no lora and output with lora disabled should give same results"
,
)
def
test_simple_inference_with_text_denoiser_multi_adapter_weighted
(
self
):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
lora_loadable_components
=
self
.
pipeline_class
.
_lora_loadable_modules
if
"text_encoder_2"
in
lora_loadable_components
:
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
lora_loadable_components
=
self
.
pipeline_class
.
_lora_loadable_modules
if
"text_encoder_2"
in
lora_loadable_components
:
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
)
pipe
.
set_adapters
(
"adapter-1"
)
output_adapter_1
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
set_adapters
(
"adapter-1"
)
output_adapter_1
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
set_adapters
(
"adapter-2"
)
output_adapter_2
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
set_adapters
(
"adapter-2"
)
output_adapter_2
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
set_adapters
([
"adapter-1"
,
"adapter-2"
])
output_adapter_mixed
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
set_adapters
([
"adapter-1"
,
"adapter-2"
])
output_adapter_mixed
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
# Fuse and unfuse should lead to the same results
self
.
assertFalse
(
np
.
allclose
(
output_adapter_1
,
output_adapter_2
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 1 and 2 should give different results"
,
)
# Fuse and unfuse should lead to the same results
self
.
assertFalse
(
np
.
allclose
(
output_adapter_1
,
output_adapter_2
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 1 and 2 should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_adapter_1
,
output_adapter_mixed
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 1 and mixed adapters should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_adapter_1
,
output_adapter_mixed
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 1 and mixed adapters should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_adapter_2
,
output_adapter_mixed
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 2 and mixed adapters should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_adapter_2
,
output_adapter_mixed
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Adapter 2 and mixed adapters should give different results"
,
)
pipe
.
set_adapters
([
"adapter-1"
,
"adapter-2"
],
[
0.5
,
0.6
])
output_adapter_mixed_weighted
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
set_adapters
([
"adapter-1"
,
"adapter-2"
],
[
0.5
,
0.6
])
output_adapter_mixed_weighted
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertFalse
(
np
.
allclose
(
output_adapter_mixed_weighted
,
output_adapter_mixed
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Weighted adapter and mixed adapter should give different results"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_adapter_mixed_weighted
,
output_adapter_mixed
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Weighted adapter and mixed adapter should give different results"
,
)
pipe
.
disable_lora
()
output_disabled
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
disable_lora
()
output_disabled
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
output_no_lora
,
output_disabled
,
atol
=
1e-3
,
rtol
=
1e-3
),
"output with no lora and output with lora disabled should give same results"
,
)
self
.
assertTrue
(
np
.
allclose
(
output_no_lora
,
output_disabled
,
atol
=
1e-3
,
rtol
=
1e-3
),
"output with no lora and output with lora disabled should give same results"
,
)
@
skip_mps
@
pytest
.
mark
.
xfail
(
...
...
@@ -1445,164 +1399,157 @@ class PeftLoraLoaderMixinTests:
reason
=
"Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1."
,
strict
=
False
,
)
def
test_lora_fuse_nan
(
self
):
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
def
test_lora_fuse_nan
(
self
):
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
# corrupt one LoRA weight with `inf` values
with
torch
.
no_grad
():
if
self
.
unet_kwargs
:
pipe
.
unet
.
mid_block
.
attentions
[
0
].
transformer_blocks
[
0
].
attn1
.
to_q
.
lora_A
[
"adapter-1"
].
weight
+=
float
(
"inf"
)
else
:
named_modules
=
[
name
for
name
,
_
in
pipe
.
transformer
.
named_modules
()]
possible_tower_names
=
[
"transformer_blocks"
,
"blocks"
,
"joint_transformer_blocks"
,
"single_transformer_blocks"
,
]
filtered_tower_names
=
[
tower_name
for
tower_name
in
possible_tower_names
if
hasattr
(
pipe
.
transformer
,
tower_name
)
]
if
len
(
filtered_tower_names
)
==
0
:
reason
=
(
f
"`pipe.transformer` didn't have any of the following attributes:
{
possible_tower_names
}
."
)
raise
ValueError
(
reason
)
for
tower_name
in
filtered_tower_names
:
transformer_tower
=
getattr
(
pipe
.
transformer
,
tower_name
)
has_attn1
=
any
(
"attn1"
in
name
for
name
in
named_modules
)
if
has_attn1
:
transformer_tower
[
0
].
attn1
.
to_q
.
lora_A
[
"adapter-1"
].
weight
+=
float
(
"inf"
)
else
:
transformer_tower
[
0
].
attn
.
to_q
.
lora_A
[
"adapter-1"
].
weight
+=
float
(
"inf"
)
# with `safe_fusing=True` we should see an Error
with
self
.
assertRaises
(
ValueError
):
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
,
safe_fusing
=
True
)
# without we should not see an error, but every image will be black
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
,
safe_fusing
=
False
)
out
=
pipe
(
**
inputs
)[
0
]
self
.
assertTrue
(
np
.
isnan
(
out
).
all
())
# corrupt one LoRA weight with `inf` values
with
torch
.
no_grad
():
if
self
.
unet_kwargs
:
pipe
.
unet
.
mid_block
.
attentions
[
0
].
transformer_blocks
[
0
].
attn1
.
to_q
.
lora_A
[
"adapter-1"
].
weight
+=
float
(
"inf"
)
else
:
named_modules
=
[
name
for
name
,
_
in
pipe
.
transformer
.
named_modules
()]
possible_tower_names
=
[
"transformer_blocks"
,
"blocks"
,
"joint_transformer_blocks"
,
"single_transformer_blocks"
,
]
filtered_tower_names
=
[
tower_name
for
tower_name
in
possible_tower_names
if
hasattr
(
pipe
.
transformer
,
tower_name
)
]
if
len
(
filtered_tower_names
)
==
0
:
reason
=
f
"`pipe.transformer` didn't have any of the following attributes:
{
possible_tower_names
}
."
raise
ValueError
(
reason
)
for
tower_name
in
filtered_tower_names
:
transformer_tower
=
getattr
(
pipe
.
transformer
,
tower_name
)
has_attn1
=
any
(
"attn1"
in
name
for
name
in
named_modules
)
if
has_attn1
:
transformer_tower
[
0
].
attn1
.
to_q
.
lora_A
[
"adapter-1"
].
weight
+=
float
(
"inf"
)
else
:
transformer_tower
[
0
].
attn
.
to_q
.
lora_A
[
"adapter-1"
].
weight
+=
float
(
"inf"
)
# with `safe_fusing=True` we should see an Error
with
self
.
assertRaises
(
ValueError
):
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
,
safe_fusing
=
True
)
# without we should not see an error, but every image will be black
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
,
safe_fusing
=
False
)
out
=
pipe
(
**
inputs
)[
0
]
self
.
assertTrue
(
np
.
isnan
(
out
).
all
())
def
test_get_adapters
(
self
):
"""
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
adapter_names
=
pipe
.
get_active_adapters
()
self
.
assertListEqual
(
adapter_names
,
[
"adapter-1"
])
adapter_names
=
pipe
.
get_active_adapters
()
self
.
assertListEqual
(
adapter_names
,
[
"adapter-1"
])
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-2"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-2"
)
adapter_names
=
pipe
.
get_active_adapters
()
self
.
assertListEqual
(
adapter_names
,
[
"adapter-2"
])
adapter_names
=
pipe
.
get_active_adapters
()
self
.
assertListEqual
(
adapter_names
,
[
"adapter-2"
])
pipe
.
set_adapters
([
"adapter-1"
,
"adapter-2"
])
self
.
assertListEqual
(
pipe
.
get_active_adapters
(),
[
"adapter-1"
,
"adapter-2"
])
pipe
.
set_adapters
([
"adapter-1"
,
"adapter-2"
])
self
.
assertListEqual
(
pipe
.
get_active_adapters
(),
[
"adapter-1"
,
"adapter-2"
])
def
test_get_list_adapters
(
self
):
"""
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
# 1.
dicts_to_be_checked
=
{}
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
dicts_to_be_checked
=
{
"text_encoder"
:
[
"adapter-1"
]}
# 1.
dicts_to_be_checked
=
{}
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
dicts_to_be_checked
=
{
"text_encoder"
:
[
"adapter-1"
]}
if
self
.
unet_kwargs
is
not
None
:
pipe
.
unet
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
dicts_to_be_checked
.
update
({
"unet"
:
[
"adapter-1"
]})
else
:
pipe
.
transformer
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
dicts_to_be_checked
.
update
({
"transformer"
:
[
"adapter-1"
]})
if
self
.
unet_kwargs
is
not
None
:
pipe
.
unet
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
dicts_to_be_checked
.
update
({
"unet"
:
[
"adapter-1"
]})
else
:
pipe
.
transformer
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
dicts_to_be_checked
.
update
({
"transformer"
:
[
"adapter-1"
]})
self
.
assertDictEqual
(
pipe
.
get_list_adapters
(),
dicts_to_be_checked
)
self
.
assertDictEqual
(
pipe
.
get_list_adapters
(),
dicts_to_be_checked
)
# 2.
dicts_to_be_checked
=
{}
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
dicts_to_be_checked
=
{
"text_encoder"
:
[
"adapter-1"
,
"adapter-2"
]}
# 2.
dicts_to_be_checked
=
{}
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
dicts_to_be_checked
=
{
"text_encoder"
:
[
"adapter-1"
,
"adapter-2"
]}
if
self
.
unet_kwargs
is
not
None
:
pipe
.
unet
.
add_adapter
(
denoiser_lora_config
,
"adapter-2"
)
dicts_to_be_checked
.
update
({
"unet"
:
[
"adapter-1"
,
"adapter-2"
]})
else
:
pipe
.
transformer
.
add_adapter
(
denoiser_lora_config
,
"adapter-2"
)
dicts_to_be_checked
.
update
({
"transformer"
:
[
"adapter-1"
,
"adapter-2"
]})
if
self
.
unet_kwargs
is
not
None
:
pipe
.
unet
.
add_adapter
(
denoiser_lora_config
,
"adapter-2"
)
dicts_to_be_checked
.
update
({
"unet"
:
[
"adapter-1"
,
"adapter-2"
]})
else
:
pipe
.
transformer
.
add_adapter
(
denoiser_lora_config
,
"adapter-2"
)
dicts_to_be_checked
.
update
({
"transformer"
:
[
"adapter-1"
,
"adapter-2"
]})
self
.
assertDictEqual
(
pipe
.
get_list_adapters
(),
dicts_to_be_checked
)
self
.
assertDictEqual
(
pipe
.
get_list_adapters
(),
dicts_to_be_checked
)
# 3.
pipe
.
set_adapters
([
"adapter-1"
,
"adapter-2"
])
# 3.
pipe
.
set_adapters
([
"adapter-1"
,
"adapter-2"
])
dicts_to_be_checked
=
{}
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
dicts_to_be_checked
=
{
"text_encoder"
:
[
"adapter-1"
,
"adapter-2"
]}
dicts_to_be_checked
=
{}
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
dicts_to_be_checked
=
{
"text_encoder"
:
[
"adapter-1"
,
"adapter-2"
]}
if
self
.
unet_kwargs
is
not
None
:
dicts_to_be_checked
.
update
({
"unet"
:
[
"adapter-1"
,
"adapter-2"
]})
else
:
dicts_to_be_checked
.
update
({
"transformer"
:
[
"adapter-1"
,
"adapter-2"
]})
if
self
.
unet_kwargs
is
not
None
:
dicts_to_be_checked
.
update
({
"unet"
:
[
"adapter-1"
,
"adapter-2"
]})
else
:
dicts_to_be_checked
.
update
({
"transformer"
:
[
"adapter-1"
,
"adapter-2"
]})
self
.
assertDictEqual
(
pipe
.
get_list_adapters
(),
dicts_to_be_checked
,
)
self
.
assertDictEqual
(
pipe
.
get_list_adapters
(),
dicts_to_be_checked
,
)
# 4.
dicts_to_be_checked
=
{}
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
dicts_to_be_checked
=
{
"text_encoder"
:
[
"adapter-1"
,
"adapter-2"
]}
# 4.
dicts_to_be_checked
=
{}
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
dicts_to_be_checked
=
{
"text_encoder"
:
[
"adapter-1"
,
"adapter-2"
]}
if
self
.
unet_kwargs
is
not
None
:
pipe
.
unet
.
add_adapter
(
denoiser_lora_config
,
"adapter-3"
)
dicts_to_be_checked
.
update
({
"unet"
:
[
"adapter-1"
,
"adapter-2"
,
"adapter-3"
]})
else
:
pipe
.
transformer
.
add_adapter
(
denoiser_lora_config
,
"adapter-3"
)
dicts_to_be_checked
.
update
({
"transformer"
:
[
"adapter-1"
,
"adapter-2"
,
"adapter-3"
]})
if
self
.
unet_kwargs
is
not
None
:
pipe
.
unet
.
add_adapter
(
denoiser_lora_config
,
"adapter-3"
)
dicts_to_be_checked
.
update
({
"unet"
:
[
"adapter-1"
,
"adapter-2"
,
"adapter-3"
]})
else
:
pipe
.
transformer
.
add_adapter
(
denoiser_lora_config
,
"adapter-3"
)
dicts_to_be_checked
.
update
({
"transformer"
:
[
"adapter-1"
,
"adapter-2"
,
"adapter-3"
]})
self
.
assertDictEqual
(
pipe
.
get_list_adapters
(),
dicts_to_be_checked
)
self
.
assertDictEqual
(
pipe
.
get_list_adapters
(),
dicts_to_be_checked
)
@
require_peft_version_greater
(
peft_version
=
"0.6.2"
)
def
test_simple_inference_with_text_lora_denoiser_fused_multi
(
...
...
@@ -1612,8 +1559,83 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet and multi-adapter case
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-2"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
lora_loadable_components
=
self
.
pipeline_class
.
_lora_loadable_modules
if
"text_encoder_2"
in
lora_loadable_components
:
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
)
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
# set them to multi-adapter inference mode
pipe
.
set_adapters
([
"adapter-1"
,
"adapter-2"
])
outputs_all_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
set_adapters
([
"adapter-1"
])
outputs_lora_1
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
,
adapter_names
=
[
"adapter-1"
])
self
.
assertTrue
(
pipe
.
num_fused_loras
==
1
,
f
"
{
pipe
.
num_fused_loras
=
}
,
{
pipe
.
fused_loras
=
}
"
)
# Fusing should still keep the LoRA layers so output should remain the same
outputs_lora_1_fused
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
outputs_lora_1
,
outputs_lora_1_fused
,
atol
=
expected_atol
,
rtol
=
expected_rtol
),
"Fused lora should not change the output"
,
)
pipe
.
unfuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
)
self
.
assertTrue
(
pipe
.
num_fused_loras
==
0
,
f
"
{
pipe
.
num_fused_loras
=
}
,
{
pipe
.
fused_loras
=
}
"
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Unfuse should still keep LoRA layers"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Unfuse should still keep LoRA layers"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Unfuse should still keep LoRA layers"
)
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
,
adapter_names
=
[
"adapter-2"
,
"adapter-1"
])
self
.
assertTrue
(
pipe
.
num_fused_loras
==
2
,
f
"
{
pipe
.
num_fused_loras
=
}
,
{
pipe
.
fused_loras
=
}
"
)
# Fusing should still keep the LoRA layers
output_all_lora_fused
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
output_all_lora_fused
,
outputs_all_lora
,
atol
=
expected_atol
,
rtol
=
expected_rtol
),
"Fused lora should not change the output"
,
)
pipe
.
unfuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
)
self
.
assertTrue
(
pipe
.
num_fused_loras
==
0
,
f
"
{
pipe
.
num_fused_loras
=
}
,
{
pipe
.
fused_loras
=
}
"
)
def
test_lora_scale_kwargs_match_fusion
(
self
,
expected_atol
:
float
=
1e-3
,
expected_rtol
:
float
=
1e-3
):
attention_kwargs_name
=
determine_attention_kwargs_name
(
self
.
pipeline_class
)
for
lora_scale
in
[
1.0
,
0.8
]:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1627,150 +1649,65 @@ class PeftLoraLoaderMixinTests:
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-2"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
lora_loadable_components
=
self
.
pipeline_class
.
_lora_loadable_modules
if
"text_encoder_2"
in
lora_loadable_components
:
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
,
)
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
# set them to multi-adapter inference mode
pipe
.
set_adapters
([
"adapter-1"
,
"adapter-2"
])
outputs_all_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
pipe
.
set_adapters
([
"adapter-1"
])
outputs_lora_1
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
attention_kwargs
=
{
attention_kwargs_name
:
{
"scale"
:
lora_scale
}}
outputs_lora_1
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
),
**
attention_kwargs
)[
0
]
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
,
adapter_names
=
[
"adapter-1"
])
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
,
adapter_names
=
[
"adapter-1"
],
lora_scale
=
lora_scale
,
)
self
.
assertTrue
(
pipe
.
num_fused_loras
==
1
,
f
"
{
pipe
.
num_fused_loras
=
}
,
{
pipe
.
fused_loras
=
}
"
)
# Fusing should still keep the LoRA layers so output should remain the same
outputs_lora_1_fused
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
outputs_lora_1
,
outputs_lora_1_fused
,
atol
=
expected_atol
,
rtol
=
expected_rtol
),
"Fused lora should not change the output"
,
)
pipe
.
unfuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
)
self
.
assertTrue
(
pipe
.
num_fused_loras
==
0
,
f
"
{
pipe
.
num_fused_loras
=
}
,
{
pipe
.
fused_loras
=
}
"
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Unfuse should still keep LoRA layers"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Unfuse should still keep LoRA layers"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Unfuse should still keep LoRA layers"
)
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
,
adapter_names
=
[
"adapter-2"
,
"adapter-1"
]
)
self
.
assertTrue
(
pipe
.
num_fused_loras
==
2
,
f
"
{
pipe
.
num_fused_loras
=
}
,
{
pipe
.
fused_loras
=
}
"
)
# Fusing should still keep the LoRA layers
output_all_lora_fused
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
output_all_lora_fused
,
outputs_all_lora
,
atol
=
expected_atol
,
rtol
=
expected_rtol
),
"Fused lora should not change the output"
,
self
.
assertFalse
(
np
.
allclose
(
output_no_lora
,
outputs_lora_1
,
atol
=
expected_atol
,
rtol
=
expected_rtol
),
"LoRA should change the output"
,
)
pipe
.
unfuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
)
self
.
assertTrue
(
pipe
.
num_fused_loras
==
0
,
f
"
{
pipe
.
num_fused_loras
=
}
,
{
pipe
.
fused_loras
=
}
"
)
def
test_lora_scale_kwargs_match_fusion
(
self
,
expected_atol
:
float
=
1e-3
,
expected_rtol
:
float
=
1e-3
):
attention_kwargs_name
=
determine_attention_kwargs_name
(
self
.
pipeline_class
)
for
lora_scale
in
[
1.0
,
0.8
]:
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
lora_loadable_components
=
self
.
pipeline_class
.
_lora_loadable_modules
if
"text_encoder_2"
in
lora_loadable_components
:
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
,
)
pipe
.
set_adapters
([
"adapter-1"
])
attention_kwargs
=
{
attention_kwargs_name
:
{
"scale"
:
lora_scale
}}
outputs_lora_1
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
),
**
attention_kwargs
)[
0
]
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
,
adapter_names
=
[
"adapter-1"
],
lora_scale
=
lora_scale
,
)
self
.
assertTrue
(
pipe
.
num_fused_loras
==
1
,
f
"
{
pipe
.
num_fused_loras
=
}
,
{
pipe
.
fused_loras
=
}
"
)
outputs_lora_1_fused
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
outputs_lora_1
,
outputs_lora_1_fused
,
atol
=
expected_atol
,
rtol
=
expected_rtol
),
"Fused lora should not change the output"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_no_lora
,
outputs_lora_1
,
atol
=
expected_atol
,
rtol
=
expected_rtol
),
"LoRA should change the output"
,
)
@
require_peft_version_greater
(
peft_version
=
"0.9.0"
)
def
test_simple_inference_with_dora
(
self
):
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
,
use_dora
=
True
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
use_dora
=
True
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_dora_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_dora_lora
.
shape
==
self
.
output_shape
)
output_no_dora_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_dora_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
output_dora_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
output_dora_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertFalse
(
np
.
allclose
(
output_dora_lora
,
output_no_dora_lora
,
atol
=
1e-3
,
rtol
=
1e-3
),
"DoRA lora should change the output"
,
)
self
.
assertFalse
(
np
.
allclose
(
output_dora_lora
,
output_no_dora_lora
,
atol
=
1e-3
,
rtol
=
1e-3
),
"DoRA lora should change the output"
,
)
def
test_missing_keys_warning
(
self
):
scheduler_cls
=
self
.
scheduler_classes
[
0
]
# Skip text encoder check for now as that is handled with `transformers`.
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1805,9 +1742,8 @@ class PeftLoraLoaderMixinTests:
self
.
assertTrue
(
missing_key
.
replace
(
f
"
{
component
}
."
,
""
)
in
cap_logger
.
out
.
replace
(
"default_0."
,
""
))
def
test_unexpected_keys_warning
(
self
):
scheduler_cls
=
self
.
scheduler_classes
[
0
]
# Skip text encoder check for now as that is handled with `transformers`.
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1842,23 +1778,22 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
pipe
.
unet
=
torch
.
compile
(
pipe
.
unet
,
mode
=
"reduce-overhead"
,
fullgraph
=
True
)
pipe
.
text_encoder
=
torch
.
compile
(
pipe
.
text_encoder
,
mode
=
"reduce-overhead"
,
fullgraph
=
True
)
pipe
.
unet
=
torch
.
compile
(
pipe
.
unet
,
mode
=
"reduce-overhead"
,
fullgraph
=
True
)
pipe
.
text_encoder
=
torch
.
compile
(
pipe
.
text_encoder
,
mode
=
"reduce-overhead"
,
fullgraph
=
True
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
pipe
.
text_encoder_2
=
torch
.
compile
(
pipe
.
text_encoder_2
,
mode
=
"reduce-overhead"
,
fullgraph
=
True
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
pipe
.
text_encoder_2
=
torch
.
compile
(
pipe
.
text_encoder_2
,
mode
=
"reduce-overhead"
,
fullgraph
=
True
)
# Just makes sure it works.
.
_
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
# Just makes sure it works.
_
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
def
test_modify_padding_mode
(
self
):
def
set_pad_mode
(
network
,
mode
=
"circular"
):
...
...
@@ -1866,22 +1801,20 @@ class PeftLoraLoaderMixinTests:
if
isinstance
(
module
,
torch
.
nn
.
Conv2d
):
module
.
padding_mode
=
mode
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
_
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_pad_mode
=
"circular"
set_pad_mode
(
pipe
.
vae
,
_pad_mode
)
set_pad_mode
(
pipe
.
unet
,
_pad_mode
)
components
,
_
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_pad_mode
=
"circular"
set_pad_mode
(
pipe
.
vae
,
_pad_mode
)
set_pad_mode
(
pipe
.
unet
,
_pad_mode
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
()
_
=
pipe
(
**
inputs
)[
0
]
_
,
_
,
inputs
=
self
.
get_dummy_inputs
()
_
=
pipe
(
**
inputs
)[
0
]
def
test_logs_info_when_no_lora_keys_found
(
self
):
scheduler_cls
=
self
.
scheduler_classes
[
0
]
# Skip text encoder check for now as that is handled with `transformers`.
components
,
_
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
_
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1925,73 +1858,71 @@ class PeftLoraLoaderMixinTests:
def
test_set_adapters_match_attention_kwargs
(
self
):
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
attention_kwargs_name
=
determine_attention_kwargs_name
(
self
.
pipeline_class
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
lora_scale
=
0.5
attention_kwargs
=
{
attention_kwargs_name
:
{
"scale"
:
lora_scale
}}
output_lora_scale
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
),
**
attention_kwargs
)[
0
]
self
.
assertFalse
(
np
.
allclose
(
output_no_lora
,
output_lora_scale
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora + scale should change the output"
,
)
pipe
.
set_adapters
(
"default"
,
lora_scale
)
output_lora_scale_wo_kwargs
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
not
np
.
allclose
(
output_no_lora
,
output_lora_scale_wo_kwargs
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora + scale should change the output"
,
)
self
.
assertTrue
(
np
.
allclose
(
output_lora_scale
,
output_lora_scale_wo_kwargs
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora + scale should match the output of `set_adapters()`."
,
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
modules_to_save
=
self
.
_get_modules_to_save
(
pipe
,
has_denoiser
=
True
)
lora_state_dicts
=
self
.
_get_lora_state_dicts
(
modules_to_save
)
self
.
pipeline_class
.
save_lora_weights
(
save_directory
=
tmpdirname
,
safe_serialization
=
True
,
**
lora_state_dicts
)
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.safetensors"
)))
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
output_no_lora
.
shape
==
self
.
output_shape
)
pipe
.
load_lora_weights
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.safetensors"
))
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
text_lora_config
,
denoiser_lora_config
)
for
module_name
,
module
in
modules_to_save
.
items
():
self
.
assertTrue
(
check_if_lora_correctly_set
(
module
),
f
"Lora not correctly set in
{
module_name
}
"
)
lora_scale
=
0.5
attention_kwargs
=
{
attention_kwargs_name
:
{
"scale"
:
lora_scale
}}
output_lora_scale
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
),
**
attention_kwargs
)[
0
]
self
.
assertFalse
(
np
.
allclose
(
output_no_lora
,
output_lora_scale
,
atol
=
1e-3
,
rtol
=
1e-3
),
output_lora_from_pretrained
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
),
**
attention_kwargs
)[
0
]
self
.
assertTrue
(
not
np
.
allclose
(
output_no_lora
,
output_lora_from_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora + scale should change the output"
,
)
pipe
.
set_adapters
(
"default"
,
lora_scale
)
output_lora_scale_wo_kwargs
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
not
np
.
allclose
(
output_
no_
lora
,
output_lora_
scale_wo_kwargs
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lo
ra + scale should change the output
"
,
np
.
allclose
(
output_lora
_scale
,
output_lora_
from_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lo
ading from saved checkpoints should give same results as attention_kwargs.
"
,
)
self
.
assertTrue
(
np
.
allclose
(
output_lora_scale
,
output_lora_
scale_wo_kwargs
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lo
ra + scale should match the output of `
set_adapters()
`
."
,
np
.
allclose
(
output_lora_scale
_wo_kwargs
,
output_lora_
from_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lo
ading from saved checkpoints should give same results as
set_adapters()."
,
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
modules_to_save
=
self
.
_get_modules_to_save
(
pipe
,
has_denoiser
=
True
)
lora_state_dicts
=
self
.
_get_lora_state_dicts
(
modules_to_save
)
self
.
pipeline_class
.
save_lora_weights
(
save_directory
=
tmpdirname
,
safe_serialization
=
True
,
**
lora_state_dicts
)
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.safetensors"
)))
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
load_lora_weights
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.safetensors"
))
for
module_name
,
module
in
modules_to_save
.
items
():
self
.
assertTrue
(
check_if_lora_correctly_set
(
module
),
f
"Lora not correctly set in
{
module_name
}
"
)
output_lora_from_pretrained
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
),
**
attention_kwargs
)[
0
]
self
.
assertTrue
(
not
np
.
allclose
(
output_no_lora
,
output_lora_from_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Lora + scale should change the output"
,
)
self
.
assertTrue
(
np
.
allclose
(
output_lora_scale
,
output_lora_from_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Loading from saved checkpoints should give same results as attention_kwargs."
,
)
self
.
assertTrue
(
np
.
allclose
(
output_lora_scale_wo_kwargs
,
output_lora_from_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Loading from saved checkpoints should give same results as set_adapters()."
,
)
@
require_peft_version_greater
(
"0.13.2"
)
def
test_lora_B_bias
(
self
):
# Currently, this test is only relevant for Flux Control LoRA as we are not
# aware of any other LoRA checkpoint that has its `lora_B` biases trained.
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -2028,7 +1959,7 @@ class PeftLoraLoaderMixinTests:
self
.
assertFalse
(
np
.
allclose
(
lora_bias_false_output
,
lora_bias_true_output
,
atol
=
1e-3
,
rtol
=
1e-3
))
def
test_correct_lora_configs_with_different_ranks
(
self
):
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -2114,7 +2045,7 @@ class PeftLoraLoaderMixinTests:
self
.
assertEqual
(
submodule
.
bias
.
dtype
,
dtype_to_check
)
def
initialize_pipeline
(
storage_dtype
=
None
,
compute_dtype
=
torch
.
float32
):
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
,
dtype
=
compute_dtype
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -2181,7 +2112,7 @@ class PeftLoraLoaderMixinTests:
self
.
assertTrue
(
module
.
_diffusers_hook
.
get_hook
(
_PEFT_AUTOCAST_DISABLE_HOOK
)
is
not
None
)
# 1. Test forward with add_adapter
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
,
dtype
=
compute_dtype
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -2211,7 +2142,7 @@ class PeftLoraLoaderMixinTests:
)
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.safetensors"
)))
components
,
_
,
_
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
,
dtype
=
compute_dtype
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -2231,10 +2162,7 @@ class PeftLoraLoaderMixinTests:
@
parameterized
.
expand
([
4
,
8
,
16
])
def
test_lora_adapter_metadata_is_loaded_correctly
(
self
,
lora_alpha
):
scheduler_cls
=
self
.
scheduler_classes
[
0
]
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
,
lora_alpha
=
lora_alpha
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
lora_alpha
=
lora_alpha
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
...
...
@@ -2280,10 +2208,7 @@ class PeftLoraLoaderMixinTests:
@
parameterized
.
expand
([
4
,
8
,
16
])
def
test_lora_adapter_metadata_save_load_inference
(
self
,
lora_alpha
):
scheduler_cls
=
self
.
scheduler_classes
[
0
]
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
,
lora_alpha
=
lora_alpha
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
lora_alpha
=
lora_alpha
)
pipe
=
self
.
pipeline_class
(
**
components
).
to
(
torch_device
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
...
...
@@ -2311,8 +2236,7 @@ class PeftLoraLoaderMixinTests:
def
test_lora_unload_add_adapter
(
self
):
"""Tests if `unload_lora_weights()` -> `add_adapter()` works."""
scheduler_cls
=
self
.
scheduler_classes
[
0
]
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
).
to
(
torch_device
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
...
...
@@ -2330,51 +2254,48 @@ class PeftLoraLoaderMixinTests:
def
test_inference_load_delete_load_adapters
(
self
):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
output_no_lora
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser."
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
lora_loadable_components
=
self
.
pipeline_class
.
_lora_loadable_modules
if
"text_encoder_2"
in
lora_loadable_components
:
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
lora_loadable_components
=
self
.
pipeline_class
.
_lora_loadable_modules
if
"text_encoder_2"
in
lora_loadable_components
:
pipe
.
text_encoder_2
.
add_adapter
(
text_lora_config
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Lora not correctly set in text encoder 2"
)
output_adapter_1
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
output_adapter_1
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
modules_to_save
=
self
.
_get_modules_to_save
(
pipe
,
has_denoiser
=
True
)
lora_state_dicts
=
self
.
_get_lora_state_dicts
(
modules_to_save
)
self
.
pipeline_class
.
save_lora_weights
(
save_directory
=
tmpdirname
,
**
lora_state_dicts
)
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.safetensors"
)))
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
modules_to_save
=
self
.
_get_modules_to_save
(
pipe
,
has_denoiser
=
True
)
lora_state_dicts
=
self
.
_get_lora_state_dicts
(
modules_to_save
)
self
.
pipeline_class
.
save_lora_weights
(
save_directory
=
tmpdirname
,
**
lora_state_dicts
)
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.safetensors"
)))
# First, delete adapter and compare.
pipe
.
delete_adapters
(
pipe
.
get_active_adapters
()[
0
])
output_no_adapter
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertFalse
(
np
.
allclose
(
output_adapter_1
,
output_no_adapter
,
atol
=
1e-3
,
rtol
=
1e-3
))
self
.
assertTrue
(
np
.
allclose
(
output_no_lora
,
output_no_adapter
,
atol
=
1e-3
,
rtol
=
1e-3
))
# First, delete adapter and compare.
pipe
.
delete_adapters
(
pipe
.
get_active_adapters
()[
0
])
output_no_adapter
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertFalse
(
np
.
allclose
(
output_adapter_1
,
output_no_adapter
,
atol
=
1e-3
,
rtol
=
1e-3
))
self
.
assertTrue
(
np
.
allclose
(
output_no_lora
,
output_no_adapter
,
atol
=
1e-3
,
rtol
=
1e-3
))
# Then load adapter and compare.
pipe
.
load_lora_weights
(
tmpdirname
)
output_lora_loaded
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
output_adapter_1
,
output_lora_loaded
,
atol
=
1e-3
,
rtol
=
1e-3
))
# Then load adapter and compare.
pipe
.
load_lora_weights
(
tmpdirname
)
output_lora_loaded
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
output_adapter_1
,
output_lora_loaded
,
atol
=
1e-3
,
rtol
=
1e-3
))
def
_test_group_offloading_inference_denoiser
(
self
,
offload_type
,
use_stream
):
from
diffusers.hooks.group_offloading
import
_get_top_level_group_offload_hook
...
...
@@ -2382,7 +2303,7 @@ class PeftLoraLoaderMixinTests:
onload_device
=
torch_device
offload_device
=
torch
.
device
(
"cpu"
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -2399,7 +2320,7 @@ class PeftLoraLoaderMixinTests:
)
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.safetensors"
)))
components
,
_
,
_
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
...
...
@@ -2451,7 +2372,7 @@ class PeftLoraLoaderMixinTests:
@
require_torch_accelerator
def
test_lora_loading_model_cpu_offload
(
self
):
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
...
...
@@ -2470,7 +2391,7 @@ class PeftLoraLoaderMixinTests:
save_directory
=
tmpdirname
,
safe_serialization
=
True
,
**
lora_state_dicts
)
# reinitialize the pipeline to mimic the inference workflow.
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
.
enable_model_cpu_offload
(
device
=
torch_device
)
pipe
.
load_lora_weights
(
tmpdirname
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment