Unverified Commit d7001400 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA deprecation] handle rest of the stuff related to deprecated lora stuff. (#6426)

* handle rest of the stuff related to deprecated lora stuff.

* fix: copies

* don't modify the uNet in-place.

* fix: temporal autoencoder.

* manually remove lora layers.

* don't copy unet.

* alright

* remove lora attn processors from unet3d

* fix: unet3d.

* styl

* Empty-Commit
parent 2e4dc3e2
...@@ -494,9 +494,7 @@ class ControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -494,9 +494,7 @@ class ControlNetXSModel(ModelMixin, ConfigMixin):
""" """
return self.control_model.attn_processors return self.control_model.attn_processors
def set_attn_processor( def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -509,7 +507,7 @@ class ControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -509,7 +507,7 @@ class ControlNetXSModel(ModelMixin, ConfigMixin):
processor. This is strongly recommended when setting trainable attention processors. processor. This is strongly recommended when setting trainable attention processors.
""" """
self.control_model.set_attn_processor(processor, _remove_lora) self.control_model.set_attn_processor(processor)
def set_default_attn_processor(self): def set_default_attn_processor(self):
""" """
......
...@@ -980,7 +980,7 @@ class LoraLoaderMixin: ...@@ -980,7 +980,7 @@ class LoraLoaderMixin:
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
if version.parse(__version__) > version.parse("0.23"): if version.parse(__version__) > version.parse("0.23"):
logger.warn( logger.warning(
"You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights," "You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT." "you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
) )
......
...@@ -373,29 +373,14 @@ class Attention(nn.Module): ...@@ -373,29 +373,14 @@ class Attention(nn.Module):
self.set_processor(processor) self.set_processor(processor)
def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None: def set_processor(self, processor: "AttnProcessor") -> None:
r""" r"""
Set the attention processor to use. Set the attention processor to use.
Args: Args:
processor (`AttnProcessor`): processor (`AttnProcessor`):
The attention processor to use. The attention processor to use.
_remove_lora (`bool`, *optional*, defaults to `False`):
Set to `True` to remove LoRA layers from the model.
""" """
if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
deprecate(
"set_processor to offload LoRA",
"0.26.0",
"In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
)
# TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
# We need to remove all LoRA layers
# Don't forget to remove ALL `_remove_lora` from the codebase
for module in self.modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
# if current processor is in `self._modules` and if passed `processor` is not, we need to # if current processor is in `self._modules` and if passed `processor` is not, we need to
# pop `processor` from `self._modules` # pop `processor` from `self._modules`
if ( if (
......
...@@ -182,9 +182,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -182,9 +182,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor( def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -208,9 +206,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -208,9 +206,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora) module.set_processor(processor)
else: else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
...@@ -232,7 +230,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -232,7 +230,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor, _remove_lora=True) self.set_attn_processor(processor)
@apply_forward_hook @apply_forward_hook
def encode( def encode(
......
...@@ -267,9 +267,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin ...@@ -267,9 +267,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor( def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -293,9 +291,9 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin ...@@ -293,9 +291,9 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora) module.set_processor(processor)
else: else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
...@@ -314,7 +312,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin ...@@ -314,7 +312,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor, _remove_lora=True) self.set_attn_processor(processor)
@apply_forward_hook @apply_forward_hook
def encode( def encode(
......
...@@ -212,9 +212,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -212,9 +212,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor( def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -238,9 +236,9 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -238,9 +236,9 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora) module.set_processor(processor)
else: else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
...@@ -262,7 +260,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -262,7 +260,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor, _remove_lora=True) self.set_attn_processor(processor)
@apply_forward_hook @apply_forward_hook
def encode( def encode(
......
...@@ -534,9 +534,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ...@@ -534,9 +534,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor( def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -560,9 +558,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ...@@ -560,9 +558,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora) module.set_processor(processor)
else: else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
...@@ -584,7 +582,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ...@@ -584,7 +582,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor, _remove_lora=True) self.set_attn_processor(processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
......
...@@ -192,9 +192,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -192,9 +192,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor( def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -218,9 +216,9 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -218,9 +216,9 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora) module.set_processor(processor)
else: else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
...@@ -242,7 +240,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -242,7 +240,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor, _remove_lora=True) self.set_attn_processor(processor)
def forward( def forward(
self, self,
......
...@@ -643,9 +643,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -643,9 +643,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
return processors return processors
def set_attn_processor( def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -669,9 +667,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -669,9 +667,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora) module.set_processor(processor)
else: else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
...@@ -692,7 +690,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -692,7 +690,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor, _remove_lora=True) self.set_attn_processor(processor)
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
r""" r"""
......
...@@ -375,9 +375,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -375,9 +375,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
fn_recursive_set_attention_slice(module, reversed_slice_size) fn_recursive_set_attention_slice(module, reversed_slice_size)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor( def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -401,9 +399,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -401,9 +399,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora) module.set_processor(processor)
else: else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
...@@ -465,7 +463,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -465,7 +463,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor, _remove_lora=True) self.set_attn_processor(processor)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None: def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
......
...@@ -549,9 +549,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -549,9 +549,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor( def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -575,9 +573,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -575,9 +573,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora) module.set_processor(processor)
else: else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
...@@ -641,7 +639,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -641,7 +639,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor, _remove_lora=True) self.set_attn_processor(processor)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None: def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)): if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
......
...@@ -237,9 +237,7 @@ class UVit2DModel(ModelMixin, ConfigMixin): ...@@ -237,9 +237,7 @@ class UVit2DModel(ModelMixin, ConfigMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor( def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -263,9 +261,9 @@ class UVit2DModel(ModelMixin, ConfigMixin): ...@@ -263,9 +261,9 @@ class UVit2DModel(ModelMixin, ConfigMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora) module.set_processor(processor)
else: else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
...@@ -287,7 +285,7 @@ class UVit2DModel(ModelMixin, ConfigMixin): ...@@ -287,7 +285,7 @@ class UVit2DModel(ModelMixin, ConfigMixin):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor, _remove_lora=True) self.set_attn_processor(processor)
class UVit2DConvEmbed(nn.Module): class UVit2DConvEmbed(nn.Module):
......
...@@ -538,9 +538,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad ...@@ -538,9 +538,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor( def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -564,9 +562,9 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad ...@@ -564,9 +562,9 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora) module.set_processor(processor)
else: else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
...@@ -588,7 +586,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad ...@@ -588,7 +586,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor, _remove_lora=True) self.set_attn_processor(processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
......
...@@ -848,9 +848,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -848,9 +848,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
return processors return processors
def set_attn_processor( def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -874,9 +872,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -874,9 +872,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora) module.set_processor(processor)
else: else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
...@@ -897,7 +895,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -897,7 +895,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor, _remove_lora=True) self.set_attn_processor(processor)
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
r""" r"""
......
...@@ -91,9 +91,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -91,9 +91,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor( def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -117,9 +115,9 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -117,9 +115,9 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora) module.set_processor(processor)
else: else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
...@@ -141,7 +139,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -141,7 +139,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor, _remove_lora=True) self.set_attn_processor(processor)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value self.gradient_checkpointing = value
......
...@@ -61,7 +61,8 @@ from diffusers.utils.testing_utils import ( ...@@ -61,7 +61,8 @@ from diffusers.utils.testing_utils import (
) )
def text_encoder_attn_modules(text_encoder): def text_encoder_attn_modules(text_encoder: nn.Module):
"""Fetches the attention modules from `text_encoder`."""
attn_modules = [] attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
...@@ -75,7 +76,8 @@ def text_encoder_attn_modules(text_encoder): ...@@ -75,7 +76,8 @@ def text_encoder_attn_modules(text_encoder):
return attn_modules return attn_modules
def text_encoder_lora_state_dict(text_encoder): def text_encoder_lora_state_dict(text_encoder: nn.Module):
"""Returns the LoRA state dict of the `text_encoder`. Assumes that `_modify_text_encoder()` was already called on it."""
state_dict = {} state_dict = {}
for name, module in text_encoder_attn_modules(text_encoder): for name, module in text_encoder_attn_modules(text_encoder):
...@@ -95,6 +97,8 @@ def text_encoder_lora_state_dict(text_encoder): ...@@ -95,6 +97,8 @@ def text_encoder_lora_state_dict(text_encoder):
def create_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): def create_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True):
"""Creates and returns the LoRA state dict for the UNet."""
# So that we accidentally don't end up using the in-place modified UNet.
unet_lora_parameters = [] unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items(): for attn_processor_name, attn_processor in unet.attn_processors.items():
...@@ -145,10 +149,17 @@ def create_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): ...@@ -145,10 +149,17 @@ def create_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True):
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
return unet_lora_parameters, unet_lora_state_dict(unet) unet_lora_sd = unet_lora_state_dict(unet)
# Unload LoRA.
for module in unet.modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
return unet_lora_parameters, unet_lora_sd
def create_3d_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): def create_3d_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True):
"""Creates and returns the LoRA state dict for the 3D UNet."""
for attn_processor_name in unet.attn_processors.keys(): for attn_processor_name in unet.attn_processors.keys():
has_cross_attention = attn_processor_name.endswith("attn2.processor") and not ( has_cross_attention = attn_processor_name.endswith("attn2.processor") and not (
attn_processor_name.startswith("transformer_in") or "temp_attentions" in attn_processor_name.split(".") attn_processor_name.startswith("transformer_in") or "temp_attentions" in attn_processor_name.split(".")
...@@ -216,10 +227,18 @@ def create_3d_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): ...@@ -216,10 +227,18 @@ def create_3d_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True):
attn_module.to_v.lora_layer.up.weight += 1 attn_module.to_v.lora_layer.up.weight += 1
attn_module.to_out[0].lora_layer.up.weight += 1 attn_module.to_out[0].lora_layer.up.weight += 1
return unet_lora_state_dict(unet) unet_lora_sd = unet_lora_state_dict(unet)
# Unload LoRA.
for module in unet.modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
return unet_lora_sd
def set_lora_weights(lora_attn_parameters, randn_weight=False, var=1.0): def set_lora_weights(lora_attn_parameters, randn_weight=False, var=1.0):
"""Randomizes the LoRA params if specified."""
if not isinstance(lora_attn_parameters, dict): if not isinstance(lora_attn_parameters, dict):
with torch.no_grad(): with torch.no_grad():
for parameter in lora_attn_parameters: for parameter in lora_attn_parameters:
...@@ -1441,6 +1460,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase): ...@@ -1441,6 +1460,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
class UNet2DConditionLoRAModelTests(unittest.TestCase): class UNet2DConditionLoRAModelTests(unittest.TestCase):
model_class = UNet2DConditionModel model_class = UNet2DConditionModel
main_input_name = "sample" main_input_name = "sample"
lora_rank = 4
@property @property
def dummy_input(self): def dummy_input(self):
...@@ -1489,7 +1509,7 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase): ...@@ -1489,7 +1509,7 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
sample1 = model(**inputs_dict).sample sample1 = model(**inputs_dict).sample
_, lora_params = create_unet_lora_layers(model) _, lora_params = create_unet_lora_layers(model, rank=self.lora_rank)
# make sure we can set a list of attention processors # make sure we can set a list of attention processors
model.load_attn_procs(lora_params) model.load_attn_procs(lora_params)
...@@ -1522,13 +1542,16 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase): ...@@ -1522,13 +1542,16 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
old_sample = model(**inputs_dict).sample old_sample = model(**inputs_dict).sample
_, lora_params = create_unet_lora_layers(model) _, lora_params = create_unet_lora_layers(model, rank=self.lora_rank)
model.load_attn_procs(lora_params) model.load_attn_procs(lora_params)
with torch.no_grad(): with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
model.set_default_attn_processor() # Unload LoRA.
for module in model.modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
with torch.no_grad(): with torch.no_grad():
new_sample = model(**inputs_dict).sample new_sample = model(**inputs_dict).sample
...@@ -1552,7 +1575,7 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase): ...@@ -1552,7 +1575,7 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase):
torch.manual_seed(0) torch.manual_seed(0)
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
model.to(torch_device) model.to(torch_device)
_, lora_params = create_unet_lora_layers(model) _, lora_params = create_unet_lora_layers(model, rank=self.lora_rank)
model.load_attn_procs(lora_params) model.load_attn_procs(lora_params)
# default # default
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment