"packaging/vscode:/vscode.git/clone" did not exist on "2883a07bfe854849d409c1638b17841e4875d11c"
Unverified Commit ad787082 authored by Batuhan Taskaya's avatar Batuhan Taskaya Committed by GitHub
Browse files

Fix unloading of LoRAs when xformers attention procs are in use (#4179)

parent 7a47df22
......@@ -26,6 +26,7 @@ from huggingface_hub import hf_hub_download
from torch import nn
from .models.attention_processor import (
LORA_ATTENTION_PROCESSORS,
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
AttnProcessor,
......@@ -1293,22 +1294,21 @@ class LoraLoaderMixin:
>>> ...
```
"""
is_unet_lora = all(
isinstance(processor, (LoRAAttnProcessor2_0, LoRAAttnProcessor, LoRAAttnAddedKVProcessor))
for _, processor in self.unet.attn_processors.items()
)
unet_attention_classes = {type(processor) for _, processor in self.unet.attn_processors.items()}
if unet_attention_classes.issubset(LORA_ATTENTION_PROCESSORS):
# Handle attention processors that are a mix of regular attention and AddedKV
# attention.
if is_unet_lora:
is_attn_procs_mixed = all(
isinstance(processor, (LoRAAttnProcessor2_0, LoRAAttnProcessor))
for _, processor in self.unet.attn_processors.items()
)
if not is_attn_procs_mixed:
unet_attn_proc_cls = AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
self.unet.set_attn_processor(unet_attn_proc_cls())
else:
if len(unet_attention_classes) > 1 or LoRAAttnAddedKVProcessor in unet_attention_classes:
self.unet.set_default_attn_processor()
else:
regular_attention_classes = {
LoRAAttnProcessor: AttnProcessor,
LoRAAttnProcessor2_0: AttnProcessor2_0,
LoRAXFormersAttnProcessor: XFormersAttnProcessor,
}
[attention_proc_class] = unet_attention_classes
self.unet.set_attn_processor(regular_attention_classes[attention_proc_class]())
# Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch()
......
......@@ -167,7 +167,7 @@ class Attention(nn.Module):
):
is_lora = hasattr(self, "processor") and isinstance(
self.processor,
(LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
LORA_ATTENTION_PROCESSORS,
)
is_custom_diffusion = hasattr(self, "processor") and isinstance(
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
......@@ -1623,6 +1623,13 @@ AttentionProcessor = Union[
CustomDiffusionXFormersAttnProcessor,
]
LORA_ATTENTION_PROCESSORS = (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
)
class SpatialNorm(nn.Module):
"""
......
......@@ -464,6 +464,14 @@ class LoraLoaderMixinTests(unittest.TestCase):
if isinstance(module, Attention):
self.assertIsInstance(module.processor, LoRAXFormersAttnProcessor)
# unload lora weights
sd_pipe.unload_lora_weights()
# check if attention processors are reverted back to xFormers
for _, module in sd_pipe.unet.named_modules():
if isinstance(module, Attention):
self.assertIsInstance(module.processor, XFormersAttnProcessor)
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
def test_lora_save_load_with_xformers(self):
pipeline_components, lora_components = self.get_dummy_components()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment