Unverified Commit a8315ce1 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[UNet3DModel] Fix with attn processor (#2790)

* [UNet3DModel] Fix attn processor

* make style
parent 0d633a42
...@@ -21,6 +21,7 @@ import torch.utils.checkpoint ...@@ -21,6 +21,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor
from .embeddings import TimestepEmbedding, Timesteps from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .transformer_temporal import TransformerTemporalModel from .transformer_temporal import TransformerTemporalModel
...@@ -249,6 +250,32 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): ...@@ -249,6 +250,32 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
) )
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
...@@ -259,24 +286,24 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): ...@@ -259,24 +286,24 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
Args: Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`. must be a multiple of `slice_size`.
""" """
sliceable_head_dims = [] sliceable_head_dims = []
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"): if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim) sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children(): for child in module.children():
fn_recursive_retrieve_slicable_dims(child) fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers # retrieve number of attention layers
for module in self.children(): for module in self.children():
fn_recursive_retrieve_slicable_dims(module) fn_recursive_retrieve_sliceable_dims(module)
num_slicable_layers = len(sliceable_head_dims) num_sliceable_layers = len(sliceable_head_dims)
if slice_size == "auto": if slice_size == "auto":
# half the attention head size is usually a good trade-off between # half the attention head size is usually a good trade-off between
...@@ -284,9 +311,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): ...@@ -284,9 +311,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
slice_size = [dim // 2 for dim in sliceable_head_dims] slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max": elif slice_size == "max":
# make smallest slice possible # make smallest slice possible
slice_size = num_slicable_layers * [1] slice_size = num_sliceable_layers * [1]
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims): if len(slice_size) != len(sliceable_head_dims):
raise ValueError( raise ValueError(
...@@ -314,6 +341,37 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): ...@@ -314,6 +341,37 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
for module in self.children(): for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size) fn_recursive_set_attention_slice(module, reversed_slice_size)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Parameters:
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value module.gradient_checkpointing = value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment