Unverified Commit a584d42c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[LoRA, Xformers] Fix xformers lora (#5201)

* fix xformers lora

* improve

* fix
parent cdcc01be
...@@ -310,19 +310,16 @@ class Attention(nn.Module): ...@@ -310,19 +310,16 @@ class Attention(nn.Module):
self.set_processor(processor) self.set_processor(processor)
def set_processor(self, processor: "AttnProcessor"): def set_processor(self, processor: "AttnProcessor", _remove_lora=False):
if ( if hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
hasattr(self, "processor")
and not isinstance(processor, LORA_ATTENTION_PROCESSORS)
and self.to_q.lora_layer is not None
):
deprecate( deprecate(
"set_processor to offload LoRA", "set_processor to offload LoRA",
"0.26.0", "0.26.0",
"In detail, removing LoRA layers via calling `set_processor` or `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.", "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
) )
# TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
# We need to remove all LoRA layers # We need to remove all LoRA layers
# Don't forget to remove ALL `_remove_lora` from the codebase
for module in self.modules(): for module in self.modules():
if hasattr(module, "set_lora_layer"): if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None) module.set_lora_layer(None)
......
...@@ -196,7 +196,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -196,7 +196,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -220,9 +222,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -220,9 +222,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor) module.set_processor(processor, _remove_lora=_remove_lora)
else: else:
module.set_processor(processor.pop(f"{name}.processor")) module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
...@@ -244,7 +246,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -244,7 +246,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor) self.set_attn_processor(processor, _remove_lora=True)
@apply_forward_hook @apply_forward_hook
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
......
...@@ -517,7 +517,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ...@@ -517,7 +517,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -541,9 +543,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ...@@ -541,9 +543,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor) module.set_processor(processor, _remove_lora=_remove_lora)
else: else:
module.set_processor(processor.pop(f"{name}.processor")) module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
...@@ -565,7 +567,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ...@@ -565,7 +567,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor) self.set_attn_processor(processor, _remove_lora=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
......
...@@ -192,7 +192,9 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -192,7 +192,9 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -216,9 +218,9 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -216,9 +218,9 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor) module.set_processor(processor, _remove_lora=_remove_lora)
else: else:
module.set_processor(processor.pop(f"{name}.processor")) module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
...@@ -240,7 +242,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -240,7 +242,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor) self.set_attn_processor(processor, _remove_lora=True)
def forward( def forward(
self, self,
......
...@@ -613,7 +613,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -613,7 +613,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
return processors return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -637,9 +639,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -637,9 +639,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor) module.set_processor(processor, _remove_lora=_remove_lora)
else: else:
module.set_processor(processor.pop(f"{name}.processor")) module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
...@@ -660,7 +662,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -660,7 +662,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor) self.set_attn_processor(processor, _remove_lora=True)
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
r""" r"""
......
...@@ -366,7 +366,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -366,7 +366,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
fn_recursive_set_attention_slice(module, reversed_slice_size) fn_recursive_set_attention_slice(module, reversed_slice_size)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -390,9 +392,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -390,9 +392,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor) module.set_processor(processor, _remove_lora=_remove_lora)
else: else:
module.set_processor(processor.pop(f"{name}.processor")) module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
...@@ -454,7 +456,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -454,7 +456,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor) self.set_attn_processor(processor, _remove_lora=True)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
......
...@@ -538,7 +538,9 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad ...@@ -538,7 +538,9 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -562,9 +564,9 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad ...@@ -562,9 +564,9 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor) module.set_processor(processor, _remove_lora=_remove_lora)
else: else:
module.set_processor(processor.pop(f"{name}.processor")) module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
...@@ -586,7 +588,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad ...@@ -586,7 +588,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor) self.set_attn_processor(processor, _remove_lora=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
......
...@@ -820,7 +820,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -820,7 +820,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
return processors return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -844,9 +846,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -844,9 +846,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor) module.set_processor(processor, _remove_lora=_remove_lora)
else: else:
module.set_processor(processor.pop(f"{name}.processor")) module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
...@@ -868,7 +870,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -868,7 +870,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
f" {next(iter(self.attn_processors.values()))}" f" {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor) self.set_attn_processor(processor, _remove_lora=True)
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
r""" r"""
......
...@@ -1786,7 +1786,7 @@ class UNet3DConditionModelTests(unittest.TestCase): ...@@ -1786,7 +1786,7 @@ class UNet3DConditionModelTests(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
model.set_attn_processor(AttnProcessor()) model.set_default_attn_processor()
with torch.no_grad(): with torch.no_grad():
new_sample = model(**inputs_dict).sample new_sample = model(**inputs_dict).sample
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment