Unverified Commit 09e777a3 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] Single scheduler in lora tests (#12315)

* single scheduler please.

* up

* up

* up
parent a72bc0c4
...@@ -43,7 +43,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -43,7 +43,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = AuraFlowPipeline pipeline_class = AuraFlowPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -21,7 +21,6 @@ from transformers import AutoTokenizer, T5EncoderModel ...@@ -21,7 +21,6 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import ( from diffusers import (
AutoencoderKLCogVideoX, AutoencoderKLCogVideoX,
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler, CogVideoXDPMScheduler,
CogVideoXPipeline, CogVideoXPipeline,
CogVideoXTransformer3DModel, CogVideoXTransformer3DModel,
...@@ -44,7 +43,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -44,7 +43,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogVideoXPipeline pipeline_class = CogVideoXPipeline
scheduler_cls = CogVideoXDPMScheduler scheduler_cls = CogVideoXDPMScheduler
scheduler_kwargs = {"timestep_spacing": "trailing"} scheduler_kwargs = {"timestep_spacing": "trailing"}
scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler]
transformer_kwargs = { transformer_kwargs = {
"num_attention_heads": 4, "num_attention_heads": 4,
......
...@@ -50,7 +50,6 @@ class TokenizerWrapper: ...@@ -50,7 +50,6 @@ class TokenizerWrapper:
class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogView4Pipeline pipeline_class = CogView4Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
...@@ -124,8 +123,7 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -124,8 +123,7 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
""" """
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
""" """
for scheduler_cls in self.scheduler_classes: components, _, _ = self.get_dummy_components()
components, _, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
......
...@@ -55,9 +55,8 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa ...@@ -55,9 +55,8 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
@require_peft_backend @require_peft_backend
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler() scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = { transformer_kwargs = {
"patch_size": 1, "patch_size": 1,
"in_channels": 4, "in_channels": 4,
...@@ -282,9 +281,8 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -282,9 +281,8 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxControlPipeline pipeline_class = FluxControlPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler() scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = { transformer_kwargs = {
"patch_size": 1, "patch_size": 1,
"in_channels": 8, "in_channels": 8,
......
...@@ -51,7 +51,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -51,7 +51,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = HunyuanVideoPipeline pipeline_class = HunyuanVideoPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = LTXPipeline pipeline_class = LTXPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -39,7 +39,6 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa ...@@ -39,7 +39,6 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = Lumina2Pipeline pipeline_class = Lumina2Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
...@@ -141,8 +140,7 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -141,8 +140,7 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
strict=False, strict=False,
) )
def test_lora_fuse_nan(self): def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -150,9 +148,7 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -150,9 +148,7 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue( self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
......
...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = MochiPipeline pipeline_class = MochiPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = QwenImagePipeline pipeline_class = QwenImagePipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -31,9 +31,8 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -31,9 +31,8 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = SanaPipeline pipeline_class = SanaPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler(shift=7.0) scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {"shift": 7.0}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = { transformer_kwargs = {
"patch_size": 1, "patch_size": 1,
"in_channels": 4, "in_channels": 4,
......
...@@ -55,7 +55,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -55,7 +55,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = { transformer_kwargs = {
"sample_size": 32, "sample_size": 32,
"patch_size": 1, "patch_size": 1,
......
...@@ -42,7 +42,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -42,7 +42,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanPipeline pipeline_class = WanPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -50,7 +50,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -50,7 +50,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanVACEPipeline pipeline_class = WanVACEPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
...@@ -165,9 +164,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -165,9 +164,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_peft_version_greater("0.13.2") @require_peft_version_greater("0.13.2")
def test_lora_exclude_modules_wanvace(self): def test_lora_exclude_modules_wanvace(self):
scheduler_cls = self.scheduler_classes[0]
exclude_module_name = "vace_blocks.0.proj_out" exclude_module_name = "vace_blocks.0.proj_out"
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device) pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment