Unverified Commit 09e777a3 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] Single scheduler in lora tests (#12315)

* single scheduler please.

* up

* up

* up
parent a72bc0c4
...@@ -43,7 +43,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -43,7 +43,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = AuraFlowPipeline pipeline_class = AuraFlowPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -21,7 +21,6 @@ from transformers import AutoTokenizer, T5EncoderModel ...@@ -21,7 +21,6 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import ( from diffusers import (
AutoencoderKLCogVideoX, AutoencoderKLCogVideoX,
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler, CogVideoXDPMScheduler,
CogVideoXPipeline, CogVideoXPipeline,
CogVideoXTransformer3DModel, CogVideoXTransformer3DModel,
...@@ -44,7 +43,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -44,7 +43,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogVideoXPipeline pipeline_class = CogVideoXPipeline
scheduler_cls = CogVideoXDPMScheduler scheduler_cls = CogVideoXDPMScheduler
scheduler_kwargs = {"timestep_spacing": "trailing"} scheduler_kwargs = {"timestep_spacing": "trailing"}
scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler]
transformer_kwargs = { transformer_kwargs = {
"num_attention_heads": 4, "num_attention_heads": 4,
......
...@@ -50,7 +50,6 @@ class TokenizerWrapper: ...@@ -50,7 +50,6 @@ class TokenizerWrapper:
class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogView4Pipeline pipeline_class = CogView4Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
...@@ -124,8 +123,7 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -124,8 +123,7 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
""" """
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
""" """
for scheduler_cls in self.scheduler_classes: components, _, _ = self.get_dummy_components()
components, _, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
......
...@@ -55,9 +55,8 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa ...@@ -55,9 +55,8 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
@require_peft_backend @require_peft_backend
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler() scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = { transformer_kwargs = {
"patch_size": 1, "patch_size": 1,
"in_channels": 4, "in_channels": 4,
...@@ -282,9 +281,8 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -282,9 +281,8 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxControlPipeline pipeline_class = FluxControlPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler() scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = { transformer_kwargs = {
"patch_size": 1, "patch_size": 1,
"in_channels": 8, "in_channels": 8,
......
...@@ -51,7 +51,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -51,7 +51,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = HunyuanVideoPipeline pipeline_class = HunyuanVideoPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = LTXPipeline pipeline_class = LTXPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -39,7 +39,6 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa ...@@ -39,7 +39,6 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = Lumina2Pipeline pipeline_class = Lumina2Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
...@@ -141,8 +140,7 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -141,8 +140,7 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
strict=False, strict=False,
) )
def test_lora_fuse_nan(self): def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -150,9 +148,7 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -150,9 +148,7 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue( self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
......
...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = MochiPipeline pipeline_class = MochiPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = QwenImagePipeline pipeline_class = QwenImagePipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -31,9 +31,8 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -31,9 +31,8 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = SanaPipeline pipeline_class = SanaPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler(shift=7.0) scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {"shift": 7.0}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = { transformer_kwargs = {
"patch_size": 1, "patch_size": 1,
"in_channels": 4, "in_channels": 4,
......
...@@ -55,7 +55,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -55,7 +55,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = { transformer_kwargs = {
"sample_size": 32, "sample_size": 32,
"patch_size": 1, "patch_size": 1,
......
...@@ -42,7 +42,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -42,7 +42,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanPipeline pipeline_class = WanPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -50,7 +50,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -50,7 +50,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanVACEPipeline pipeline_class = WanVACEPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
...@@ -165,9 +164,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -165,9 +164,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_peft_version_greater("0.13.2") @require_peft_version_greater("0.13.2")
def test_lora_exclude_modules_wanvace(self): def test_lora_exclude_modules_wanvace(self):
scheduler_cls = self.scheduler_classes[0]
exclude_module_name = "vace_blocks.0.proj_out" exclude_module_name = "vace_blocks.0.proj_out"
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device) pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
......
...@@ -26,8 +26,6 @@ from parameterized import parameterized ...@@ -26,8 +26,6 @@ from parameterized import parameterized
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
DDIMScheduler,
LCMScheduler,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils import logging from diffusers.utils import logging
...@@ -109,7 +107,6 @@ class PeftLoraLoaderMixinTests: ...@@ -109,7 +107,6 @@ class PeftLoraLoaderMixinTests:
scheduler_cls = None scheduler_cls = None
scheduler_kwargs = None scheduler_kwargs = None
scheduler_classes = [DDIMScheduler, LCMScheduler]
has_two_text_encoders = False has_two_text_encoders = False
has_three_text_encoders = False has_three_text_encoders = False
...@@ -129,13 +126,13 @@ class PeftLoraLoaderMixinTests: ...@@ -129,13 +126,13 @@ class PeftLoraLoaderMixinTests:
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): def get_dummy_components(self, use_dora=False, lora_alpha=None):
if self.unet_kwargs and self.transformer_kwargs: if self.unet_kwargs and self.transformer_kwargs:
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
if self.has_two_text_encoders and self.has_three_text_encoders: if self.has_two_text_encoders and self.has_three_text_encoders:
raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.") raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls scheduler_cls = self.scheduler_cls
rank = 4 rank = 4
lora_alpha = rank if lora_alpha is None else lora_alpha lora_alpha = rank if lora_alpha is None else lora_alpha
...@@ -319,8 +316,7 @@ class PeftLoraLoaderMixinTests: ...@@ -319,8 +316,7 @@ class PeftLoraLoaderMixinTests:
""" """
Tests a simple inference and makes sure it works as expected Tests a simple inference and makes sure it works as expected
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, _ = self.get_dummy_components()
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -334,8 +330,7 @@ class PeftLoraLoaderMixinTests: ...@@ -334,8 +330,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached on the text encoder Tests a simple inference with lora attached on the text encoder
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, _ = self.get_dummy_components()
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -354,17 +349,14 @@ class PeftLoraLoaderMixinTests: ...@@ -354,17 +349,14 @@ class PeftLoraLoaderMixinTests:
@require_peft_version_greater("0.13.1") @require_peft_version_greater("0.13.1")
def test_low_cpu_mem_usage_with_injection(self): def test_low_cpu_mem_usage_with_injection(self):
"""Tests if we can inject LoRA state dict with low_cpu_mem_usage.""" """Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True) inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True)
self.assertTrue( self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder.")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder."
)
self.assertTrue( self.assertTrue(
"meta" in {p.device.type for p in pipe.text_encoder.parameters()}, "meta" in {p.device.type for p in pipe.text_encoder.parameters()},
"The LoRA params should be on 'meta' device.", "The LoRA params should be on 'meta' device.",
...@@ -416,9 +408,7 @@ class PeftLoraLoaderMixinTests: ...@@ -416,9 +408,7 @@ class PeftLoraLoaderMixinTests:
@require_transformers_version_greater("4.45.2") @require_transformers_version_greater("4.45.2")
def test_low_cpu_mem_usage_with_loading(self): def test_low_cpu_mem_usage_with_loading(self):
"""Tests if we can load LoRA state dict with low_cpu_mem_usage.""" """Tests if we can load LoRA state dict with low_cpu_mem_usage."""
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -460,9 +450,7 @@ class PeftLoraLoaderMixinTests: ...@@ -460,9 +450,7 @@ class PeftLoraLoaderMixinTests:
images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3
),
"Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.", "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.",
) )
...@@ -472,9 +460,7 @@ class PeftLoraLoaderMixinTests: ...@@ -472,9 +460,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
components, text_lora_config, _ = self.get_dummy_components()
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -511,8 +497,7 @@ class PeftLoraLoaderMixinTests: ...@@ -511,8 +497,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, _ = self.get_dummy_components()
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -543,8 +528,7 @@ class PeftLoraLoaderMixinTests: ...@@ -543,8 +528,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder, then unloads the lora weights Tests a simple inference with lora attached to text encoder, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, _ = self.get_dummy_components()
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -557,9 +541,7 @@ class PeftLoraLoaderMixinTests: ...@@ -557,9 +541,7 @@ class PeftLoraLoaderMixinTests:
pipe.unload_lora_weights() pipe.unload_lora_weights()
# unloading should remove the LoRA layers # unloading should remove the LoRA layers
self.assertFalse( self.assertFalse(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
)
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
...@@ -578,8 +560,7 @@ class PeftLoraLoaderMixinTests: ...@@ -578,8 +560,7 @@ class PeftLoraLoaderMixinTests:
""" """
Tests a simple usecase where users could use saving utilities for LoRA. Tests a simple usecase where users could use saving utilities for LoRA.
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, _ = self.get_dummy_components()
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -620,8 +601,7 @@ class PeftLoraLoaderMixinTests: ...@@ -620,8 +601,7 @@ class PeftLoraLoaderMixinTests:
with different ranks and some adapters removed with different ranks and some adapters removed
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in self.scheduler_classes: components, _, _ = self.get_dummy_components()
components, _, _ = self.get_dummy_components(scheduler_cls)
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324). # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
text_lora_config = LoraConfig( text_lora_config = LoraConfig(
r=4, r=4,
...@@ -680,8 +660,7 @@ class PeftLoraLoaderMixinTests: ...@@ -680,8 +660,7 @@ class PeftLoraLoaderMixinTests:
""" """
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, _ = self.get_dummy_components()
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -723,8 +702,7 @@ class PeftLoraLoaderMixinTests: ...@@ -723,8 +702,7 @@ class PeftLoraLoaderMixinTests:
""" """
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -763,9 +741,7 @@ class PeftLoraLoaderMixinTests: ...@@ -763,9 +741,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -808,8 +784,7 @@ class PeftLoraLoaderMixinTests: ...@@ -808,8 +784,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet and makes sure it works as expected - with unet
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -824,9 +799,7 @@ class PeftLoraLoaderMixinTests: ...@@ -824,9 +799,7 @@ class PeftLoraLoaderMixinTests:
# Fusing should still keep the LoRA layers # Fusing should still keep the LoRA layers
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue( self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser") self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
...@@ -846,8 +819,7 @@ class PeftLoraLoaderMixinTests: ...@@ -846,8 +819,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -860,9 +832,7 @@ class PeftLoraLoaderMixinTests: ...@@ -860,9 +832,7 @@ class PeftLoraLoaderMixinTests:
pipe.unload_lora_weights() pipe.unload_lora_weights()
# unloading should remove the LoRA layers # unloading should remove the LoRA layers
self.assertFalse( self.assertFalse(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
)
self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser") self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
...@@ -885,8 +855,7 @@ class PeftLoraLoaderMixinTests: ...@@ -885,8 +855,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -925,8 +894,7 @@ class PeftLoraLoaderMixinTests: ...@@ -925,8 +894,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them multiple adapters and set them
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -937,9 +905,7 @@ class PeftLoraLoaderMixinTests: ...@@ -937,9 +905,7 @@ class PeftLoraLoaderMixinTests:
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue( self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
...@@ -1002,8 +968,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1002,8 +968,7 @@ class PeftLoraLoaderMixinTests:
def test_wrong_adapter_name_raises_error(self): def test_wrong_adapter_name_raises_error(self):
adapter_name = "adapter-1" adapter_name = "adapter-1"
scheduler_cls = self.scheduler_classes[0] components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1024,8 +989,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1024,8 +989,7 @@ class PeftLoraLoaderMixinTests:
def test_multiple_wrong_adapter_name_raises_error(self): def test_multiple_wrong_adapter_name_raises_error(self):
adapter_name = "adapter-1" adapter_name = "adapter-1"
scheduler_cls = self.scheduler_classes[0] components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1054,8 +1018,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1054,8 +1018,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
one adapter and set different weights for different blocks (i.e. block lora) one adapter and set different weights for different blocks (i.e. block lora)
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1111,8 +1074,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1111,8 +1074,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set different weights for different blocks (i.e. block lora) multiple adapters and set different weights for different blocks (i.e. block lora)
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1123,9 +1085,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1123,9 +1085,7 @@ class PeftLoraLoaderMixinTests:
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue( self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
...@@ -1274,8 +1234,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1274,8 +1234,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set/delete them multiple adapters and set/delete them
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1286,9 +1245,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1286,9 +1245,7 @@ class PeftLoraLoaderMixinTests:
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue( self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
...@@ -1368,8 +1325,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1368,8 +1325,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them multiple adapters and set them
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1380,9 +1336,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1380,9 +1336,7 @@ class PeftLoraLoaderMixinTests:
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue( self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
...@@ -1446,8 +1400,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1446,8 +1400,7 @@ class PeftLoraLoaderMixinTests:
strict=False, strict=False,
) )
def test_lora_fuse_nan(self): def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1455,9 +1408,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1455,9 +1408,7 @@ class PeftLoraLoaderMixinTests:
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue( self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
...@@ -1466,9 +1417,9 @@ class PeftLoraLoaderMixinTests: ...@@ -1466,9 +1417,9 @@ class PeftLoraLoaderMixinTests:
# corrupt one LoRA weight with `inf` values # corrupt one LoRA weight with `inf` values
with torch.no_grad(): with torch.no_grad():
if self.unet_kwargs: if self.unet_kwargs:
pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A[ pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float(
"adapter-1" "inf"
].weight += float("inf") )
else: else:
named_modules = [name for name, _ in pipe.transformer.named_modules()] named_modules = [name for name, _ in pipe.transformer.named_modules()]
possible_tower_names = [ possible_tower_names = [
...@@ -1481,9 +1432,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1481,9 +1432,7 @@ class PeftLoraLoaderMixinTests:
tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name) tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
] ]
if len(filtered_tower_names) == 0: if len(filtered_tower_names) == 0:
reason = ( reason = f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}."
f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}."
)
raise ValueError(reason) raise ValueError(reason)
for tower_name in filtered_tower_names: for tower_name in filtered_tower_names:
transformer_tower = getattr(pipe.transformer, tower_name) transformer_tower = getattr(pipe.transformer, tower_name)
...@@ -1508,8 +1457,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1508,8 +1457,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where we attach multiple adapters and check if the results Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results are the expected results
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1537,8 +1485,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1537,8 +1485,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where we attach multiple adapters and check if the results Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results are the expected results
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1612,8 +1559,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1612,8 +1559,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet and multi-adapter case and makes sure it works as expected - with unet and multi-adapter case
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1624,9 +1570,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1624,9 +1570,7 @@ class PeftLoraLoaderMixinTests:
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue( self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
...@@ -1675,9 +1619,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1675,9 +1619,7 @@ class PeftLoraLoaderMixinTests:
check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
) )
pipe.fuse_lora( pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"])
components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]
)
self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
# Fusing should still keep the LoRA layers # Fusing should still keep the LoRA layers
...@@ -1693,8 +1635,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1693,8 +1635,7 @@ class PeftLoraLoaderMixinTests:
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
for lora_scale in [1.0, 0.8]: for lora_scale in [1.0, 0.8]:
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1746,10 +1687,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1746,10 +1687,7 @@ class PeftLoraLoaderMixinTests:
@require_peft_version_greater(peft_version="0.9.0") @require_peft_version_greater(peft_version="0.9.0")
def test_simple_inference_with_dora(self): def test_simple_inference_with_dora(self):
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, use_dora=True
)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1768,9 +1706,8 @@ class PeftLoraLoaderMixinTests: ...@@ -1768,9 +1706,8 @@ class PeftLoraLoaderMixinTests:
) )
def test_missing_keys_warning(self): def test_missing_keys_warning(self):
scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`. # Skip text encoder check for now as that is handled with `transformers`.
components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1805,9 +1742,8 @@ class PeftLoraLoaderMixinTests: ...@@ -1805,9 +1742,8 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", "")) self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", ""))
def test_unexpected_keys_warning(self): def test_unexpected_keys_warning(self):
scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`. # Skip text encoder check for now as that is handled with `transformers`.
components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1842,8 +1778,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1842,8 +1778,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1857,7 +1792,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1857,7 +1792,7 @@ class PeftLoraLoaderMixinTests:
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True) pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
# Just makes sure it works.. # Just makes sure it works.
_ = pipe(**inputs, generator=torch.manual_seed(0))[0] _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
...@@ -1866,8 +1801,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1866,8 +1801,7 @@ class PeftLoraLoaderMixinTests:
if isinstance(module, torch.nn.Conv2d): if isinstance(module, torch.nn.Conv2d):
module.padding_mode = mode module.padding_mode = mode
for scheduler_cls in self.scheduler_classes: components, _, _ = self.get_dummy_components()
components, _, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1879,9 +1813,8 @@ class PeftLoraLoaderMixinTests: ...@@ -1879,9 +1813,8 @@ class PeftLoraLoaderMixinTests:
_ = pipe(**inputs)[0] _ = pipe(**inputs)[0]
def test_logs_info_when_no_lora_keys_found(self): def test_logs_info_when_no_lora_keys_found(self):
scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`. # Skip text encoder check for now as that is handled with `transformers`.
components, _, _ = self.get_dummy_components(scheduler_cls) components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1925,9 +1858,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1925,9 +1858,7 @@ class PeftLoraLoaderMixinTests:
def test_set_adapters_match_attention_kwargs(self): def test_set_adapters_match_attention_kwargs(self):
"""Test to check if outputs after `set_adapters()` and attention kwargs match.""" """Test to check if outputs after `set_adapters()` and attention kwargs match."""
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1991,7 +1922,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1991,7 +1922,7 @@ class PeftLoraLoaderMixinTests:
def test_lora_B_bias(self): def test_lora_B_bias(self):
# Currently, this test is only relevant for Flux Control LoRA as we are not # Currently, this test is only relevant for Flux Control LoRA as we are not
# aware of any other LoRA checkpoint that has its `lora_B` biases trained. # aware of any other LoRA checkpoint that has its `lora_B` biases trained.
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -2028,7 +1959,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2028,7 +1959,7 @@ class PeftLoraLoaderMixinTests:
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)) self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
def test_correct_lora_configs_with_different_ranks(self): def test_correct_lora_configs_with_different_ranks(self):
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -2114,7 +2045,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2114,7 +2045,7 @@ class PeftLoraLoaderMixinTests:
self.assertEqual(submodule.bias.dtype, dtype_to_check) self.assertEqual(submodule.bias.dtype, dtype_to_check)
def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype) pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -2181,7 +2112,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2181,7 +2112,7 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None) self.assertTrue(module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None)
# 1. Test forward with add_adapter # 1. Test forward with add_adapter
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype) pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -2211,7 +2142,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2211,7 +2142,7 @@ class PeftLoraLoaderMixinTests:
) )
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
components, _, _ = self.get_dummy_components(self.scheduler_classes[0]) components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype) pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -2231,10 +2162,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2231,10 +2162,7 @@ class PeftLoraLoaderMixinTests:
@parameterized.expand([4, 8, 16]) @parameterized.expand([4, 8, 16])
def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha): def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha):
scheduler_cls = self.scheduler_classes[0] components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, lora_alpha=lora_alpha
)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe, _ = self.add_adapters_to_pipeline( pipe, _ = self.add_adapters_to_pipeline(
...@@ -2280,10 +2208,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2280,10 +2208,7 @@ class PeftLoraLoaderMixinTests:
@parameterized.expand([4, 8, 16]) @parameterized.expand([4, 8, 16])
def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
scheduler_cls = self.scheduler_classes[0] components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, lora_alpha=lora_alpha
)
pipe = self.pipeline_class(**components).to(torch_device) pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
...@@ -2311,8 +2236,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2311,8 +2236,7 @@ class PeftLoraLoaderMixinTests:
def test_lora_unload_add_adapter(self): def test_lora_unload_add_adapter(self):
"""Tests if `unload_lora_weights()` -> `add_adapter()` works.""" """Tests if `unload_lora_weights()` -> `add_adapter()` works."""
scheduler_cls = self.scheduler_classes[0] components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components).to(torch_device) pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
...@@ -2330,8 +2254,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2330,8 +2254,7 @@ class PeftLoraLoaderMixinTests:
def test_inference_load_delete_load_adapters(self): def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -2341,9 +2264,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2341,9 +2264,7 @@ class PeftLoraLoaderMixinTests:
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config) denoiser.add_adapter(denoiser_lora_config)
...@@ -2382,7 +2303,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2382,7 +2303,7 @@ class PeftLoraLoaderMixinTests:
onload_device = torch_device onload_device = torch_device
offload_device = torch.device("cpu") offload_device = torch.device("cpu")
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -2399,7 +2320,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2399,7 +2320,7 @@ class PeftLoraLoaderMixinTests:
) )
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
components, _, _ = self.get_dummy_components(self.scheduler_classes[0]) components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
...@@ -2451,7 +2372,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2451,7 +2372,7 @@ class PeftLoraLoaderMixinTests:
@require_torch_accelerator @require_torch_accelerator
def test_lora_loading_model_cpu_offload(self): def test_lora_loading_model_cpu_offload(self):
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) components, _, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -2470,7 +2391,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2470,7 +2391,7 @@ class PeftLoraLoaderMixinTests:
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
) )
# reinitialize the pipeline to mimic the inference workflow. # reinitialize the pipeline to mimic the inference workflow.
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe.enable_model_cpu_offload(device=torch_device) pipe.enable_model_cpu_offload(device=torch_device)
pipe.load_lora_weights(tmpdirname) pipe.load_lora_weights(tmpdirname)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment