Unverified Commit 09e777a3 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] Single scheduler in lora tests (#12315)

* single scheduler please.

* up

* up

* up
parent a72bc0c4
...@@ -43,7 +43,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -43,7 +43,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = AuraFlowPipeline pipeline_class = AuraFlowPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -21,7 +21,6 @@ from transformers import AutoTokenizer, T5EncoderModel ...@@ -21,7 +21,6 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import ( from diffusers import (
AutoencoderKLCogVideoX, AutoencoderKLCogVideoX,
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler, CogVideoXDPMScheduler,
CogVideoXPipeline, CogVideoXPipeline,
CogVideoXTransformer3DModel, CogVideoXTransformer3DModel,
...@@ -44,7 +43,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -44,7 +43,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogVideoXPipeline pipeline_class = CogVideoXPipeline
scheduler_cls = CogVideoXDPMScheduler scheduler_cls = CogVideoXDPMScheduler
scheduler_kwargs = {"timestep_spacing": "trailing"} scheduler_kwargs = {"timestep_spacing": "trailing"}
scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler]
transformer_kwargs = { transformer_kwargs = {
"num_attention_heads": 4, "num_attention_heads": 4,
......
...@@ -50,7 +50,6 @@ class TokenizerWrapper: ...@@ -50,7 +50,6 @@ class TokenizerWrapper:
class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogView4Pipeline pipeline_class = CogView4Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
...@@ -124,30 +123,29 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -124,30 +123,29 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
""" """
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
""" """
for scheduler_cls in self.scheduler_classes: components, _, _ = self.get_dummy_components()
components, _, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname) pipe.save_pretrained(tmpdirname)
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
pipe_from_pretrained.to(torch_device) pipe_from_pretrained.to(torch_device)
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0] images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.", "Loading from saved checkpoints should give same results.",
) )
@parameterized.expand([("block_level", True), ("leaf_level", False)]) @parameterized.expand([("block_level", True), ("leaf_level", False)])
@require_torch_accelerator @require_torch_accelerator
......
...@@ -55,9 +55,8 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa ...@@ -55,9 +55,8 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
@require_peft_backend @require_peft_backend
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler() scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = { transformer_kwargs = {
"patch_size": 1, "patch_size": 1,
"in_channels": 4, "in_channels": 4,
...@@ -282,9 +281,8 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -282,9 +281,8 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxControlPipeline pipeline_class = FluxControlPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler() scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = { transformer_kwargs = {
"patch_size": 1, "patch_size": 1,
"in_channels": 8, "in_channels": 8,
......
...@@ -51,7 +51,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -51,7 +51,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = HunyuanVideoPipeline pipeline_class = HunyuanVideoPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = LTXPipeline pipeline_class = LTXPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -39,7 +39,6 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa ...@@ -39,7 +39,6 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = Lumina2Pipeline pipeline_class = Lumina2Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
...@@ -141,33 +140,30 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -141,33 +140,30 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
strict=False, strict=False,
) )
def test_lora_fuse_nan(self): def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
) denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1") # corrupt one LoRA weight with `inf` values
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") with torch.no_grad():
pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
# corrupt one LoRA weight with `inf` values
with torch.no_grad(): # with `safe_fusing=True` we should see an Error
pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError): # without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0]
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) self.assertTrue(np.isnan(out).all())
out = pipe(**inputs)[0]
self.assertTrue(np.isnan(out).all())
...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = MochiPipeline pipeline_class = MochiPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = QwenImagePipeline pipeline_class = QwenImagePipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -31,9 +31,8 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -31,9 +31,8 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = SanaPipeline pipeline_class = SanaPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler(shift=7.0) scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {"shift": 7.0}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = { transformer_kwargs = {
"patch_size": 1, "patch_size": 1,
"in_channels": 4, "in_channels": 4,
......
...@@ -55,7 +55,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -55,7 +55,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = { transformer_kwargs = {
"sample_size": 32, "sample_size": 32,
"patch_size": 1, "patch_size": 1,
......
...@@ -42,7 +42,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -42,7 +42,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanPipeline pipeline_class = WanPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -50,7 +50,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -50,7 +50,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanVACEPipeline pipeline_class = WanVACEPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
...@@ -165,9 +164,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -165,9 +164,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_peft_version_greater("0.13.2") @require_peft_version_greater("0.13.2")
def test_lora_exclude_modules_wanvace(self): def test_lora_exclude_modules_wanvace(self):
scheduler_cls = self.scheduler_classes[0]
exclude_module_name = "vace_blocks.0.proj_out" exclude_module_name = "vace_blocks.0.proj_out"
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device) pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment