Unverified Commit 09e777a3 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] Single scheduler in lora tests (#12315)

* single scheduler please.

* up

* up

* up
parent a72bc0c4
......@@ -43,7 +43,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = AuraFlowPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
......
......@@ -21,7 +21,6 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKLCogVideoX,
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler,
CogVideoXPipeline,
CogVideoXTransformer3DModel,
......@@ -44,7 +43,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogVideoXPipeline
scheduler_cls = CogVideoXDPMScheduler
scheduler_kwargs = {"timestep_spacing": "trailing"}
scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler]
transformer_kwargs = {
"num_attention_heads": 4,
......
......@@ -50,7 +50,6 @@ class TokenizerWrapper:
class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogView4Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
......@@ -124,8 +123,7 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
for scheduler_cls in self.scheduler_classes:
components, _, _ = self.get_dummy_components(scheduler_cls)
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......
......@@ -55,9 +55,8 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
@require_peft_backend
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = {
"patch_size": 1,
"in_channels": 4,
......@@ -282,9 +281,8 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxControlPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = {
"patch_size": 1,
"in_channels": 8,
......
......@@ -51,7 +51,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = HunyuanVideoPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
......
......@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = LTXPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
......
......@@ -39,7 +39,6 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = Lumina2Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
......@@ -141,8 +140,7 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
strict=False,
)
def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -150,9 +148,7 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
......
......@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = MochiPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
......
......@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = QwenImagePipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
......
......@@ -31,9 +31,8 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = SanaPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler(shift=7.0)
scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {"shift": 7.0}
transformer_kwargs = {
"patch_size": 1,
"in_channels": 4,
......
......@@ -55,7 +55,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = {
"sample_size": 32,
"patch_size": 1,
......
......@@ -42,7 +42,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
......
......@@ -50,7 +50,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanVACEPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
......@@ -165,9 +164,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_peft_version_greater("0.13.2")
def test_lora_exclude_modules_wanvace(self):
scheduler_cls = self.scheduler_classes[0]
exclude_module_name = "vace_blocks.0.proj_out"
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
......
......@@ -26,8 +26,6 @@ from parameterized import parameterized
from diffusers import (
AutoencoderKL,
DDIMScheduler,
LCMScheduler,
UNet2DConditionModel,
)
from diffusers.utils import logging
......@@ -109,7 +107,6 @@ class PeftLoraLoaderMixinTests:
scheduler_cls = None
scheduler_kwargs = None
scheduler_classes = [DDIMScheduler, LCMScheduler]
has_two_text_encoders = False
has_three_text_encoders = False
......@@ -129,13 +126,13 @@ class PeftLoraLoaderMixinTests:
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
def get_dummy_components(self, use_dora=False, lora_alpha=None):
if self.unet_kwargs and self.transformer_kwargs:
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
if self.has_two_text_encoders and self.has_three_text_encoders:
raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
scheduler_cls = self.scheduler_cls
rank = 4
lora_alpha = rank if lora_alpha is None else lora_alpha
......@@ -319,8 +316,7 @@ class PeftLoraLoaderMixinTests:
"""
Tests a simple inference and makes sure it works as expected
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -334,8 +330,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached on the text encoder
and makes sure it works as expected
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -354,17 +349,14 @@ class PeftLoraLoaderMixinTests:
@require_peft_version_greater("0.13.1")
def test_low_cpu_mem_usage_with_injection(self):
"""Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder."
)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder.")
self.assertTrue(
"meta" in {p.device.type for p in pipe.text_encoder.parameters()},
"The LoRA params should be on 'meta' device.",
......@@ -416,9 +408,7 @@ class PeftLoraLoaderMixinTests:
@require_transformers_version_greater("4.45.2")
def test_low_cpu_mem_usage_with_loading(self):
"""Tests if we can load LoRA state dict with low_cpu_mem_usage."""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -460,9 +450,7 @@ class PeftLoraLoaderMixinTests:
images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(
images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3
),
np.allclose(images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.",
)
......@@ -472,9 +460,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -511,8 +497,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -543,8 +528,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder, then unloads the lora weights
and makes sure it works as expected
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -557,9 +541,7 @@ class PeftLoraLoaderMixinTests:
pipe.unload_lora_weights()
# unloading should remove the LoRA layers
self.assertFalse(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
)
self.assertFalse(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
......@@ -578,8 +560,7 @@ class PeftLoraLoaderMixinTests:
"""
Tests a simple usecase where users could use saving utilities for LoRA.
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -620,8 +601,7 @@ class PeftLoraLoaderMixinTests:
with different ranks and some adapters removed
and makes sure it works as expected
"""
for scheduler_cls in self.scheduler_classes:
components, _, _ = self.get_dummy_components(scheduler_cls)
components, _, _ = self.get_dummy_components()
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
text_lora_config = LoraConfig(
r=4,
......@@ -680,8 +660,7 @@ class PeftLoraLoaderMixinTests:
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -723,8 +702,7 @@ class PeftLoraLoaderMixinTests:
"""
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -763,9 +741,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -808,8 +784,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -824,9 +799,7 @@ class PeftLoraLoaderMixinTests:
# Fusing should still keep the LoRA layers
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
......@@ -846,8 +819,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -860,9 +832,7 @@ class PeftLoraLoaderMixinTests:
pipe.unload_lora_weights()
# unloading should remove the LoRA layers
self.assertFalse(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
)
self.assertFalse(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder")
self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders:
......@@ -885,8 +855,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -925,8 +894,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -937,9 +905,7 @@ class PeftLoraLoaderMixinTests:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
......@@ -1002,8 +968,7 @@ class PeftLoraLoaderMixinTests:
def test_wrong_adapter_name_raises_error(self):
adapter_name = "adapter-1"
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1024,8 +989,7 @@ class PeftLoraLoaderMixinTests:
def test_multiple_wrong_adapter_name_raises_error(self):
adapter_name = "adapter-1"
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1054,8 +1018,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
one adapter and set different weights for different blocks (i.e. block lora)
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1111,8 +1074,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set different weights for different blocks (i.e. block lora)
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1123,9 +1085,7 @@ class PeftLoraLoaderMixinTests:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
......@@ -1274,8 +1234,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set/delete them
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1286,9 +1245,7 @@ class PeftLoraLoaderMixinTests:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
......@@ -1368,8 +1325,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1380,9 +1336,7 @@ class PeftLoraLoaderMixinTests:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
......@@ -1446,8 +1400,7 @@ class PeftLoraLoaderMixinTests:
strict=False,
)
def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1455,9 +1408,7 @@ class PeftLoraLoaderMixinTests:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
......@@ -1466,9 +1417,9 @@ class PeftLoraLoaderMixinTests:
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
if self.unet_kwargs:
pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A[
"adapter-1"
].weight += float("inf")
pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float(
"inf"
)
else:
named_modules = [name for name, _ in pipe.transformer.named_modules()]
possible_tower_names = [
......@@ -1481,9 +1432,7 @@ class PeftLoraLoaderMixinTests:
tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
]
if len(filtered_tower_names) == 0:
reason = (
f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}."
)
reason = f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}."
raise ValueError(reason)
for tower_name in filtered_tower_names:
transformer_tower = getattr(pipe.transformer, tower_name)
......@@ -1508,8 +1457,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1537,8 +1485,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1612,8 +1559,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet and multi-adapter case
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1624,9 +1570,7 @@ class PeftLoraLoaderMixinTests:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
......@@ -1675,9 +1619,7 @@ class PeftLoraLoaderMixinTests:
check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
)
pipe.fuse_lora(
components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]
)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"])
self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
# Fusing should still keep the LoRA layers
......@@ -1693,8 +1635,7 @@ class PeftLoraLoaderMixinTests:
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
for lora_scale in [1.0, 0.8]:
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1746,10 +1687,7 @@ class PeftLoraLoaderMixinTests:
@require_peft_version_greater(peft_version="0.9.0")
def test_simple_inference_with_dora(self):
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, use_dora=True
)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1768,9 +1706,8 @@ class PeftLoraLoaderMixinTests:
)
def test_missing_keys_warning(self):
scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`.
components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1805,9 +1742,8 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", ""))
def test_unexpected_keys_warning(self):
scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`.
components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1842,8 +1778,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1857,7 +1792,7 @@ class PeftLoraLoaderMixinTests:
if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
# Just makes sure it works..
# Just makes sure it works.
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
def test_modify_padding_mode(self):
......@@ -1866,8 +1801,7 @@ class PeftLoraLoaderMixinTests:
if isinstance(module, torch.nn.Conv2d):
module.padding_mode = mode
for scheduler_cls in self.scheduler_classes:
components, _, _ = self.get_dummy_components(scheduler_cls)
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1879,9 +1813,8 @@ class PeftLoraLoaderMixinTests:
_ = pipe(**inputs)[0]
def test_logs_info_when_no_lora_keys_found(self):
scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`.
components, _, _ = self.get_dummy_components(scheduler_cls)
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1925,9 +1858,7 @@ class PeftLoraLoaderMixinTests:
def test_set_adapters_match_attention_kwargs(self):
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1991,7 +1922,7 @@ class PeftLoraLoaderMixinTests:
def test_lora_B_bias(self):
# Currently, this test is only relevant for Flux Control LoRA as we are not
# aware of any other LoRA checkpoint that has its `lora_B` biases trained.
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -2028,7 +1959,7 @@ class PeftLoraLoaderMixinTests:
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
def test_correct_lora_configs_with_different_ranks(self):
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -2114,7 +2045,7 @@ class PeftLoraLoaderMixinTests:
self.assertEqual(submodule.bias.dtype, dtype_to_check)
def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
......@@ -2181,7 +2112,7 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None)
# 1. Test forward with add_adapter
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
......@@ -2211,7 +2142,7 @@ class PeftLoraLoaderMixinTests:
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
......@@ -2231,10 +2162,7 @@ class PeftLoraLoaderMixinTests:
@parameterized.expand([4, 8, 16])
def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha):
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, lora_alpha=lora_alpha
)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
pipe = self.pipeline_class(**components)
pipe, _ = self.add_adapters_to_pipeline(
......@@ -2280,10 +2208,7 @@ class PeftLoraLoaderMixinTests:
@parameterized.expand([4, 8, 16])
def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, lora_alpha=lora_alpha
)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
......@@ -2311,8 +2236,7 @@ class PeftLoraLoaderMixinTests:
def test_lora_unload_add_adapter(self):
"""Tests if `unload_lora_weights()` -> `add_adapter()` works."""
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
......@@ -2330,8 +2254,7 @@ class PeftLoraLoaderMixinTests:
def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -2341,9 +2264,7 @@ class PeftLoraLoaderMixinTests:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
......@@ -2382,7 +2303,7 @@ class PeftLoraLoaderMixinTests:
onload_device = torch_device
offload_device = torch.device("cpu")
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -2399,7 +2320,7 @@ class PeftLoraLoaderMixinTests:
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
......@@ -2451,7 +2372,7 @@ class PeftLoraLoaderMixinTests:
@require_torch_accelerator
def test_lora_loading_model_cpu_offload(self):
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
components, _, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
......@@ -2470,7 +2391,7 @@ class PeftLoraLoaderMixinTests:
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
# reinitialize the pipeline to mimic the inference workflow.
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.enable_model_cpu_offload(device=torch_device)
pipe.load_lora_weights(tmpdirname)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment