Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
09e777a3
Unverified
Commit
09e777a3
authored
Sep 24, 2025
by
Sayak Paul
Committed by
GitHub
Sep 24, 2025
Browse files
[tests] Single scheduler in lora tests (#12315)
* single scheduler please. * up * up * up
parent
a72bc0c4
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
1044 additions
and
1143 deletions
+1044
-1143
tests/lora/test_lora_layers_auraflow.py
tests/lora/test_lora_layers_auraflow.py
+0
-1
tests/lora/test_lora_layers_cogvideox.py
tests/lora/test_lora_layers_cogvideox.py
+0
-2
tests/lora/test_lora_layers_cogview4.py
tests/lora/test_lora_layers_cogview4.py
+17
-19
tests/lora/test_lora_layers_flux.py
tests/lora/test_lora_layers_flux.py
+2
-4
tests/lora/test_lora_layers_hunyuanvideo.py
tests/lora/test_lora_layers_hunyuanvideo.py
+0
-1
tests/lora/test_lora_layers_ltx_video.py
tests/lora/test_lora_layers_ltx_video.py
+0
-1
tests/lora/test_lora_layers_lumina2.py
tests/lora/test_lora_layers_lumina2.py
+27
-31
tests/lora/test_lora_layers_mochi.py
tests/lora/test_lora_layers_mochi.py
+0
-1
tests/lora/test_lora_layers_qwenimage.py
tests/lora/test_lora_layers_qwenimage.py
+0
-1
tests/lora/test_lora_layers_sana.py
tests/lora/test_lora_layers_sana.py
+2
-3
tests/lora/test_lora_layers_sd3.py
tests/lora/test_lora_layers_sd3.py
+0
-1
tests/lora/test_lora_layers_wan.py
tests/lora/test_lora_layers_wan.py
+0
-1
tests/lora/test_lora_layers_wanvace.py
tests/lora/test_lora_layers_wanvace.py
+1
-3
tests/lora/utils.py
tests/lora/utils.py
+995
-1074
No files found.
tests/lora/test_lora_layers_auraflow.py
View file @
09e777a3
...
@@ -43,7 +43,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
...
@@ -43,7 +43,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
AuraFlowLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
class
AuraFlowLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
AuraFlowPipeline
pipeline_class
=
AuraFlowPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
scheduler_kwargs
=
{}
transformer_kwargs
=
{
transformer_kwargs
=
{
...
...
tests/lora/test_lora_layers_cogvideox.py
View file @
09e777a3
...
@@ -21,7 +21,6 @@ from transformers import AutoTokenizer, T5EncoderModel
...
@@ -21,7 +21,6 @@ from transformers import AutoTokenizer, T5EncoderModel
from
diffusers
import
(
from
diffusers
import
(
AutoencoderKLCogVideoX
,
AutoencoderKLCogVideoX
,
CogVideoXDDIMScheduler
,
CogVideoXDPMScheduler
,
CogVideoXDPMScheduler
,
CogVideoXPipeline
,
CogVideoXPipeline
,
CogVideoXTransformer3DModel
,
CogVideoXTransformer3DModel
,
...
@@ -44,7 +43,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
...
@@ -44,7 +43,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class
=
CogVideoXPipeline
pipeline_class
=
CogVideoXPipeline
scheduler_cls
=
CogVideoXDPMScheduler
scheduler_cls
=
CogVideoXDPMScheduler
scheduler_kwargs
=
{
"timestep_spacing"
:
"trailing"
}
scheduler_kwargs
=
{
"timestep_spacing"
:
"trailing"
}
scheduler_classes
=
[
CogVideoXDDIMScheduler
,
CogVideoXDPMScheduler
]
transformer_kwargs
=
{
transformer_kwargs
=
{
"num_attention_heads"
:
4
,
"num_attention_heads"
:
4
,
...
...
tests/lora/test_lora_layers_cogview4.py
View file @
09e777a3
...
@@ -50,7 +50,6 @@ class TokenizerWrapper:
...
@@ -50,7 +50,6 @@ class TokenizerWrapper:
class
CogView4LoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
class
CogView4LoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
CogView4Pipeline
pipeline_class
=
CogView4Pipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
scheduler_kwargs
=
{}
transformer_kwargs
=
{
transformer_kwargs
=
{
...
@@ -124,8 +123,7 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
...
@@ -124,8 +123,7 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"""
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
_
,
_
=
self
.
get_dummy_components
()
components
,
_
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
tests/lora/test_lora_layers_flux.py
View file @
09e777a3
...
@@ -55,9 +55,8 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
...
@@ -55,9 +55,8 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
@
require_peft_backend
@
require_peft_backend
class
FluxLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
class
FluxLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
FluxPipeline
pipeline_class
=
FluxPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
()
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_kwargs
=
{}
scheduler_kwargs
=
{}
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
transformer_kwargs
=
{
transformer_kwargs
=
{
"patch_size"
:
1
,
"patch_size"
:
1
,
"in_channels"
:
4
,
"in_channels"
:
4
,
...
@@ -282,9 +281,8 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
...
@@ -282,9 +281,8 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class
FluxControlLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
class
FluxControlLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
FluxControlPipeline
pipeline_class
=
FluxControlPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
()
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_kwargs
=
{}
scheduler_kwargs
=
{}
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
transformer_kwargs
=
{
transformer_kwargs
=
{
"patch_size"
:
1
,
"patch_size"
:
1
,
"in_channels"
:
8
,
"in_channels"
:
8
,
...
...
tests/lora/test_lora_layers_hunyuanvideo.py
View file @
09e777a3
...
@@ -51,7 +51,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
...
@@ -51,7 +51,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
HunyuanVideoLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
class
HunyuanVideoLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
HunyuanVideoPipeline
pipeline_class
=
HunyuanVideoPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
scheduler_kwargs
=
{}
transformer_kwargs
=
{
transformer_kwargs
=
{
...
...
tests/lora/test_lora_layers_ltx_video.py
View file @
09e777a3
...
@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
...
@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
LTXVideoLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
class
LTXVideoLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
LTXPipeline
pipeline_class
=
LTXPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
scheduler_kwargs
=
{}
transformer_kwargs
=
{
transformer_kwargs
=
{
...
...
tests/lora/test_lora_layers_lumina2.py
View file @
09e777a3
...
@@ -39,7 +39,6 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
...
@@ -39,7 +39,6 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
class
Lumina2LoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
class
Lumina2LoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
Lumina2Pipeline
pipeline_class
=
Lumina2Pipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
scheduler_kwargs
=
{}
transformer_kwargs
=
{
transformer_kwargs
=
{
...
@@ -141,8 +140,7 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
...
@@ -141,8 +140,7 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
strict
=
False
,
strict
=
False
,
)
)
def
test_lora_fuse_nan
(
self
):
def
test_lora_fuse_nan
(
self
):
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -150,9 +148,7 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
...
@@ -150,9 +148,7 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
...
...
tests/lora/test_lora_layers_mochi.py
View file @
09e777a3
...
@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
...
@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
MochiLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
class
MochiLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
MochiPipeline
pipeline_class
=
MochiPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
scheduler_kwargs
=
{}
transformer_kwargs
=
{
transformer_kwargs
=
{
...
...
tests/lora/test_lora_layers_qwenimage.py
View file @
09e777a3
...
@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
...
@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
QwenImageLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
class
QwenImageLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
QwenImagePipeline
pipeline_class
=
QwenImagePipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
scheduler_kwargs
=
{}
transformer_kwargs
=
{
transformer_kwargs
=
{
...
...
tests/lora/test_lora_layers_sana.py
View file @
09e777a3
...
@@ -31,9 +31,8 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
...
@@ -31,9 +31,8 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@
require_peft_backend
@
require_peft_backend
class
SanaLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
class
SanaLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
SanaPipeline
pipeline_class
=
SanaPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
(
shift
=
7.0
)
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_kwargs
=
{}
scheduler_kwargs
=
{
"shift"
:
7.0
}
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
transformer_kwargs
=
{
transformer_kwargs
=
{
"patch_size"
:
1
,
"patch_size"
:
1
,
"in_channels"
:
4
,
"in_channels"
:
4
,
...
...
tests/lora/test_lora_layers_sd3.py
View file @
09e777a3
...
@@ -55,7 +55,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
...
@@ -55,7 +55,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class
=
StableDiffusion3Pipeline
pipeline_class
=
StableDiffusion3Pipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_kwargs
=
{}
scheduler_kwargs
=
{}
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
transformer_kwargs
=
{
transformer_kwargs
=
{
"sample_size"
:
32
,
"sample_size"
:
32
,
"patch_size"
:
1
,
"patch_size"
:
1
,
...
...
tests/lora/test_lora_layers_wan.py
View file @
09e777a3
...
@@ -42,7 +42,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
...
@@ -42,7 +42,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
WanLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
class
WanLoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
WanPipeline
pipeline_class
=
WanPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
scheduler_kwargs
=
{}
transformer_kwargs
=
{
transformer_kwargs
=
{
...
...
tests/lora/test_lora_layers_wanvace.py
View file @
09e777a3
...
@@ -50,7 +50,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
...
@@ -50,7 +50,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class
WanVACELoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
class
WanVACELoRATests
(
unittest
.
TestCase
,
PeftLoraLoaderMixinTests
):
pipeline_class
=
WanVACEPipeline
pipeline_class
=
WanVACEPipeline
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_cls
=
FlowMatchEulerDiscreteScheduler
scheduler_classes
=
[
FlowMatchEulerDiscreteScheduler
]
scheduler_kwargs
=
{}
scheduler_kwargs
=
{}
transformer_kwargs
=
{
transformer_kwargs
=
{
...
@@ -165,9 +164,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
...
@@ -165,9 +164,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@
require_peft_version_greater
(
"0.13.2"
)
@
require_peft_version_greater
(
"0.13.2"
)
def
test_lora_exclude_modules_wanvace
(
self
):
def
test_lora_exclude_modules_wanvace
(
self
):
scheduler_cls
=
self
.
scheduler_classes
[
0
]
exclude_module_name
=
"vace_blocks.0.proj_out"
exclude_module_name
=
"vace_blocks.0.proj_out"
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
).
to
(
torch_device
)
pipe
=
self
.
pipeline_class
(
**
components
).
to
(
torch_device
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
...
...
tests/lora/utils.py
View file @
09e777a3
...
@@ -26,8 +26,6 @@ from parameterized import parameterized
...
@@ -26,8 +26,6 @@ from parameterized import parameterized
from
diffusers
import
(
from
diffusers
import
(
AutoencoderKL
,
AutoencoderKL
,
DDIMScheduler
,
LCMScheduler
,
UNet2DConditionModel
,
UNet2DConditionModel
,
)
)
from
diffusers.utils
import
logging
from
diffusers.utils
import
logging
...
@@ -109,7 +107,6 @@ class PeftLoraLoaderMixinTests:
...
@@ -109,7 +107,6 @@ class PeftLoraLoaderMixinTests:
scheduler_cls
=
None
scheduler_cls
=
None
scheduler_kwargs
=
None
scheduler_kwargs
=
None
scheduler_classes
=
[
DDIMScheduler
,
LCMScheduler
]
has_two_text_encoders
=
False
has_two_text_encoders
=
False
has_three_text_encoders
=
False
has_three_text_encoders
=
False
...
@@ -129,13 +126,13 @@ class PeftLoraLoaderMixinTests:
...
@@ -129,13 +126,13 @@ class PeftLoraLoaderMixinTests:
text_encoder_target_modules
=
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
"out_proj"
]
text_encoder_target_modules
=
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
"out_proj"
]
denoiser_target_modules
=
[
"to_q"
,
"to_k"
,
"to_v"
,
"to_out.0"
]
denoiser_target_modules
=
[
"to_q"
,
"to_k"
,
"to_v"
,
"to_out.0"
]
def
get_dummy_components
(
self
,
scheduler_cls
=
None
,
use_dora
=
False
,
lora_alpha
=
None
):
def
get_dummy_components
(
self
,
use_dora
=
False
,
lora_alpha
=
None
):
if
self
.
unet_kwargs
and
self
.
transformer_kwargs
:
if
self
.
unet_kwargs
and
self
.
transformer_kwargs
:
raise
ValueError
(
"Both `unet_kwargs` and `transformer_kwargs` cannot be specified."
)
raise
ValueError
(
"Both `unet_kwargs` and `transformer_kwargs` cannot be specified."
)
if
self
.
has_two_text_encoders
and
self
.
has_three_text_encoders
:
if
self
.
has_two_text_encoders
and
self
.
has_three_text_encoders
:
raise
ValueError
(
"Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True."
)
raise
ValueError
(
"Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True."
)
scheduler_cls
=
self
.
scheduler_cls
if
scheduler_cls
is
None
else
scheduler_cls
scheduler_cls
=
self
.
scheduler_cls
rank
=
4
rank
=
4
lora_alpha
=
rank
if
lora_alpha
is
None
else
lora_alpha
lora_alpha
=
rank
if
lora_alpha
is
None
else
lora_alpha
...
@@ -319,8 +316,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -319,8 +316,7 @@ class PeftLoraLoaderMixinTests:
"""
"""
Tests a simple inference and makes sure it works as expected
Tests a simple inference and makes sure it works as expected
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -334,8 +330,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -334,8 +330,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached on the text encoder
Tests a simple inference with lora attached on the text encoder
and makes sure it works as expected
and makes sure it works as expected
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -354,17 +349,14 @@ class PeftLoraLoaderMixinTests:
...
@@ -354,17 +349,14 @@ class PeftLoraLoaderMixinTests:
@
require_peft_version_greater
(
"0.13.1"
)
@
require_peft_version_greater
(
"0.13.1"
)
def
test_low_cpu_mem_usage_with_injection
(
self
):
def
test_low_cpu_mem_usage_with_injection
(
self
):
"""Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
"""Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
inject_adapter_in_model
(
text_lora_config
,
pipe
.
text_encoder
,
low_cpu_mem_usage
=
True
)
inject_adapter_in_model
(
text_lora_config
,
pipe
.
text_encoder
,
low_cpu_mem_usage
=
True
)
self
.
assertTrue
(
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder."
)
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder."
)
self
.
assertTrue
(
self
.
assertTrue
(
"meta"
in
{
p
.
device
.
type
for
p
in
pipe
.
text_encoder
.
parameters
()},
"meta"
in
{
p
.
device
.
type
for
p
in
pipe
.
text_encoder
.
parameters
()},
"The LoRA params should be on 'meta' device."
,
"The LoRA params should be on 'meta' device."
,
...
@@ -416,9 +408,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -416,9 +408,7 @@ class PeftLoraLoaderMixinTests:
@
require_transformers_version_greater
(
"4.45.2"
)
@
require_transformers_version_greater
(
"4.45.2"
)
def
test_low_cpu_mem_usage_with_loading
(
self
):
def
test_low_cpu_mem_usage_with_loading
(
self
):
"""Tests if we can load LoRA state dict with low_cpu_mem_usage."""
"""Tests if we can load LoRA state dict with low_cpu_mem_usage."""
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -460,9 +450,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -460,9 +450,7 @@ class PeftLoraLoaderMixinTests:
images_lora_from_pretrained_low_cpu
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
images_lora_from_pretrained_low_cpu
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertTrue
(
self
.
assertTrue
(
np
.
allclose
(
np
.
allclose
(
images_lora_from_pretrained_low_cpu
,
images_lora_from_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
images_lora_from_pretrained_low_cpu
,
images_lora_from_pretrained
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Loading from saved checkpoints with `low_cpu_mem_usage` should give same results."
,
"Loading from saved checkpoints with `low_cpu_mem_usage` should give same results."
,
)
)
...
@@ -472,9 +460,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -472,9 +460,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
and makes sure it works as expected
"""
"""
attention_kwargs_name
=
determine_attention_kwargs_name
(
self
.
pipeline_class
)
attention_kwargs_name
=
determine_attention_kwargs_name
(
self
.
pipeline_class
)
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -511,8 +497,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -511,8 +497,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected
and makes sure it works as expected
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -543,8 +528,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -543,8 +528,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder, then unloads the lora weights
Tests a simple inference with lora attached to text encoder, then unloads the lora weights
and makes sure it works as expected
and makes sure it works as expected
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -557,9 +541,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -557,9 +541,7 @@ class PeftLoraLoaderMixinTests:
pipe
.
unload_lora_weights
()
pipe
.
unload_lora_weights
()
# unloading should remove the LoRA layers
# unloading should remove the LoRA layers
self
.
assertFalse
(
self
.
assertFalse
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly unloaded in text encoder"
)
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly unloaded in text encoder"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
if
"text_encoder_2"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
...
@@ -578,8 +560,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -578,8 +560,7 @@ class PeftLoraLoaderMixinTests:
"""
"""
Tests a simple usecase where users could use saving utilities for LoRA.
Tests a simple usecase where users could use saving utilities for LoRA.
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -620,8 +601,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -620,8 +601,7 @@ class PeftLoraLoaderMixinTests:
with different ranks and some adapters removed
with different ranks and some adapters removed
and makes sure it works as expected
and makes sure it works as expected
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
_
,
_
=
self
.
get_dummy_components
()
components
,
_
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
text_lora_config
=
LoraConfig
(
text_lora_config
=
LoraConfig
(
r
=
4
,
r
=
4
,
...
@@ -680,8 +660,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -680,8 +660,7 @@ class PeftLoraLoaderMixinTests:
"""
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -723,8 +702,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -723,8 +702,7 @@ class PeftLoraLoaderMixinTests:
"""
"""
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -763,9 +741,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -763,9 +741,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
and makes sure it works as expected
"""
"""
attention_kwargs_name
=
determine_attention_kwargs_name
(
self
.
pipeline_class
)
attention_kwargs_name
=
determine_attention_kwargs_name
(
self
.
pipeline_class
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -808,8 +784,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -808,8 +784,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet
and makes sure it works as expected - with unet
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -824,9 +799,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -824,9 +799,7 @@ class PeftLoraLoaderMixinTests:
# Fusing should still keep the LoRA layers
# Fusing should still keep the LoRA layers
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
self
.
assertTrue
(
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser"
)
self
.
assertTrue
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly set in denoiser"
)
...
@@ -846,8 +819,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -846,8 +819,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
and makes sure it works as expected
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -860,9 +832,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -860,9 +832,7 @@ class PeftLoraLoaderMixinTests:
pipe
.
unload_lora_weights
()
pipe
.
unload_lora_weights
()
# unloading should remove the LoRA layers
# unloading should remove the LoRA layers
self
.
assertFalse
(
self
.
assertFalse
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly unloaded in text encoder"
)
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly unloaded in text encoder"
)
self
.
assertFalse
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly unloaded in denoiser"
)
self
.
assertFalse
(
check_if_lora_correctly_set
(
denoiser
),
"Lora not correctly unloaded in denoiser"
)
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
...
@@ -885,8 +855,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -885,8 +855,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
and makes sure it works as expected
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -925,8 +894,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -925,8 +894,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
multiple adapters and set them
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -937,9 +905,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -937,9 +905,7 @@ class PeftLoraLoaderMixinTests:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
...
@@ -1002,8 +968,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1002,8 +968,7 @@ class PeftLoraLoaderMixinTests:
def
test_wrong_adapter_name_raises_error
(
self
):
def
test_wrong_adapter_name_raises_error
(
self
):
adapter_name
=
"adapter-1"
adapter_name
=
"adapter-1"
scheduler_cls
=
self
.
scheduler_classes
[
0
]
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1024,8 +989,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1024,8 +989,7 @@ class PeftLoraLoaderMixinTests:
def
test_multiple_wrong_adapter_name_raises_error
(
self
):
def
test_multiple_wrong_adapter_name_raises_error
(
self
):
adapter_name
=
"adapter-1"
adapter_name
=
"adapter-1"
scheduler_cls
=
self
.
scheduler_classes
[
0
]
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1054,8 +1018,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1054,8 +1018,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
Tests a simple inference with lora attached to text encoder and unet, attaches
one adapter and set different weights for different blocks (i.e. block lora)
one adapter and set different weights for different blocks (i.e. block lora)
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1111,8 +1074,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1111,8 +1074,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set different weights for different blocks (i.e. block lora)
multiple adapters and set different weights for different blocks (i.e. block lora)
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1123,9 +1085,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1123,9 +1085,7 @@ class PeftLoraLoaderMixinTests:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
...
@@ -1274,8 +1234,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1274,8 +1234,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set/delete them
multiple adapters and set/delete them
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1286,9 +1245,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1286,9 +1245,7 @@ class PeftLoraLoaderMixinTests:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
...
@@ -1368,8 +1325,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1368,8 +1325,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
multiple adapters and set them
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1380,9 +1336,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1380,9 +1336,7 @@ class PeftLoraLoaderMixinTests:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
self
.
assertTrue
(
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
...
@@ -1446,8 +1400,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1446,8 +1400,7 @@ class PeftLoraLoaderMixinTests:
strict
=
False
,
strict
=
False
,
)
)
def
test_lora_fuse_nan
(
self
):
def
test_lora_fuse_nan
(
self
):
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1455,9 +1408,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1455,9 +1408,7 @@ class PeftLoraLoaderMixinTests:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
denoiser
.
add_adapter
(
denoiser_lora_config
,
"adapter-1"
)
...
@@ -1466,9 +1417,9 @@ class PeftLoraLoaderMixinTests:
...
@@ -1466,9 +1417,9 @@ class PeftLoraLoaderMixinTests:
# corrupt one LoRA weight with `inf` values
# corrupt one LoRA weight with `inf` values
with
torch
.
no_grad
():
with
torch
.
no_grad
():
if
self
.
unet_kwargs
:
if
self
.
unet_kwargs
:
pipe
.
unet
.
mid_block
.
attentions
[
0
].
transformer_blocks
[
0
].
attn1
.
to_q
.
lora_A
[
pipe
.
unet
.
mid_block
.
attentions
[
0
].
transformer_blocks
[
0
].
attn1
.
to_q
.
lora_A
[
"adapter-1"
].
weight
+=
float
(
"adapter-1
"
"inf
"
].
weight
+=
float
(
"inf"
)
)
else
:
else
:
named_modules
=
[
name
for
name
,
_
in
pipe
.
transformer
.
named_modules
()]
named_modules
=
[
name
for
name
,
_
in
pipe
.
transformer
.
named_modules
()]
possible_tower_names
=
[
possible_tower_names
=
[
...
@@ -1481,9 +1432,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1481,9 +1432,7 @@ class PeftLoraLoaderMixinTests:
tower_name
for
tower_name
in
possible_tower_names
if
hasattr
(
pipe
.
transformer
,
tower_name
)
tower_name
for
tower_name
in
possible_tower_names
if
hasattr
(
pipe
.
transformer
,
tower_name
)
]
]
if
len
(
filtered_tower_names
)
==
0
:
if
len
(
filtered_tower_names
)
==
0
:
reason
=
(
reason
=
f
"`pipe.transformer` didn't have any of the following attributes:
{
possible_tower_names
}
."
f
"`pipe.transformer` didn't have any of the following attributes:
{
possible_tower_names
}
."
)
raise
ValueError
(
reason
)
raise
ValueError
(
reason
)
for
tower_name
in
filtered_tower_names
:
for
tower_name
in
filtered_tower_names
:
transformer_tower
=
getattr
(
pipe
.
transformer
,
tower_name
)
transformer_tower
=
getattr
(
pipe
.
transformer
,
tower_name
)
...
@@ -1508,8 +1457,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1508,8 +1457,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where we attach multiple adapters and check if the results
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
are the expected results
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1537,8 +1485,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1537,8 +1485,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where we attach multiple adapters and check if the results
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
are the expected results
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1612,8 +1559,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1612,8 +1559,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet and multi-adapter case
and makes sure it works as expected - with unet and multi-adapter case
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1624,9 +1570,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1624,9 +1570,7 @@ class PeftLoraLoaderMixinTests:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-1"
)
self
.
assertTrue
(
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
,
"adapter-2"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
...
@@ -1675,9 +1619,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1675,9 +1619,7 @@ class PeftLoraLoaderMixinTests:
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Unfuse should still keep LoRA layers"
check_if_lora_correctly_set
(
pipe
.
text_encoder_2
),
"Unfuse should still keep LoRA layers"
)
)
pipe
.
fuse_lora
(
pipe
.
fuse_lora
(
components
=
self
.
pipeline_class
.
_lora_loadable_modules
,
adapter_names
=
[
"adapter-2"
,
"adapter-1"
])
components
=
self
.
pipeline_class
.
_lora_loadable_modules
,
adapter_names
=
[
"adapter-2"
,
"adapter-1"
]
)
self
.
assertTrue
(
pipe
.
num_fused_loras
==
2
,
f
"
{
pipe
.
num_fused_loras
=
}
,
{
pipe
.
fused_loras
=
}
"
)
self
.
assertTrue
(
pipe
.
num_fused_loras
==
2
,
f
"
{
pipe
.
num_fused_loras
=
}
,
{
pipe
.
fused_loras
=
}
"
)
# Fusing should still keep the LoRA layers
# Fusing should still keep the LoRA layers
...
@@ -1693,8 +1635,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1693,8 +1635,7 @@ class PeftLoraLoaderMixinTests:
attention_kwargs_name
=
determine_attention_kwargs_name
(
self
.
pipeline_class
)
attention_kwargs_name
=
determine_attention_kwargs_name
(
self
.
pipeline_class
)
for
lora_scale
in
[
1.0
,
0.8
]:
for
lora_scale
in
[
1.0
,
0.8
]:
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1746,10 +1687,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1746,10 +1687,7 @@ class PeftLoraLoaderMixinTests:
@
require_peft_version_greater
(
peft_version
=
"0.9.0"
)
@
require_peft_version_greater
(
peft_version
=
"0.9.0"
)
def
test_simple_inference_with_dora
(
self
):
def
test_simple_inference_with_dora
(
self
):
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
use_dora
=
True
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
,
use_dora
=
True
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1768,9 +1706,8 @@ class PeftLoraLoaderMixinTests:
...
@@ -1768,9 +1706,8 @@ class PeftLoraLoaderMixinTests:
)
)
def
test_missing_keys_warning
(
self
):
def
test_missing_keys_warning
(
self
):
scheduler_cls
=
self
.
scheduler_classes
[
0
]
# Skip text encoder check for now as that is handled with `transformers`.
# Skip text encoder check for now as that is handled with `transformers`.
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1805,9 +1742,8 @@ class PeftLoraLoaderMixinTests:
...
@@ -1805,9 +1742,8 @@ class PeftLoraLoaderMixinTests:
self
.
assertTrue
(
missing_key
.
replace
(
f
"
{
component
}
."
,
""
)
in
cap_logger
.
out
.
replace
(
"default_0."
,
""
))
self
.
assertTrue
(
missing_key
.
replace
(
f
"
{
component
}
."
,
""
)
in
cap_logger
.
out
.
replace
(
"default_0."
,
""
))
def
test_unexpected_keys_warning
(
self
):
def
test_unexpected_keys_warning
(
self
):
scheduler_cls
=
self
.
scheduler_classes
[
0
]
# Skip text encoder check for now as that is handled with `transformers`.
# Skip text encoder check for now as that is handled with `transformers`.
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1842,8 +1778,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1842,8 +1778,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
and makes sure it works as expected
"""
"""
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1857,7 +1792,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1857,7 +1792,7 @@ class PeftLoraLoaderMixinTests:
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
if
self
.
has_two_text_encoders
or
self
.
has_three_text_encoders
:
pipe
.
text_encoder_2
=
torch
.
compile
(
pipe
.
text_encoder_2
,
mode
=
"reduce-overhead"
,
fullgraph
=
True
)
pipe
.
text_encoder_2
=
torch
.
compile
(
pipe
.
text_encoder_2
,
mode
=
"reduce-overhead"
,
fullgraph
=
True
)
# Just makes sure it works.
.
# Just makes sure it works.
_
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
_
=
pipe
(
**
inputs
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
def
test_modify_padding_mode
(
self
):
def
test_modify_padding_mode
(
self
):
...
@@ -1866,8 +1801,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1866,8 +1801,7 @@ class PeftLoraLoaderMixinTests:
if
isinstance
(
module
,
torch
.
nn
.
Conv2d
):
if
isinstance
(
module
,
torch
.
nn
.
Conv2d
):
module
.
padding_mode
=
mode
module
.
padding_mode
=
mode
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
_
,
_
=
self
.
get_dummy_components
()
components
,
_
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1879,9 +1813,8 @@ class PeftLoraLoaderMixinTests:
...
@@ -1879,9 +1813,8 @@ class PeftLoraLoaderMixinTests:
_
=
pipe
(
**
inputs
)[
0
]
_
=
pipe
(
**
inputs
)[
0
]
def
test_logs_info_when_no_lora_keys_found
(
self
):
def
test_logs_info_when_no_lora_keys_found
(
self
):
scheduler_cls
=
self
.
scheduler_classes
[
0
]
# Skip text encoder check for now as that is handled with `transformers`.
# Skip text encoder check for now as that is handled with `transformers`.
components
,
_
,
_
=
self
.
get_dummy_components
(
scheduler_cls
)
components
,
_
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1925,9 +1858,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1925,9 +1858,7 @@ class PeftLoraLoaderMixinTests:
def
test_set_adapters_match_attention_kwargs
(
self
):
def
test_set_adapters_match_attention_kwargs
(
self
):
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
attention_kwargs_name
=
determine_attention_kwargs_name
(
self
.
pipeline_class
)
attention_kwargs_name
=
determine_attention_kwargs_name
(
self
.
pipeline_class
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1991,7 +1922,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -1991,7 +1922,7 @@ class PeftLoraLoaderMixinTests:
def
test_lora_B_bias
(
self
):
def
test_lora_B_bias
(
self
):
# Currently, this test is only relevant for Flux Control LoRA as we are not
# Currently, this test is only relevant for Flux Control LoRA as we are not
# aware of any other LoRA checkpoint that has its `lora_B` biases trained.
# aware of any other LoRA checkpoint that has its `lora_B` biases trained.
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -2028,7 +1959,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -2028,7 +1959,7 @@ class PeftLoraLoaderMixinTests:
self
.
assertFalse
(
np
.
allclose
(
lora_bias_false_output
,
lora_bias_true_output
,
atol
=
1e-3
,
rtol
=
1e-3
))
self
.
assertFalse
(
np
.
allclose
(
lora_bias_false_output
,
lora_bias_true_output
,
atol
=
1e-3
,
rtol
=
1e-3
))
def
test_correct_lora_configs_with_different_ranks
(
self
):
def
test_correct_lora_configs_with_different_ranks
(
self
):
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -2114,7 +2045,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -2114,7 +2045,7 @@ class PeftLoraLoaderMixinTests:
self
.
assertEqual
(
submodule
.
bias
.
dtype
,
dtype_to_check
)
self
.
assertEqual
(
submodule
.
bias
.
dtype
,
dtype_to_check
)
def
initialize_pipeline
(
storage_dtype
=
None
,
compute_dtype
=
torch
.
float32
):
def
initialize_pipeline
(
storage_dtype
=
None
,
compute_dtype
=
torch
.
float32
):
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
,
dtype
=
compute_dtype
)
pipe
=
pipe
.
to
(
torch_device
,
dtype
=
compute_dtype
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -2181,7 +2112,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -2181,7 +2112,7 @@ class PeftLoraLoaderMixinTests:
self
.
assertTrue
(
module
.
_diffusers_hook
.
get_hook
(
_PEFT_AUTOCAST_DISABLE_HOOK
)
is
not
None
)
self
.
assertTrue
(
module
.
_diffusers_hook
.
get_hook
(
_PEFT_AUTOCAST_DISABLE_HOOK
)
is
not
None
)
# 1. Test forward with add_adapter
# 1. Test forward with add_adapter
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
,
dtype
=
compute_dtype
)
pipe
=
pipe
.
to
(
torch_device
,
dtype
=
compute_dtype
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -2211,7 +2142,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -2211,7 +2142,7 @@ class PeftLoraLoaderMixinTests:
)
)
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.safetensors"
)))
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.safetensors"
)))
components
,
_
,
_
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
,
dtype
=
compute_dtype
)
pipe
=
pipe
.
to
(
torch_device
,
dtype
=
compute_dtype
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -2231,10 +2162,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -2231,10 +2162,7 @@ class PeftLoraLoaderMixinTests:
@
parameterized
.
expand
([
4
,
8
,
16
])
@
parameterized
.
expand
([
4
,
8
,
16
])
def
test_lora_adapter_metadata_is_loaded_correctly
(
self
,
lora_alpha
):
def
test_lora_adapter_metadata_is_loaded_correctly
(
self
,
lora_alpha
):
scheduler_cls
=
self
.
scheduler_classes
[
0
]
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
lora_alpha
=
lora_alpha
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
,
lora_alpha
=
lora_alpha
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
pipe
,
_
=
self
.
add_adapters_to_pipeline
(
...
@@ -2280,10 +2208,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -2280,10 +2208,7 @@ class PeftLoraLoaderMixinTests:
@
parameterized
.
expand
([
4
,
8
,
16
])
@
parameterized
.
expand
([
4
,
8
,
16
])
def
test_lora_adapter_metadata_save_load_inference
(
self
,
lora_alpha
):
def
test_lora_adapter_metadata_save_load_inference
(
self
,
lora_alpha
):
scheduler_cls
=
self
.
scheduler_classes
[
0
]
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
lora_alpha
=
lora_alpha
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
,
lora_alpha
=
lora_alpha
)
pipe
=
self
.
pipeline_class
(
**
components
).
to
(
torch_device
)
pipe
=
self
.
pipeline_class
(
**
components
).
to
(
torch_device
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
...
@@ -2311,8 +2236,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -2311,8 +2236,7 @@ class PeftLoraLoaderMixinTests:
def
test_lora_unload_add_adapter
(
self
):
def
test_lora_unload_add_adapter
(
self
):
"""Tests if `unload_lora_weights()` -> `add_adapter()` works."""
"""Tests if `unload_lora_weights()` -> `add_adapter()` works."""
scheduler_cls
=
self
.
scheduler_classes
[
0
]
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
).
to
(
torch_device
)
pipe
=
self
.
pipeline_class
(
**
components
).
to
(
torch_device
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
...
@@ -2330,8 +2254,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -2330,8 +2254,7 @@ class PeftLoraLoaderMixinTests:
def
test_inference_load_delete_load_adapters
(
self
):
def
test_inference_load_delete_load_adapters
(
self
):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for
scheduler_cls
in
self
.
scheduler_classes
:
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
scheduler_cls
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -2341,9 +2264,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -2341,9 +2264,7 @@ class PeftLoraLoaderMixinTests:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
if
"text_encoder"
in
self
.
pipeline_class
.
_lora_loadable_modules
:
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
)
pipe
.
text_encoder
.
add_adapter
(
text_lora_config
)
self
.
assertTrue
(
self
.
assertTrue
(
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
check_if_lora_correctly_set
(
pipe
.
text_encoder
),
"Lora not correctly set in text encoder"
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
.
add_adapter
(
denoiser_lora_config
)
denoiser
.
add_adapter
(
denoiser_lora_config
)
...
@@ -2382,7 +2303,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -2382,7 +2303,7 @@ class PeftLoraLoaderMixinTests:
onload_device
=
torch_device
onload_device
=
torch_device
offload_device
=
torch
.
device
(
"cpu"
)
offload_device
=
torch
.
device
(
"cpu"
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
text_lora_config
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -2399,7 +2320,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -2399,7 +2320,7 @@ class PeftLoraLoaderMixinTests:
)
)
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.safetensors"
)))
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.safetensors"
)))
components
,
_
,
_
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
_
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
denoiser
=
pipe
.
transformer
if
self
.
unet_kwargs
is
None
else
pipe
.
unet
...
@@ -2451,7 +2372,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -2451,7 +2372,7 @@ class PeftLoraLoaderMixinTests:
@
require_torch_accelerator
@
require_torch_accelerator
def
test_lora_loading_model_cpu_offload
(
self
):
def
test_lora_loading_model_cpu_offload
(
self
):
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
_
,
_
,
inputs
=
self
.
get_dummy_inputs
(
with_generator
=
False
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
=
pipe
.
to
(
torch_device
)
...
@@ -2470,7 +2391,7 @@ class PeftLoraLoaderMixinTests:
...
@@ -2470,7 +2391,7 @@ class PeftLoraLoaderMixinTests:
save_directory
=
tmpdirname
,
safe_serialization
=
True
,
**
lora_state_dicts
save_directory
=
tmpdirname
,
safe_serialization
=
True
,
**
lora_state_dicts
)
)
# reinitialize the pipeline to mimic the inference workflow.
# reinitialize the pipeline to mimic the inference workflow.
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
(
self
.
scheduler_classes
[
0
]
)
components
,
_
,
denoiser_lora_config
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
.
enable_model_cpu_offload
(
device
=
torch_device
)
pipe
.
enable_model_cpu_offload
(
device
=
torch_device
)
pipe
.
load_lora_weights
(
tmpdirname
)
pipe
.
load_lora_weights
(
tmpdirname
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment