"docs/vscode:/vscode.git/clone" did not exist on "dba5cf31d7b2d851945f3fbf7425d5850a491dfd"
Unverified Commit b91e8c0d authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[lora]: Fix Flux2 LoRA NaN test (#12714)



* up

* Update tests/lora/test_lora_layers_flux2.py
Co-authored-by: default avatardg845 <58458699+dg845@users.noreply.github.com>

---------
Co-authored-by: default avatardg845 <58458699+dg845@users.noreply.github.com>
parent ac786462
...@@ -15,17 +15,18 @@ ...@@ -15,17 +15,18 @@
import sys import sys
import unittest import unittest
import numpy as np
import torch import torch
from transformers import AutoProcessor, Mistral3ForConditionalGeneration from transformers import AutoProcessor, Mistral3ForConditionalGeneration
from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler, Flux2Pipeline, Flux2Transformer2DModel from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler, Flux2Pipeline, Flux2Transformer2DModel
from ..testing_utils import floats_tensor, require_peft_backend from ..testing_utils import floats_tensor, require_peft_backend, torch_device
sys.path.append(".") sys.path.append(".")
from .utils import PeftLoraLoaderMixinTests # noqa: E402 from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
@require_peft_backend @require_peft_backend
...@@ -94,6 +95,46 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -94,6 +95,46 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs return noise, input_ids, pipeline_inputs
# Overriding because (1) text encoder LoRAs are not supported in Flux 2 and (2) because the Flux 2 single block
# QKV projections are always fused, it has no `to_q` param as expected by the original test.
def test_lora_fuse_nan(self):
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
possible_tower_names = ["transformer_blocks", "single_transformer_blocks"]
filtered_tower_names = [
tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
]
if len(filtered_tower_names) == 0:
reason = f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}."
raise ValueError(reason)
for tower_name in filtered_tower_names:
transformer_tower = getattr(pipe.transformer, tower_name)
is_single = "single" in tower_name
if is_single:
transformer_tower[0].attn.to_qkv_mlp_proj.lora_A["adapter-1"].weight += float("inf")
else:
transformer_tower[0].attn.to_k.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0]
self.assertTrue(np.isnan(out).all())
@unittest.skip("Not supported in Flux2.") @unittest.skip("Not supported in Flux2.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment