Unverified Commit e5d0a328 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[refactor] LoRA tests (#9481)

* refactor scheduler class usage

* reorder to make tests more readable

* remove pipeline specific checks and skip tests directly

* rewrite denoiser conditions cleaner

* bump tolerance for cog test
parent 14a1b86f
...@@ -48,6 +48,7 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -48,6 +48,7 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogVideoXPipeline pipeline_class = CogVideoXPipeline
scheduler_cls = CogVideoXDPMScheduler scheduler_cls = CogVideoXDPMScheduler
scheduler_kwargs = {"timestep_spacing": "trailing"} scheduler_kwargs = {"timestep_spacing": "trailing"}
scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler]
transformer_kwargs = { transformer_kwargs = {
"num_attention_heads": 4, "num_attention_heads": 4,
...@@ -126,8 +127,7 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -126,8 +127,7 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@skip_mps @skip_mps
def test_lora_fuse_nan(self): def test_lora_fuse_nan(self):
scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler] for scheduler_cls in self.scheduler_classes:
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -156,10 +156,22 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -156,10 +156,22 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
self.assertTrue(np.isnan(out).all()) self.assertTrue(np.isnan(out).all())
def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=5e-3) super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self): def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=5e-3) super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in CogVideoX.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.") @unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_partial_text_lora(self): def test_simple_inference_with_partial_text_lora(self):
......
...@@ -47,7 +47,7 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -47,7 +47,7 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler() scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {} scheduler_kwargs = {}
uses_flow_matching = True scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = { transformer_kwargs = {
"patch_size": 1, "patch_size": 1,
"in_channels": 4, "in_channels": 4,
...@@ -154,6 +154,14 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -154,6 +154,14 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
) )
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3)) self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Flux.")
def test_modify_padding_mode(self):
pass
@slow @slow
@require_torch_gpu @require_torch_gpu
......
...@@ -34,7 +34,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -34,7 +34,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
uses_flow_matching = True scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = { transformer_kwargs = {
"sample_size": 32, "sample_size": 32,
"patch_size": 1, "patch_size": 1,
...@@ -92,3 +92,19 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -92,3 +92,19 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
lora_filename = "lora_peft_format.safetensors" lora_filename = "lora_peft_format.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
@unittest.skip("Not supported in SD3.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in SD3.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass
@unittest.skip("Not supported in SD3.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in SD3.")
def test_modify_padding_mode(self):
pass
...@@ -24,7 +24,6 @@ import torch ...@@ -24,7 +24,6 @@ import torch
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
DDIMScheduler, DDIMScheduler,
FlowMatchEulerDiscreteScheduler,
LCMScheduler, LCMScheduler,
UNet2DConditionModel, UNet2DConditionModel,
) )
...@@ -69,9 +68,10 @@ def check_if_lora_correctly_set(model) -> bool: ...@@ -69,9 +68,10 @@ def check_if_lora_correctly_set(model) -> bool:
@require_peft_backend @require_peft_backend
class PeftLoraLoaderMixinTests: class PeftLoraLoaderMixinTests:
pipeline_class = None pipeline_class = None
scheduler_cls = None scheduler_cls = None
scheduler_kwargs = None scheduler_kwargs = None
uses_flow_matching = False scheduler_classes = [DDIMScheduler, LCMScheduler]
has_two_text_encoders = False has_two_text_encoders = False
has_three_text_encoders = False has_three_text_encoders = False
...@@ -205,13 +205,7 @@ class PeftLoraLoaderMixinTests: ...@@ -205,13 +205,7 @@ class PeftLoraLoaderMixinTests:
""" """
Tests a simple inference and makes sure it works as expected Tests a simple inference and makes sure it works as expected
""" """
# TODO(aryan): Some of the assumptions made here in many different tests are incorrect for CogVideoX. for scheduler_cls in self.scheduler_classes:
# For example, we need to test with CogVideoXDDIMScheduler and CogVideoDPMScheduler instead of DDIMScheduler
# and LCMScheduler, which are not supported by it.
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -226,10 +220,7 @@ class PeftLoraLoaderMixinTests: ...@@ -226,10 +220,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached on the text encoder Tests a simple inference with lora attached on the text encoder
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( for scheduler_cls in self.scheduler_classes:
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -260,17 +251,16 @@ class PeftLoraLoaderMixinTests: ...@@ -260,17 +251,16 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached on the text encoder + scale argument Tests a simple inference with lora attached on the text encoder + scale argument
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
# TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]: for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
if possible_attention_kwargs in call_signature_keys: if possible_attention_kwargs in call_signature_keys:
attention_kwargs_name = possible_attention_kwargs attention_kwargs_name = possible_attention_kwargs
break break
assert attention_kwargs_name is not None assert attention_kwargs_name is not None
for scheduler_cls in scheduler_classes: for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -317,10 +307,7 @@ class PeftLoraLoaderMixinTests: ...@@ -317,10 +307,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( for scheduler_cls in self.scheduler_classes:
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -360,10 +347,7 @@ class PeftLoraLoaderMixinTests: ...@@ -360,10 +347,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder, then unloads the lora weights Tests a simple inference with lora attached to text encoder, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( for scheduler_cls in self.scheduler_classes:
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -410,10 +394,7 @@ class PeftLoraLoaderMixinTests: ...@@ -410,10 +394,7 @@ class PeftLoraLoaderMixinTests:
""" """
Tests a simple usecase where users could use saving utilities for LoRA. Tests a simple usecase where users could use saving utilities for LoRA.
""" """
scheduler_classes = ( for scheduler_cls in self.scheduler_classes:
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -490,10 +471,7 @@ class PeftLoraLoaderMixinTests: ...@@ -490,10 +471,7 @@ class PeftLoraLoaderMixinTests:
with different ranks and some adapters removed with different ranks and some adapters removed
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( for scheduler_cls in self.scheduler_classes:
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, _, _ = self.get_dummy_components(scheduler_cls) components, _, _ = self.get_dummy_components(scheduler_cls)
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324). # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
text_lora_config = LoraConfig( text_lora_config = LoraConfig(
...@@ -555,10 +533,7 @@ class PeftLoraLoaderMixinTests: ...@@ -555,10 +533,7 @@ class PeftLoraLoaderMixinTests:
""" """
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
""" """
scheduler_classes = ( for scheduler_cls in self.scheduler_classes:
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -609,10 +584,7 @@ class PeftLoraLoaderMixinTests: ...@@ -609,10 +584,7 @@ class PeftLoraLoaderMixinTests:
""" """
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
""" """
scheduler_classes = ( for scheduler_cls in self.scheduler_classes:
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -628,13 +600,9 @@ class PeftLoraLoaderMixinTests: ...@@ -628,13 +600,9 @@ class PeftLoraLoaderMixinTests:
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
if self.unet_kwargs is not None: denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.unet.add_adapter(denoiser_lora_config) denoiser.add_adapter(denoiser_lora_config)
else: self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
pipe.transformer.add_adapter(denoiser_lora_config)
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in Unet")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
...@@ -652,10 +620,7 @@ class PeftLoraLoaderMixinTests: ...@@ -652,10 +620,7 @@ class PeftLoraLoaderMixinTests:
else None else None
) )
if self.unet_kwargs is not None: denoiser_state_dict = get_peft_model_state_dict(denoiser)
denoiser_state_dict = get_peft_model_state_dict(pipe.unet)
else:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
saving_kwargs = { saving_kwargs = {
"save_directory": tmpdirname, "save_directory": tmpdirname,
...@@ -689,8 +654,7 @@ class PeftLoraLoaderMixinTests: ...@@ -689,8 +654,7 @@ class PeftLoraLoaderMixinTests:
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
...@@ -708,9 +672,6 @@ class PeftLoraLoaderMixinTests: ...@@ -708,9 +672,6 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached on the text encoder + Unet + scale argument Tests a simple inference with lora attached on the text encoder + Unet + scale argument
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]: for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
if possible_attention_kwargs in call_signature_keys: if possible_attention_kwargs in call_signature_keys:
...@@ -718,7 +679,7 @@ class PeftLoraLoaderMixinTests: ...@@ -718,7 +679,7 @@ class PeftLoraLoaderMixinTests:
break break
assert attention_kwargs_name is not None assert attention_kwargs_name is not None
for scheduler_cls in scheduler_classes: for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -734,13 +695,9 @@ class PeftLoraLoaderMixinTests: ...@@ -734,13 +695,9 @@ class PeftLoraLoaderMixinTests:
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
if self.unet_kwargs is not None: denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.unet.add_adapter(denoiser_lora_config) denoiser.add_adapter(denoiser_lora_config)
else: self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
pipe.transformer.add_adapter(denoiser_lora_config)
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
...@@ -781,10 +738,7 @@ class PeftLoraLoaderMixinTests: ...@@ -781,10 +738,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet and makes sure it works as expected - with unet
""" """
scheduler_classes = ( for scheduler_cls in self.scheduler_classes:
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -800,13 +754,9 @@ class PeftLoraLoaderMixinTests: ...@@ -800,13 +754,9 @@ class PeftLoraLoaderMixinTests:
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
if self.unet_kwargs is not None: denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.unet.add_adapter(denoiser_lora_config) denoiser.add_adapter(denoiser_lora_config)
else: self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
pipe.transformer.add_adapter(denoiser_lora_config)
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
...@@ -823,8 +773,7 @@ class PeftLoraLoaderMixinTests: ...@@ -823,8 +773,7 @@ class PeftLoraLoaderMixinTests:
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
...@@ -842,10 +791,7 @@ class PeftLoraLoaderMixinTests: ...@@ -842,10 +791,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( for scheduler_cls in self.scheduler_classes:
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -861,12 +807,9 @@ class PeftLoraLoaderMixinTests: ...@@ -861,12 +807,9 @@ class PeftLoraLoaderMixinTests:
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
if self.unet_kwargs is not None: denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.unet.add_adapter(denoiser_lora_config) denoiser.add_adapter(denoiser_lora_config)
else: self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
pipe.transformer.add_adapter(denoiser_lora_config)
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
...@@ -880,10 +823,7 @@ class PeftLoraLoaderMixinTests: ...@@ -880,10 +823,7 @@ class PeftLoraLoaderMixinTests:
self.assertFalse( self.assertFalse(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder" check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
) )
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser")
self.assertFalse(
check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly unloaded in denoiser"
)
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
...@@ -905,10 +845,7 @@ class PeftLoraLoaderMixinTests: ...@@ -905,10 +845,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( for scheduler_cls in self.scheduler_classes:
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -921,13 +858,9 @@ class PeftLoraLoaderMixinTests: ...@@ -921,13 +858,9 @@ class PeftLoraLoaderMixinTests:
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
if self.unet_kwargs is not None: denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.unet.add_adapter(denoiser_lora_config) denoiser.add_adapter(denoiser_lora_config)
else: self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
pipe.transformer.add_adapter(denoiser_lora_config)
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
...@@ -946,8 +879,7 @@ class PeftLoraLoaderMixinTests: ...@@ -946,8 +879,7 @@ class PeftLoraLoaderMixinTests:
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Unfuse should still keep LoRA layers")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
...@@ -966,10 +898,7 @@ class PeftLoraLoaderMixinTests: ...@@ -966,10 +898,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them multiple adapters and set them
""" """
scheduler_classes = ( for scheduler_cls in self.scheduler_classes:
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -985,17 +914,10 @@ class PeftLoraLoaderMixinTests: ...@@ -985,17 +914,10 @@ class PeftLoraLoaderMixinTests:
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
if self.unet_kwargs is not None: denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
else: denoiser.add_adapter(denoiser_lora_config, "adapter-2")
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
...@@ -1041,15 +963,9 @@ class PeftLoraLoaderMixinTests: ...@@ -1041,15 +963,9 @@ class PeftLoraLoaderMixinTests:
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
""" """
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
one adapter and set differnt weights for different blocks (i.e. block lora) one adapter and set different weights for different blocks (i.e. block lora)
""" """
if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "CogVideoXPipeline"]: for scheduler_cls in self.scheduler_classes:
return
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -1059,14 +975,11 @@ class PeftLoraLoaderMixinTests: ...@@ -1059,14 +975,11 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
...@@ -1109,13 +1022,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1109,13 +1022,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set differnt weights for different blocks (i.e. block lora) multiple adapters and set differnt weights for different blocks (i.e. block lora)
""" """
if self.pipeline_class.__name__ == "StableDiffusion3Pipeline": for scheduler_cls in self.scheduler_classes:
return
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -1131,17 +1038,10 @@ class PeftLoraLoaderMixinTests: ...@@ -1131,17 +1038,10 @@ class PeftLoraLoaderMixinTests:
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
if self.unet_kwargs is not None: denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
else: denoiser.add_adapter(denoiser_lora_config, "adapter-2")
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
...@@ -1193,8 +1093,6 @@ class PeftLoraLoaderMixinTests: ...@@ -1193,8 +1093,6 @@ class PeftLoraLoaderMixinTests:
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
"""Tests that any valid combination of lora block scales can be used in pipe.set_adapter""" """Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""
if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline", "CogVideoXPipeline"]:
return
def updown_options(blocks_with_tf, layers_per_block, value): def updown_options(blocks_with_tf, layers_per_block, value):
""" """
...@@ -1266,10 +1164,9 @@ class PeftLoraLoaderMixinTests: ...@@ -1266,10 +1164,9 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
else: denoiser.add_adapter(denoiser_lora_config, "adapter-1")
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules lora_loadable_components = self.pipeline_class._lora_loadable_modules
...@@ -1288,10 +1185,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1288,10 +1185,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set/delete them multiple adapters and set/delete them
""" """
scheduler_classes = ( for scheduler_cls in self.scheduler_classes:
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -1307,18 +1201,10 @@ class PeftLoraLoaderMixinTests: ...@@ -1307,18 +1201,10 @@ class PeftLoraLoaderMixinTests:
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
if self.unet_kwargs is not None: denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
else: denoiser.add_adapter(denoiser_lora_config, "adapter-2")
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules lora_loadable_components = self.pipeline_class._lora_loadable_modules
...@@ -1373,14 +1259,10 @@ class PeftLoraLoaderMixinTests: ...@@ -1373,14 +1259,10 @@ class PeftLoraLoaderMixinTests:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
if self.unet_kwargs is not None: denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
else: denoiser.add_adapter(denoiser_lora_config, "adapter-2")
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.set_adapters(["adapter-1", "adapter-2"])
pipe.delete_adapters(["adapter-1", "adapter-2"]) pipe.delete_adapters(["adapter-1", "adapter-2"])
...@@ -1397,10 +1279,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1397,10 +1279,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them multiple adapters and set them
""" """
scheduler_classes = ( for scheduler_cls in self.scheduler_classes:
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -1416,17 +1295,10 @@ class PeftLoraLoaderMixinTests: ...@@ -1416,17 +1295,10 @@ class PeftLoraLoaderMixinTests:
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
if self.unet_kwargs is not None: denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
else: denoiser.add_adapter(denoiser_lora_config, "adapter-2")
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules lora_loadable_components = self.pipeline_class._lora_loadable_modules
...@@ -1471,7 +1343,6 @@ class PeftLoraLoaderMixinTests: ...@@ -1471,7 +1343,6 @@ class PeftLoraLoaderMixinTests:
) )
pipe.disable_lora() pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
...@@ -1481,10 +1352,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1481,10 +1352,7 @@ class PeftLoraLoaderMixinTests:
@skip_mps @skip_mps
def test_lora_fuse_nan(self): def test_lora_fuse_nan(self):
scheduler_classes = ( for scheduler_cls in self.scheduler_classes:
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -1497,13 +1365,9 @@ class PeftLoraLoaderMixinTests: ...@@ -1497,13 +1365,9 @@ class PeftLoraLoaderMixinTests:
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
if self.unet_kwargs is not None: denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
else: self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
# corrupt one LoRA weight with `inf` values # corrupt one LoRA weight with `inf` values
with torch.no_grad(): with torch.no_grad():
...@@ -1520,7 +1384,6 @@ class PeftLoraLoaderMixinTests: ...@@ -1520,7 +1384,6 @@ class PeftLoraLoaderMixinTests:
# without we should not see an error, but every image will be black # without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe("test", num_inference_steps=2, output_type="np")[0] out = pipe("test", num_inference_steps=2, output_type="np")[0]
self.assertTrue(np.isnan(out).all()) self.assertTrue(np.isnan(out).all())
...@@ -1530,10 +1393,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1530,10 +1393,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where we attach multiple adapters and check if the results Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results are the expected results
""" """
scheduler_classes = ( for scheduler_cls in self.scheduler_classes:
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -1541,19 +1401,15 @@ class PeftLoraLoaderMixinTests: ...@@ -1541,19 +1401,15 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
else: denoiser.add_adapter(denoiser_lora_config, "adapter-1")
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
adapter_names = pipe.get_active_adapters() adapter_names = pipe.get_active_adapters()
self.assertListEqual(adapter_names, ["adapter-1"]) self.assertListEqual(adapter_names, ["adapter-1"])
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
if self.unet_kwargs is not None: denoiser.add_adapter(denoiser_lora_config, "adapter-2")
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
adapter_names = pipe.get_active_adapters() adapter_names = pipe.get_active_adapters()
self.assertListEqual(adapter_names, ["adapter-2"]) self.assertListEqual(adapter_names, ["adapter-2"])
...@@ -1566,10 +1422,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1566,10 +1422,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where we attach multiple adapters and check if the results Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results are the expected results
""" """
scheduler_classes = ( for scheduler_cls in self.scheduler_classes:
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -1583,12 +1436,9 @@ class PeftLoraLoaderMixinTests: ...@@ -1583,12 +1436,9 @@ class PeftLoraLoaderMixinTests:
if self.unet_kwargs is not None: if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
if self.unet_kwargs is not None:
dicts_to_be_checked.update({"unet": ["adapter-1"]}) dicts_to_be_checked.update({"unet": ["adapter-1"]})
else: else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
dicts_to_be_checked.update({"transformer": ["adapter-1"]}) dicts_to_be_checked.update({"transformer": ["adapter-1"]})
self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
...@@ -1601,12 +1451,9 @@ class PeftLoraLoaderMixinTests: ...@@ -1601,12 +1451,9 @@ class PeftLoraLoaderMixinTests:
if self.unet_kwargs is not None: if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
if self.unet_kwargs is not None:
dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]}) dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
else: else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]}) dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
...@@ -1629,18 +1476,15 @@ class PeftLoraLoaderMixinTests: ...@@ -1629,18 +1476,15 @@ class PeftLoraLoaderMixinTests:
) )
# 4. # 4.
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-3")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3")
dicts_to_be_checked = {} dicts_to_be_checked = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
if self.unet_kwargs is not None: if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-3")
dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]}) dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]})
else: else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3")
dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]}) dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]})
self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
...@@ -1653,10 +1497,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1653,10 +1497,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet and multi-adapter case and makes sure it works as expected - with unet and multi-adapter case
""" """
scheduler_classes = ( for scheduler_cls in self.scheduler_classes:
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -1672,22 +1513,16 @@ class PeftLoraLoaderMixinTests: ...@@ -1672,22 +1513,16 @@ class PeftLoraLoaderMixinTests:
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
if self.unet_kwargs is not None: denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
# Attach a second adapter # Attach a second adapter
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
if self.unet_kwargs is not None: denoiser.add_adapter(denoiser_lora_config, "adapter-2")
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules lora_loadable_components = self.pipeline_class._lora_loadable_modules
...@@ -1729,10 +1564,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1729,10 +1564,7 @@ class PeftLoraLoaderMixinTests:
@require_peft_version_greater(peft_version="0.9.0") @require_peft_version_greater(peft_version="0.9.0")
def test_simple_inference_with_dora(self): def test_simple_inference_with_dora(self):
scheduler_classes = ( for scheduler_cls in self.scheduler_classes:
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components( components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, use_dora=True scheduler_cls, use_dora=True
) )
...@@ -1745,14 +1577,11 @@ class PeftLoraLoaderMixinTests: ...@@ -1745,14 +1577,11 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(output_no_dora_lora.shape == self.output_shape) self.assertTrue(output_no_dora_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config)
else:
pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules lora_loadable_components = self.pipeline_class._lora_loadable_modules
...@@ -1775,10 +1604,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1775,10 +1604,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( for scheduler_cls in self.scheduler_classes:
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -1786,14 +1612,11 @@ class PeftLoraLoaderMixinTests: ...@@ -1786,14 +1612,11 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config)
else:
pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) pipe.text_encoder_2.add_adapter(text_lora_config)
...@@ -1811,18 +1634,12 @@ class PeftLoraLoaderMixinTests: ...@@ -1811,18 +1634,12 @@ class PeftLoraLoaderMixinTests:
_ = pipe(**inputs, generator=torch.manual_seed(0))[0] _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline", "CogVideoXPipeline"]:
return
def set_pad_mode(network, mode="circular"): def set_pad_mode(network, mode="circular"):
for _, module in network.named_modules(): for _, module in network.named_modules():
if isinstance(module, torch.nn.Conv2d): if isinstance(module, torch.nn.Conv2d):
module.padding_mode = mode module.padding_mode = mode
scheduler_classes = ( for scheduler_cls in self.scheduler_classes:
[FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, _, _ = self.get_dummy_components(scheduler_cls) components, _, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment