Unverified Commit e5d0a328 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[refactor] LoRA tests (#9481)

* refactor scheduler class usage

* reorder to make tests more readable

* remove pipeline specific checks and skip tests directly

* rewrite denoiser conditions cleaner

* bump tolerance for cog test
parent 14a1b86f
...@@ -48,6 +48,7 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -48,6 +48,7 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogVideoXPipeline pipeline_class = CogVideoXPipeline
scheduler_cls = CogVideoXDPMScheduler scheduler_cls = CogVideoXDPMScheduler
scheduler_kwargs = {"timestep_spacing": "trailing"} scheduler_kwargs = {"timestep_spacing": "trailing"}
scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler]
transformer_kwargs = { transformer_kwargs = {
"num_attention_heads": 4, "num_attention_heads": 4,
...@@ -126,8 +127,7 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -126,8 +127,7 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@skip_mps @skip_mps
def test_lora_fuse_nan(self): def test_lora_fuse_nan(self):
scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler] for scheduler_cls in self.scheduler_classes:
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -156,10 +156,22 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -156,10 +156,22 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
self.assertTrue(np.isnan(out).all()) self.assertTrue(np.isnan(out).all())
def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=5e-3) super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self): def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=5e-3) super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in CogVideoX.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.") @unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_partial_text_lora(self): def test_simple_inference_with_partial_text_lora(self):
......
...@@ -47,7 +47,7 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -47,7 +47,7 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler() scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {} scheduler_kwargs = {}
uses_flow_matching = True scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = { transformer_kwargs = {
"patch_size": 1, "patch_size": 1,
"in_channels": 4, "in_channels": 4,
...@@ -154,6 +154,14 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -154,6 +154,14 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
) )
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3)) self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Flux.")
def test_modify_padding_mode(self):
pass
@slow @slow
@require_torch_gpu @require_torch_gpu
......
...@@ -34,7 +34,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -34,7 +34,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
uses_flow_matching = True scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = { transformer_kwargs = {
"sample_size": 32, "sample_size": 32,
"patch_size": 1, "patch_size": 1,
...@@ -92,3 +92,19 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -92,3 +92,19 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
lora_filename = "lora_peft_format.safetensors" lora_filename = "lora_peft_format.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
@unittest.skip("Not supported in SD3.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in SD3.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass
@unittest.skip("Not supported in SD3.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in SD3.")
def test_modify_padding_mode(self):
pass
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment