Unverified Commit 759ea587 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[core] reuse `AttentionMixin` for compatible classes (#12463)

* remove attn_processors property

* more

* up

* up more.

* up

* add AttentionMixin to AuraFlow.

* up

* up

* up

* up
parent f48f9c25
...@@ -13,15 +13,15 @@ ...@@ -13,15 +13,15 @@
# limitations under the License. # limitations under the License.
from typing import Dict, Tuple, Union from typing import Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging from ...utils import logging
from ..attention import FeedForward from ..attention import AttentionMixin, FeedForward
from ..attention_processor import Attention, AttentionProcessor, CogVideoXAttnProcessor2_0 from ..attention_processor import Attention, CogVideoXAttnProcessor2_0
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
...@@ -125,7 +125,7 @@ class CogView3PlusTransformerBlock(nn.Module): ...@@ -125,7 +125,7 @@ class CogView3PlusTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): class CogView3PlusTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin):
r""" r"""
The Transformer model introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay The Transformer model introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay
Diffusion](https://huggingface.co/papers/2403.05121). Diffusion](https://huggingface.co/papers/2403.05121).
...@@ -224,66 +224,6 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -224,66 +224,6 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
self.gradient_checkpointing = False self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -23,9 +23,9 @@ from diffusers.loaders import FromOriginalModelMixin ...@@ -23,9 +23,9 @@ from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention, AttentionProcessor from ..attention_processor import Attention
from ..cache_utils import CacheMixin from ..cache_utils import CacheMixin
from ..embeddings import ( from ..embeddings import (
CombinedTimestepTextProjEmbeddings, CombinedTimestepTextProjEmbeddings,
...@@ -838,7 +838,9 @@ class HunyuanVideoTokenReplaceTransformerBlock(nn.Module): ...@@ -838,7 +838,9 @@ class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): class HunyuanVideoTransformer3DModel(
ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin
):
r""" r"""
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
...@@ -987,66 +989,6 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, ...@@ -987,66 +989,6 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
self.gradient_checkpointing = False self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -23,9 +23,9 @@ from diffusers.loaders import FromOriginalModelMixin ...@@ -23,9 +23,9 @@ from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention, AttentionProcessor from ..attention_processor import Attention
from ..cache_utils import CacheMixin from ..cache_utils import CacheMixin
from ..embeddings import ( from ..embeddings import (
CombinedTimestepTextProjEmbeddings, CombinedTimestepTextProjEmbeddings,
...@@ -494,7 +494,9 @@ class HunyuanVideo15TransformerBlock(nn.Module): ...@@ -494,7 +494,9 @@ class HunyuanVideo15TransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
class HunyuanVideo15Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): class HunyuanVideo15Transformer3DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
):
r""" r"""
A Transformer model for video-like data used in [HunyuanVideo1.5](https://huggingface.co/tencent/HunyuanVideo1.5). A Transformer model for video-like data used in [HunyuanVideo1.5](https://huggingface.co/tencent/HunyuanVideo1.5).
...@@ -604,66 +606,6 @@ class HunyuanVideo15Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin ...@@ -604,66 +606,6 @@ class HunyuanVideo15Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin
self.gradient_checkpointing = False self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -25,9 +25,9 @@ from ...configuration_utils import ConfigMixin, register_to_config ...@@ -25,9 +25,9 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention, AttentionProcessor from ..attention_processor import Attention
from ..cache_utils import CacheMixin from ..cache_utils import CacheMixin
from ..embeddings import ( from ..embeddings import (
CombinedTimestepTextProjEmbeddings, CombinedTimestepTextProjEmbeddings,
...@@ -616,7 +616,9 @@ class HunyuanImageTransformerBlock(nn.Module): ...@@ -616,7 +616,9 @@ class HunyuanImageTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): class HunyuanImageTransformer2DModel(
ModelMixin, ConfigMixin, AttentionMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin
):
r""" r"""
The Transformer model used in [HunyuanImage-2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1). The Transformer model used in [HunyuanImage-2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1).
...@@ -667,10 +669,7 @@ class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, ...@@ -667,10 +669,7 @@ class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
"HunyuanImagePatchEmbed", "HunyuanImagePatchEmbed",
"HunyuanImageTokenRefiner", "HunyuanImageTokenRefiner",
] ]
_repeated_blocks = [ _repeated_blocks = ["HunyuanImageTransformerBlock", "HunyuanImageSingleTransformerBlock"]
"HunyuanImageTransformerBlock",
"HunyuanImageSingleTransformerBlock",
]
@register_to_config @register_to_config
def __init__( def __init__(
...@@ -743,66 +742,6 @@ class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, ...@@ -743,66 +742,6 @@ class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
self.gradient_checkpointing = False self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -20,10 +20,9 @@ from ...configuration_utils import ConfigMixin, register_to_config ...@@ -20,10 +20,9 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward, JointTransformerBlock from ..attention import AttentionMixin, FeedForward, JointTransformerBlock
from ..attention_processor import ( from ..attention_processor import (
Attention, Attention,
AttentionProcessor,
FusedJointAttnProcessor2_0, FusedJointAttnProcessor2_0,
JointAttnProcessor2_0, JointAttnProcessor2_0,
) )
...@@ -78,7 +77,7 @@ class SD3SingleTransformerBlock(nn.Module): ...@@ -78,7 +77,7 @@ class SD3SingleTransformerBlock(nn.Module):
class SD3Transformer2DModel( class SD3Transformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
): ):
""" """
The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206). The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
...@@ -214,66 +213,6 @@ class SD3Transformer2DModel( ...@@ -214,66 +213,6 @@ class SD3Transformer2DModel(
for module in self.children(): for module in self.children():
fn_recursive_feed_forward(module, None, 0) fn_recursive_feed_forward(module, None, 0)
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0 # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
def fuse_qkv_projections(self): def fuse_qkv_projections(self):
""" """
......
...@@ -22,11 +22,11 @@ from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin ...@@ -22,11 +22,11 @@ from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...loaders.single_file_model import FromOriginalModelMixin from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ..activations import get_activation from ..activations import get_activation
from ..attention import AttentionMixin
from ..attention_processor import ( from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS,
Attention, Attention,
AttentionProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
FusedAttnProcessor2_0, FusedAttnProcessor2_0,
...@@ -68,7 +68,7 @@ class UNet2DConditionOutput(BaseOutput): ...@@ -68,7 +68,7 @@ class UNet2DConditionOutput(BaseOutput):
class UNet2DConditionModel( class UNet2DConditionModel(
ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
): ):
r""" r"""
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
...@@ -701,64 +701,6 @@ class UNet2DConditionModel( ...@@ -701,64 +701,6 @@ class UNet2DConditionModel(
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
) )
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self): def set_default_attn_processor(self):
""" """
Disables custom attention processors and sets the default attention implementation. Disables custom attention processors and sets the default attention implementation.
......
...@@ -23,11 +23,11 @@ from ...configuration_utils import ConfigMixin, register_to_config ...@@ -23,11 +23,11 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin from ...loaders import UNet2DConditionLoadersMixin
from ...utils import BaseOutput, logging from ...utils import BaseOutput, logging
from ..activations import get_activation from ..activations import get_activation
from ..attention import AttentionMixin
from ..attention_processor import ( from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS,
Attention, Attention,
AttentionProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
FusedAttnProcessor2_0, FusedAttnProcessor2_0,
...@@ -58,7 +58,7 @@ class UNet3DConditionOutput(BaseOutput): ...@@ -58,7 +58,7 @@ class UNet3DConditionOutput(BaseOutput):
sample: torch.Tensor sample: torch.Tensor
class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): class UNet3DConditionModel(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r""" r"""
A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
shaped output. shaped output.
...@@ -285,31 +285,6 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -285,31 +285,6 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
) )
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
r""" r"""
...@@ -376,41 +351,6 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -376,41 +351,6 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
for module in self.children(): for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size) fn_recursive_set_attention_slice(module, reversed_slice_size)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
""" """
Sets the attention processor to use [feed forward Sets the attention processor to use [feed forward
......
...@@ -21,11 +21,10 @@ from ...configuration_utils import ConfigMixin, register_to_config ...@@ -21,11 +21,10 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin from ...loaders import UNet2DConditionLoadersMixin
from ...utils import logging from ...utils import logging
from ..activations import get_activation from ..activations import get_activation
from ..attention import Attention, FeedForward from ..attention import Attention, AttentionMixin, FeedForward
from ..attention_processor import ( from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
FusedAttnProcessor2_0, FusedAttnProcessor2_0,
...@@ -93,7 +92,7 @@ class I2VGenXLTransformerTemporalEncoder(nn.Module): ...@@ -93,7 +92,7 @@ class I2VGenXLTransformerTemporalEncoder(nn.Module):
return hidden_states return hidden_states
class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): class I2VGenXLUNet(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r""" r"""
I2VGenXL UNet. It is a conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and I2VGenXL UNet. It is a conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and
returns a sample-shaped output. returns a sample-shaped output.
...@@ -313,66 +312,6 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -313,66 +312,6 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
self.conv_act = get_activation("silu") self.conv_act = get_activation("silu")
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
""" """
......
...@@ -13,14 +13,15 @@ ...@@ -13,14 +13,15 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Tuple, Union from typing import Tuple, Union
import torch import torch
from torch import nn from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput, logging from ...utils import BaseOutput, logging
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor from ..attention import AttentionMixin
from ..attention_processor import Attention, AttnProcessor
from ..embeddings import TimestepEmbedding, Timesteps from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
...@@ -45,7 +46,7 @@ class Kandinsky3EncoderProj(nn.Module): ...@@ -45,7 +46,7 @@ class Kandinsky3EncoderProj(nn.Module):
return x return x
class Kandinsky3UNet(ModelMixin, ConfigMixin): class Kandinsky3UNet(ModelMixin, AttentionMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -140,64 +141,6 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin): ...@@ -140,64 +141,6 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
self.conv_act_out = nn.SiLU() self.conv_act_out = nn.SiLU()
self.conv_out = nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1) self.conv_out = nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1)
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self): def set_default_attn_processor(self):
""" """
Disables custom attention processors and sets the default attention implementation. Disables custom attention processors and sets the default attention implementation.
......
...@@ -23,12 +23,11 @@ from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config ...@@ -23,12 +23,11 @@ from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import BaseOutput, deprecate, logging from ...utils import BaseOutput, deprecate, logging
from ...utils.torch_utils import apply_freeu from ...utils.torch_utils import apply_freeu
from ..attention import BasicTransformerBlock from ..attention import AttentionMixin, BasicTransformerBlock
from ..attention_processor import ( from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS,
Attention, Attention,
AttentionProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
AttnProcessor2_0, AttnProcessor2_0,
...@@ -1195,7 +1194,7 @@ class MotionAdapter(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1195,7 +1194,7 @@ class MotionAdapter(ModelMixin, ConfigMixin, FromOriginalModelMixin):
pass pass
class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): class UNetMotionModel(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
r""" r"""
A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a
sample shaped output. sample shaped output.
...@@ -1754,66 +1753,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft ...@@ -1754,66 +1753,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
**kwargs, **kwargs,
) )
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
""" """
Sets the attention processor to use [feed forward Sets the attention processor to use [feed forward
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -7,7 +7,8 @@ import torch.nn as nn ...@@ -7,7 +7,8 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin from ...loaders import UNet2DConditionLoadersMixin
from ...utils import BaseOutput, logging from ...utils import BaseOutput, logging
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor from ..attention import AttentionMixin
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttnProcessor
from ..embeddings import TimestepEmbedding, Timesteps from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
...@@ -29,7 +30,7 @@ class UNetSpatioTemporalConditionOutput(BaseOutput): ...@@ -29,7 +30,7 @@ class UNetSpatioTemporalConditionOutput(BaseOutput):
sample: torch.Tensor = None sample: torch.Tensor = None
class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): class UNetSpatioTemporalConditionModel(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r""" r"""
A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and
returns a sample shaped output. returns a sample shaped output.
...@@ -245,68 +246,6 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL ...@@ -245,68 +246,6 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
padding=1, padding=1,
) )
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(
name: str,
module: torch.nn.Module,
processors: Dict[str, AttentionProcessor],
):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self): def set_default_attn_processor(self):
""" """
Disables custom attention processors and sets the default attention implementation. Disables custom attention processors and sets the default attention implementation.
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -22,11 +21,10 @@ from torch.utils.checkpoint import checkpoint ...@@ -22,11 +21,10 @@ from torch.utils.checkpoint import checkpoint
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin from ...loaders import PeftAdapterMixin
from ..attention import BasicTransformerBlock, SkipFFTransformerBlock from ..attention import AttentionMixin, BasicTransformerBlock, SkipFFTransformerBlock
from ..attention_processor import ( from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
) )
...@@ -36,7 +34,7 @@ from ..normalization import GlobalResponseNorm, RMSNorm ...@@ -36,7 +34,7 @@ from ..normalization import GlobalResponseNorm, RMSNorm
from ..resnet import Downsample2D, Upsample2D from ..resnet import Downsample2D, Upsample2D
class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): class UVit2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin):
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
@register_to_config @register_to_config
...@@ -209,66 +207,6 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -209,66 +207,6 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
return logits return logits
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self): def set_default_attn_processor(self):
""" """
......
...@@ -21,10 +21,10 @@ import torch.nn as nn ...@@ -21,10 +21,10 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin from ...loaders import UNet2DConditionLoadersMixin
from ...models.activations import get_activation from ...models.activations import get_activation
from ...models.attention import AttentionMixin
from ...models.attention_processor import ( from ...models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
) )
...@@ -163,7 +163,7 @@ class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin): ...@@ -163,7 +163,7 @@ class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin):
) )
class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): class AudioLDM2UNet2DConditionModel(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r""" r"""
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
shaped output. Compared to the vanilla [`UNet2DConditionModel`], this variant optionally includes an additional shaped output. Compared to the vanilla [`UNet2DConditionModel`], this variant optionally includes an additional
...@@ -535,66 +535,6 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad ...@@ -535,66 +535,6 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
) )
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self): def set_default_attn_processor(self):
""" """
......
...@@ -14,17 +14,16 @@ ...@@ -14,17 +14,16 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import Dict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...models.attention import AttentionMixin
from ...models.attention_processor import ( from ...models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
) )
...@@ -32,7 +31,7 @@ from ...models.modeling_utils import ModelMixin ...@@ -32,7 +31,7 @@ from ...models.modeling_utils import ModelMixin
from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): class WuerstchenPrior(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
unet_name = "prior" unet_name = "prior"
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
...@@ -61,66 +60,6 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft ...@@ -61,66 +60,6 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.set_default_attn_processor() self.set_default_attn_processor()
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self): def set_default_attn_processor(self):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment