Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
09e777a3
Unverified
Commit
09e777a3
authored
Sep 24, 2025
by
Sayak Paul
Committed by
GitHub
Sep 24, 2025
Browse files
[tests] Single scheduler in lora tests (#12315)
* single scheduler please. * up * up * up
parent
a72bc0c4
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
1044 additions
and
1143 deletions
+1044
-1143
tests/lora/test_lora_layers_auraflow.py
tests/lora/test_lora_layers_auraflow.py
+0
-1
tests/lora/test_lora_layers_cogvideox.py
tests/lora/test_lora_layers_cogvideox.py
+0
-2
tests/lora/test_lora_layers_cogview4.py
tests/lora/test_lora_layers_cogview4.py
+17
-19
tests/lora/test_lora_layers_flux.py
tests/lora/test_lora_layers_flux.py
+2
-4
tests/lora/test_lora_layers_hunyuanvideo.py
tests/lora/test_lora_layers_hunyuanvideo.py
+0
-1
tests/lora/test_lora_layers_ltx_video.py
tests/lora/test_lora_layers_ltx_video.py
+0
-1
tests/lora/test_lora_layers_lumina2.py
tests/lora/test_lora_layers_lumina2.py
+27
-31
tests/lora/test_lora_layers_mochi.py
tests/lora/test_lora_layers_mochi.py
+0
-1
tests/lora/test_lora_layers_qwenimage.py
tests/lora/test_lora_layers_qwenimage.py
+0
-1
tests/lora/test_lora_layers_sana.py
tests/lora/test_lora_layers_sana.py
+2
-3
tests/lora/test_lora_layers_sd3.py
tests/lora/test_lora_layers_sd3.py
+0
-1
tests/lora/test_lora_layers_wan.py
tests/lora/test_lora_layers_wan.py
+0
-1
tests/lora/test_lora_layers_wanvace.py
tests/lora/test_lora_layers_wanvace.py
+1
-3
tests/lora/utils.py
tests/lora/utils.py
+995
-1074
No files found.
tests/lora/test_lora_layers_auraflow.py
View file @
09e777a3
...
...
@@ -43,7 +43,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
AuraFlowLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
AuraFlowPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
transformer_kwargs
=
{
...
...
tests/lora/test_lora_layers_cogvideox.py
View file @
09e777a3
...
...
@@ -21,7 +21,6 @@ from transformers import AutoTokenizer, T5EncoderModel
from
diffusers
import
(
AutoencoderKLCogVideoX
,
CogVideoXDDIMScheduler
,
CogVideoXDPMScheduler
,
CogVideoXPipeline
,
CogVideoXTransformer3DModel
,
...
...
@@ -44,7 +43,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class
=
CogVideoXPipeline
scheduler_cls
=
CogVideoXDPMScheduler
scheduler_kwargs
=
{
"timestep_spacing"
:
"trailing"
}
scheduler_classes
=
[
CogVideoXDDIMScheduler
,
CogVideoXDPMScheduler
]
transformer_kwargs
=
{
"num_attention_heads"
:
4
,
...
...
tests/lora/test_lora_layers_cogview4.py
View file @
09e777a3
...
...
@@ -50,7 +50,6 @@ class TokenizerWrapper:
class
CogView4LoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
CogView4Pipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
transformer_kwargs
=
{
...
...
@@ -124,8 +123,7 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
_
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
_
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
tests/lora/test_lora_layers_flux.py
View file @
09e777a3
...
...
@@ -55,9 +55,8 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
@
require_peft_backend
class
FluxLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
FluxPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
()
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_kwargs
=
{}
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
transformer_kwargs
=
{
"patch_size"
:
1
,
"in_channels"
:
4
,
...
...
@@ -282,9 +281,8 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class
FluxControlLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
FluxControlPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
()
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_kwargs
=
{}
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
transformer_kwargs
=
{
"patch_size"
:
1
,
"in_channels"
:
8
,
...
...
tests/lora/test_lora_layers_hunyuanvideo.py
View file @
09e777a3
...
...
@@ -51,7 +51,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
HunyuanVideoLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
HunyuanVideoPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
transformer_kwargs
=
{
...
...
tests/lora/test_lora_layers_ltx_video.py
View file @
09e777a3
...
...
@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
LTXVideoLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
LTXPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
transformer_kwargs
=
{
...
...
tests/lora/test_lora_layers_lumina2.py
View file @
09e777a3
...
...
@@ -39,7 +39,6 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
class
Lumina2LoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
Lumina2Pipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
transformer_kwargs
=
{
...
...
@@ -141,8 +140,7 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
strict
=
False
,
)
def
test_lora_fuse_nan
(
self
):
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -150,9 +148,7 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
...
...
tests/lora/test_lora_layers_mochi.py
View file @
09e777a3
...
...
@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
MochiLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
MochiPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
transformer_kwargs
=
{
...
...
tests/lora/test_lora_layers_qwenimage.py
View file @
09e777a3
...
...
@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
QwenImageLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
QwenImagePipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
transformer_kwargs
=
{
...
...
tests/lora/test_lora_layers_sana.py
View file @
09e777a3
...
...
@@ -31,9 +31,8 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@
require_peft_backend
class
SanaLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
SanaPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
(
shift
=
7.0
)
scheduler_kwargs
=
{}
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_kwargs
=
{
"shift"
:
7.0
}
transformer_kwargs
=
{
"patch_size"
:
1
,
"in_channels"
:
4
,
...
...
tests/lora/test_lora_layers_sd3.py
View file @
09e777a3
...
...
@@ -55,7 +55,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class
=
StableDiffusion3Pipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_kwargs
=
{}
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
transformer_kwargs
=
{
"sample_size"
:
32
,
"patch_size"
:
1
,
...
...
tests/lora/test_lora_layers_wan.py
View file @
09e777a3
...
...
@@ -42,7 +42,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
WanLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
WanPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
transformer_kwargs
=
{
...
...
tests/lora/test_lora_layers_wanvace.py
View file @
09e777a3
...
...
@@ -50,7 +50,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
WanVACELoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
WanVACEPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
transformer_kwargs
=
{
...
...
@@ -165,9 +164,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@
require_peft_version_greater
(
"0.13.2"
)
def
test_lora_exclude_modules_wanvace
(
self
):
scheduler_cls
=
self
.
scheduler_classes
[
0
]
exclude_module_name
=
"vace_blocks.0.proj_out"
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
).
to
(
torch_device
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
...
...
tests/lora/utils.py
View file @
09e777a3
...
...
@@ -26,8 +26,6 @@ from parameterized import parameterized
from
diffusers
import
(
AutoencoderKL
,
DDIMScheduler
,
LCMScheduler
,
UNet2DConditionModel
,
)
from
diffusers.utils
import
logging
...
...
@@ -109,7 +107,6 @@ class PeftLoraLoaderMixinTests:
scheduler_cls
=
None
scheduler_kwargs
=
None
scheduler_classes
=
[
DDIMScheduler
,
LCMScheduler
]
has_two_text_encoders
=
False
has_three_text_encoders
=
False
...
...
@@ -129,13 +126,13 @@ class PeftLoraLoaderMixinTests:
text_encoder_target_modules
=
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
"out_proj"
]
denoiser_target_modules
=
[
"to_q"
,
"to_k"
,
"to_v"
,
"to_out.0"
]
def
get_dummy_components
(
self
,
scheduler_cls
=
None
,
use_dora
=
False
,
lora_alpha
=
None
):
def
get_dummy_components
(
self
,
use_dora
=
False
,
lora_alpha
=
None
):
if
self
.
unet_kwargs
and
self
.
transformer_kwargs
:
raise
ValueError
(
"Both `unet_kwargs` and `transformer_kwargs` cannot be specified."
)
if
self
.
has_two_text_encoders
and
self
.
has_three_text_encoders
:
raise
ValueError
(
"Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True."
)
scheduler_cls
=
self
.
scheduler_cls
if
scheduler_cls
is
None
else
scheduler_cls
scheduler_cls
=
self
.
scheduler_cls
rank
=
4
lora_alpha
=
rank
if
lora_alpha
is
None
else
lora_alpha
...
...
@@ -319,8 +316,7 @@ class PeftLoraLoaderMixinTests:
"""
Tests a simple inference and makes sure it works as expected
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -334,8 +330,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached on the text encoder
and makes sure it works as expected
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -354,17 +349,14 @@ class PeftLoraLoaderMixinTests:
@
require_peft_version_greater
(
"0.13.1"
)
def
test_low_cpu_mem_usage_with_injection
(
self
):
"""Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
inject_adapter_in_model
(
text_lora_config
,
pipe
.
text_encoder
,
low_cpu_mem_usage
=
True
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder."
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder."
)
self
.
assertTrue
(
"meta"
in
{
p
.
device
.
type
for
p
in
pipe
.
text_encoder
.
parameters
()},
"The LoRA params should be on 'meta' device."
,
...
...
@@ -416,9 +408,7 @@ class PeftLoraLoaderMixinTests:
@
require_transformers_version_greater
(
"4.45.2"
)
def
test_low_cpu_mem_usage_with_loading
(
self
):
"""Tests if we can load LoRA state dict with low_cpu_mem_usage."""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -460,9 +450,7 @@ class PeftLoraLoaderMixinTests:
images_lora_from_pretrained_low_cpu
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
np
.
allclose
(
images_lora_from_pretrained_low_cpu
,
images_lora_from_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
np
.
allclose
(
images_lora_from_pretrained_low_cpu
,
images_lora_from_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Loading from saved checkpoints with `low_cpu_mem_usage` should give same results."
,
)
...
...
@@ -472,9 +460,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
attention_kwargs_name
=
determine_attention_kwargs_name
(
self
.
pipeline_class
)
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -511,8 +497,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -543,8 +528,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder, then unloads the lora weights
and makes sure it works as expected
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -557,9 +541,7 @@ class PeftLoraLoaderMixinTests:
pipe
.
unload_lora_weights
()
# unloading should remove the LoRA layers
self
.
assertFalse
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly unloaded in text encoder"
)
self
.
assertFalse
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly unloaded in text encoder"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
...
...
@@ -578,8 +560,7 @@ class PeftLoraLoaderMixinTests:
"""
Tests a simple usecase where users could use saving utilities for LoRA.
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -620,8 +601,7 @@ class PeftLoraLoaderMixinTests:
with different ranks and some adapters removed
and makes sure it works as expected
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
_
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
_
,
_
=
self
.
get_dummy_components
()
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
text_lora_config
=
LoraConfig
(
r
=
4
,
...
...
@@ -680,8 +660,7 @@ class PeftLoraLoaderMixinTests:
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -723,8 +702,7 @@ class PeftLoraLoaderMixinTests:
"""
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -763,9 +741,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
attention_kwargs_name
=
determine_attention_kwargs_name
(
self
.
pipeline_class
)
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -808,8 +784,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -824,9 +799,7 @@ class PeftLoraLoaderMixinTests:
# Fusing should still keep the LoRA layers
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser"
)
...
...
@@ -846,8 +819,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -860,9 +832,7 @@ class PeftLoraLoaderMixinTests:
pipe
.
unload_lora_weights
()
# unloading should remove the LoRA layers
self
.
assertFalse
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly unloaded in text encoder"
)
self
.
assertFalse
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly unloaded in text encoder"
)
self
.
assertFalse
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly unloaded in denoiser"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
...
...
@@ -885,8 +855,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -925,8 +894,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -937,9 +905,7 @@ class PeftLoraLoaderMixinTests:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
...
...
@@ -1002,8 +968,7 @@ class PeftLoraLoaderMixinTests:
def
test_wrong_adapter_name_raises_error
(
self
):
adapter_name
=
"adapter-1"
scheduler_cls
=
self
.
scheduler_classes
[
0
]
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1024,8 +989,7 @@ class PeftLoraLoaderMixinTests:
def
test_multiple_wrong_adapter_name_raises_error
(
self
):
adapter_name
=
"adapter-1"
scheduler_cls
=
self
.
scheduler_classes
[
0
]
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1054,8 +1018,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
one adapter and set different weights for different blocks (i.e. block lora)
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1111,8 +1074,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set different weights for different blocks (i.e. block lora)
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1123,9 +1085,7 @@ class PeftLoraLoaderMixinTests:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
...
...
@@ -1274,8 +1234,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set/delete them
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1286,9 +1245,7 @@ class PeftLoraLoaderMixinTests:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
...
...
@@ -1368,8 +1325,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1380,9 +1336,7 @@ class PeftLoraLoaderMixinTests:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
...
...
@@ -1446,8 +1400,7 @@ class PeftLoraLoaderMixinTests:
strict
=
False
,
)
def
test_lora_fuse_nan
(
self
):
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1455,9 +1408,7 @@ class PeftLoraLoaderMixinTests:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
...
...
@@ -1466,9 +1417,9 @@ class PeftLoraLoaderMixinTests:
# corrupt one LoRA weight with `inf` values
with
torch
.
no_grad
():
if
self
.
unet_kwargs
:
pipe
.
unet
.
mid_block
.
attentions
[
0
].
transformer_blocks
[
0
].
attn1
.
to_q
.
lora_A
[
"adapter-1
"
].
weight
+=
float
(
"inf"
)
pipe
.
unet
.
mid_block
.
attentions
[
0
].
transformer_blocks
[
0
].
attn1
.
to_q
.
lora_A
[
"adapter-1"
].
weight
+=
float
(
"inf
"
)
else
:
named_modules
=
[
name
for
name
,
_
in
pipe
.
transformer
.
named_modules
()]
possible_tower_names
=
[
...
...
@@ -1481,9 +1432,7 @@ class PeftLoraLoaderMixinTests:
tower_name
for
tower_name
in
possible_tower_names
if
hasattr
(
pipe
.
transformer
,
tower_name
)
]
if
len
(
filtered_tower_names
)
==
0
:
reason
=
(
f
"`pipe.transformer` didn't have any of the following attributes:
{
possible_tower_names
}
."
)
reason
=
f
"`pipe.transformer` didn't have any of the following attributes:
{
possible_tower_names
}
."
raise
ValueError
(
reason
)
for
tower_name
in
filtered_tower_names
:
transformer_tower
=
getattr
(
pipe
.
transformer
,
tower_name
)
...
...
@@ -1508,8 +1457,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1537,8 +1485,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1612,8 +1559,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet and multi-adapter case
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1624,9 +1570,7 @@ class PeftLoraLoaderMixinTests:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
...
...
@@ -1675,9 +1619,7 @@ class PeftLoraLoaderMixinTests:
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Unfuse should still keep LoRA layers"
)
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
,
adapter_names
=
[
"adapter-2"
,
"adapter-1"
]
)
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
,
adapter_names
=
[
"adapter-2"
,
"adapter-1"
])
self
.
assertTrue
(
pipe
.
num_fused_loras
==
2
,
f
"
{
pipe
.
num_fused_loras
=
}
,
{
pipe
.
fused_loras
=
}
"
)
# Fusing should still keep the LoRA layers
...
...
@@ -1693,8 +1635,7 @@ class PeftLoraLoaderMixinTests:
attention_kwargs_name
=
determine_attention_kwargs_name
(
self
.
pipeline_class
)
for
lora_scale
in
[
1.0
,
0.8
]:
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1746,10 +1687,7 @@ class PeftLoraLoaderMixinTests:
@
require_peft_version_greater
(
peft_version
=
"0.9.0"
)
def
test_simple_inference_with_dora
(
self
):
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
,
use_dora
=
True
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
use_dora
=
True
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1768,9 +1706,8 @@ class PeftLoraLoaderMixinTests:
)
def
test_missing_keys_warning
(
self
):
scheduler_cls
=
self
.
scheduler_classes
[
0
]
# Skip text encoder check for now as that is handled with `transformers`.
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1805,9 +1742,8 @@ class PeftLoraLoaderMixinTests:
self
.
assertTrue
(
missing_key
.
replace
(
f
"
{
component
}
."
,
""
)
in
cap_logger
.
out
.
replace
(
"default_0."
,
""
))
def
test_unexpected_keys_warning
(
self
):
scheduler_cls
=
self
.
scheduler_classes
[
0
]
# Skip text encoder check for now as that is handled with `transformers`.
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1842,8 +1778,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1857,7 +1792,7 @@ class PeftLoraLoaderMixinTests:
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
pipe
.
text_encoder_2
=
torch
.
compile
(
pipe
.
text_encoder_2
,
mode
=
"reduce-overhead"
,
fullgraph
=
True
)
# Just makes sure it works.
.
# Just makes sure it works.
_
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
def
test_modify_padding_mode
(
self
):
...
...
@@ -1866,8 +1801,7 @@ class PeftLoraLoaderMixinTests:
if
isinstance
(
module
,
torch
.
nn
.
Conv2d
):
module
.
padding_mode
=
mode
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
_
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
_
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1879,9 +1813,8 @@ class PeftLoraLoaderMixinTests:
_
=
pipe
(
**
inputs
)[
0
]
def
test_logs_info_when_no_lora_keys_found
(
self
):
scheduler_cls
=
self
.
scheduler_classes
[
0
]
# Skip text encoder check for now as that is handled with `transformers`.
components
,
_
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
_
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1925,9 +1858,7 @@ class PeftLoraLoaderMixinTests:
def
test_set_adapters_match_attention_kwargs
(
self
):
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
attention_kwargs_name
=
determine_attention_kwargs_name
(
self
.
pipeline_class
)
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -1991,7 +1922,7 @@ class PeftLoraLoaderMixinTests:
def
test_lora_B_bias
(
self
):
# Currently, this test is only relevant for Flux Control LoRA as we are not
# aware of any other LoRA checkpoint that has its `lora_B` biases trained.
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -2028,7 +1959,7 @@ class PeftLoraLoaderMixinTests:
self
.
assertFalse
(
np
.
allclose
(
lora_bias_false_output
,
lora_bias_true_output
,
atol
=
1e-3
,
rtol
=
1e-3
))
def
test_correct_lora_configs_with_different_ranks
(
self
):
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -2114,7 +2045,7 @@ class PeftLoraLoaderMixinTests:
self
.
assertEqual
(
submodule
.
bias
.
dtype
,
dtype_to_check
)
def
initialize_pipeline
(
storage_dtype
=
None
,
compute_dtype
=
torch
.
float32
):
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
,
dtype
=
compute_dtype
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -2181,7 +2112,7 @@ class PeftLoraLoaderMixinTests:
self
.
assertTrue
(
module
.
_diffusers_hook
.
get_hook
(
_PEFT_AUTOCAST_DISABLE_HOOK
)
is
not
None
)
# 1. Test forward with add_adapter
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
,
dtype
=
compute_dtype
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -2211,7 +2142,7 @@ class PeftLoraLoaderMixinTests:
)
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.safetensors"
)))
components
,
_
,
_
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
,
dtype
=
compute_dtype
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -2231,10 +2162,7 @@ class PeftLoraLoaderMixinTests:
@
parameterized
.
expand
([
4
,
8
,
16
])
def
test_lora_adapter_metadata_is_loaded_correctly
(
self
,
lora_alpha
):
scheduler_cls
=
self
.
scheduler_classes
[
0
]
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
,
lora_alpha
=
lora_alpha
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
lora_alpha
=
lora_alpha
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
...
...
@@ -2280,10 +2208,7 @@ class PeftLoraLoaderMixinTests:
@
parameterized
.
expand
([
4
,
8
,
16
])
def
test_lora_adapter_metadata_save_load_inference
(
self
,
lora_alpha
):
scheduler_cls
=
self
.
scheduler_classes
[
0
]
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
,
lora_alpha
=
lora_alpha
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
lora_alpha
=
lora_alpha
)
pipe
=
self
.
pipeline_class
(
**
components
).
to
(
torch_device
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
...
...
@@ -2311,8 +2236,7 @@ class PeftLoraLoaderMixinTests:
def
test_lora_unload_add_adapter
(
self
):
"""Tests if `unload_lora_weights()` -> `add_adapter()` works."""
scheduler_cls
=
self
.
scheduler_classes
[
0
]
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
).
to
(
torch_device
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
...
...
@@ -2330,8 +2254,7 @@ class PeftLoraLoaderMixinTests:
def
test_inference_load_delete_load_adapters
(
self
):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -2341,9 +2264,7 @@ class PeftLoraLoaderMixinTests:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
)
...
...
@@ -2382,7 +2303,7 @@ class PeftLoraLoaderMixinTests:
onload_device
=
torch_device
offload_device
=
torch
.
device
(
"cpu"
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
@@ -2399,7 +2320,7 @@ class PeftLoraLoaderMixinTests:
)
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.safetensors"
)))
components
,
_
,
_
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
...
...
@@ -2451,7 +2372,7 @@ class PeftLoraLoaderMixinTests:
@
require_torch_accelerator
def
test_lora_loading_model_cpu_offload
(
self
):
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
...
...
@@ -2470,7 +2391,7 @@ class PeftLoraLoaderMixinTests:
save_directory
=
tmpdirname
,
safe_serialization
=
True
,
**
lora_state_dicts
)
# reinitialize the pipeline to mimic the inference workflow.
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
.
enable_model_cpu_offload
(
device
=
torch_device
)
pipe
.
load_lora_weights
(
tmpdirname
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment