Unverified Commit 67cf0445 authored by Takuma Mori's avatar Takuma Mori Committed by GitHub
Browse files

Fix to apply LoRAXFormersAttnProcessor instead of LoRAAttnProcessor when...

Fix to apply LoRAXFormersAttnProcessor instead of LoRAAttnProcessor when xFormers is enabled (#3556)

* fix to use LoRAXFormersAttnProcessor

* add test

* using new LoraLoaderMixin.save_lora_weights

* add test_lora_save_load_with_xformers
parent 352ca319
...@@ -27,7 +27,9 @@ from .models.attention_processor import ( ...@@ -27,7 +27,9 @@ from .models.attention_processor import (
CustomDiffusionXFormersAttnProcessor, CustomDiffusionXFormersAttnProcessor,
LoRAAttnAddedKVProcessor, LoRAAttnAddedKVProcessor,
LoRAAttnProcessor, LoRAAttnProcessor,
LoRAXFormersAttnProcessor,
SlicedAttnAddedKVProcessor, SlicedAttnAddedKVProcessor,
XFormersAttnProcessor,
) )
from .utils import ( from .utils import (
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
...@@ -279,7 +281,10 @@ class UNet2DConditionLoadersMixin: ...@@ -279,7 +281,10 @@ class UNet2DConditionLoadersMixin:
attn_processor_class = LoRAAttnAddedKVProcessor attn_processor_class = LoRAAttnAddedKVProcessor
else: else:
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
attn_processor_class = LoRAAttnProcessor if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
attn_processor_class = LoRAXFormersAttnProcessor
else:
attn_processor_class = LoRAAttnProcessor
attn_processors[key] = attn_processor_class( attn_processors[key] = attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
......
...@@ -22,7 +22,14 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer ...@@ -22,7 +22,14 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
from diffusers.models.attention_processor import LoRAAttnProcessor from diffusers.models.attention_processor import (
Attention,
AttnProcessor,
AttnProcessor2_0,
LoRAAttnProcessor,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device
...@@ -212,3 +219,90 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -212,3 +219,90 @@ class LoraLoaderMixinTests(unittest.TestCase):
# Outputs shouldn't match. # Outputs shouldn't match.
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
def create_lora_weight_file(self, tmpdirname):
_, lora_components = self.get_dummy_components()
LoraLoaderMixin.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
def test_lora_unet_attn_processors(self):
with tempfile.TemporaryDirectory() as tmpdirname:
self.create_lora_weight_file(tmpdirname)
pipeline_components, _ = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
# check if vanilla attention processors are used
for _, module in sd_pipe.unet.named_modules():
if isinstance(module, Attention):
self.assertIsInstance(module.processor, (AttnProcessor, AttnProcessor2_0))
# load LoRA weight file
sd_pipe.load_lora_weights(tmpdirname)
# check if lora attention processors are used
for _, module in sd_pipe.unet.named_modules():
if isinstance(module, Attention):
self.assertIsInstance(module.processor, LoRAAttnProcessor)
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
def test_lora_unet_attn_processors_with_xformers(self):
with tempfile.TemporaryDirectory() as tmpdirname:
self.create_lora_weight_file(tmpdirname)
pipeline_components, _ = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
# enable XFormers
sd_pipe.enable_xformers_memory_efficient_attention()
# check if xFormers attention processors are used
for _, module in sd_pipe.unet.named_modules():
if isinstance(module, Attention):
self.assertIsInstance(module.processor, XFormersAttnProcessor)
# load LoRA weight file
sd_pipe.load_lora_weights(tmpdirname)
# check if lora attention processors are used
for _, module in sd_pipe.unet.named_modules():
if isinstance(module, Attention):
self.assertIsInstance(module.processor, LoRAXFormersAttnProcessor)
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
def test_lora_save_load_with_xformers(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
noise, input_ids, pipeline_inputs = self.get_dummy_inputs()
# enable XFormers
sd_pipe.enable_xformers_memory_efficient_attention()
original_images = sd_pipe(**pipeline_inputs).images
orig_image_slice = original_images[0, -3:, -3:, -1]
with tempfile.TemporaryDirectory() as tmpdirname:
LoraLoaderMixin.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs).images
lora_image_slice = lora_images[0, -3:, -3:, -1]
# Outputs shouldn't match.
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment