Unverified Commit 8dba1808 authored by Vishnu V Jaddipal's avatar Vishnu V Jaddipal Committed by GitHub
Browse files

Added support to create asymmetrical U-Net structures (#5400)



* Added args, kwargs to ```U

* Add UNetMidBlock2D as a supported mid block type

* Fix extra init input for UNetMidBlock2D, change allowed types for Mid-block init

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_blocks.py

* Update unet_2d_blocks.py

* Update unet_2d_blocks.py

* Update unet_2d_condition.py

* Update unet_2d_blocks.py

* Updated docstring, increased check strictness

Updated the docstring for ```UNet2DConditionModel``` to include ```reverse_transformer_layers_per_block``` and updated checking for nested list type ```transformer_layers_per_block```

* Add basic shape-check test for asymmetrical unets

* Update src/diffusers/models/unet_2d_blocks.py

Removed blank line
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update unet_2d_condition.py

Remove blank space

* Update unet_2d_condition.py

Changed docstring for `mid_block_type`

* Fixed docstring and wrong default value

* Reformat with black

* Reformat with necessary commands

* Add UNetMidBlockFlat to versatile_diffusion/modeling_text_unet.py to ensure consistency

* Removed args, kwargs, use on mid-block type

* Make fix-copies

* Update src/diffusers/models/unet_2d_condition.py

Wrap into single line
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* make fix-copies

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 5366db5d
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch
......@@ -634,7 +634,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
......@@ -654,6 +654,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
# there is always at least one resnet
resnets = [
ResnetBlock2D(
......@@ -671,14 +675,14 @@ class UNetMidBlock2DCrossAttn(nn.Module):
]
attentions = []
for _ in range(num_layers):
for i in range(num_layers):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=transformer_layers_per_block,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
......@@ -1018,7 +1022,7 @@ class CrossAttnDownBlock2D(nn.Module):
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
......@@ -1041,6 +1045,8 @@ class CrossAttnDownBlock2D(nn.Module):
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
......@@ -1064,7 +1070,7 @@ class CrossAttnDownBlock2D(nn.Module):
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
......@@ -2167,7 +2173,7 @@ class CrossAttnUpBlock2D(nn.Module):
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
......@@ -2190,6 +2196,9 @@ class CrossAttnUpBlock2D(nn.Module):
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
......@@ -2214,7 +2223,7 @@ class CrossAttnUpBlock2D(nn.Module):
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
......
......@@ -43,6 +43,7 @@ from .embeddings import (
)
from .modeling_utils import ModelMixin
from .unet_2d_blocks import (
UNetMidBlock2D,
UNetMidBlock2DCrossAttn,
UNetMidBlock2DSimpleCrossAttn,
get_down_block,
......@@ -86,7 +87,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
The tuple of upsample blocks to use.
......@@ -105,10 +106,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features.
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
......@@ -142,9 +148,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
The dimension of `cond_proj` layer in the timestep embedding.
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
*optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
*optional*): The dimension of the `class_labels` input when
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings.
......@@ -184,7 +190,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
......@@ -265,6 +272,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
for layer_number_per_block in transformer_layers_per_block:
if isinstance(layer_number_per_block, list):
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
# input
conv_in_padding = (conv_in_kernel - 1) // 2
......@@ -500,6 +511,19 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
elif mid_block_type == "UNetMidBlock2D":
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim,
dropout=dropout,
num_layers=0,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
add_attention=False,
)
elif mid_block_type is None:
self.mid_block = None
else:
......@@ -513,7 +537,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
reversed_transformer_layers_per_block = (
list(reversed(transformer_layers_per_block))
if reverse_transformer_layers_per_block is None
else reverse_transformer_layers_per_block
)
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
......@@ -1062,14 +1090,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
else:
sample = self.mid_block(sample, emb)
# To support T2I-Adapter-XL
if (
is_adapter
......
......@@ -106,6 +106,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
......
......@@ -134,6 +134,7 @@ class AltDiffusionImg2ImgPipeline(
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
......
......@@ -324,7 +324,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
if "disable_self_attentions" in unet_params:
config["only_cross_attention"] = unet_params.disable_self_attentions
if "num_classes" in unet_params and type(unet_params.num_classes) == int:
if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
config["num_class_embeds"] = unet_params.num_classes
if controlnet:
......
......@@ -281,7 +281,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`):
The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`):
Block type for middle of UNet, it can be either `UNetMidBlockFlatCrossAttn` or
Block type for middle of UNet, it can be one of `UNetMidBlockFlatCrossAttn`, `UNetMidBlockFlat`, or
`UNetMidBlockFlatSimpleCrossAttn`. If `None`, the mid block layer is skipped.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat")`):
The tuple of upsample blocks to use.
......@@ -300,10 +300,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features.
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`],
[`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`].
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
[`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`],
[`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
......@@ -337,9 +342,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
The dimension of `cond_proj` layer in the timestep embedding.
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
*optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
*optional*): The dimension of the `class_labels` input when
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings.
......@@ -384,7 +389,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
......@@ -475,6 +481,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`:"
f" {layers_per_block}. `down_block_types`: {down_block_types}."
)
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
for layer_number_per_block in transformer_layers_per_block:
if isinstance(layer_number_per_block, list):
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
# input
conv_in_padding = (conv_in_kernel - 1) // 2
......@@ -710,6 +720,19 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
elif mid_block_type == "UNetMidBlockFlat":
self.mid_block = UNetMidBlockFlat(
in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim,
dropout=dropout,
num_layers=0,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
add_attention=False,
)
elif mid_block_type is None:
self.mid_block = None
else:
......@@ -723,7 +746,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
reversed_transformer_layers_per_block = (
list(reversed(transformer_layers_per_block))
if reverse_transformer_layers_per_block is None
else reverse_transformer_layers_per_block
)
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
......@@ -1281,14 +1308,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
else:
sample = self.mid_block(sample, emb)
# To support T2I-Adapter-XL
if (
is_adapter
......@@ -1557,7 +1588,7 @@ class CrossAttnDownBlockFlat(nn.Module):
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
......@@ -1580,6 +1611,8 @@ class CrossAttnDownBlockFlat(nn.Module):
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
......@@ -1603,7 +1636,7 @@ class CrossAttnDownBlockFlat(nn.Module):
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
......@@ -1823,7 +1856,7 @@ class CrossAttnUpBlockFlat(nn.Module):
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
......@@ -1846,6 +1879,9 @@ class CrossAttnUpBlockFlat(nn.Module):
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
......@@ -1870,7 +1906,7 @@ class CrossAttnUpBlockFlat(nn.Module):
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
......@@ -1983,6 +2019,133 @@ class CrossAttnUpBlockFlat(nn.Module):
return hidden_states
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat
class UNetMidBlockFlat(nn.Module):
"""
A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks.
Args:
in_channels (`int`): The number of input channels.
temb_channels (`int`): The number of temporal embedding channels.
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
The type of normalization to apply to the time embeddings. This can help to improve the performance of the
model on tasks with long-range temporal dependencies.
resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
resnet_pre_norm (`bool`, *optional*, defaults to `True`):
Whether to use pre-normalization for the resnet blocks.
add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
attention_head_dim (`int`, *optional*, defaults to 1):
Dimension of a single attention head. The number of attention heads is determined based on this value and
the number of input channels.
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
in_channels, height, width)`.
"""
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
attn_groups: Optional[int] = None,
resnet_pre_norm: bool = True,
add_attention: bool = True,
attention_head_dim=1,
output_scale_factor=1.0,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention
if attn_groups is None:
attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
# there is always at least one resnet
resnets = [
ResnetBlockFlat(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
if attention_head_dim is None:
logger.warn(
"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to"
f" `in_channels`: {in_channels}."
)
attention_head_dim = in_channels
for _ in range(num_layers):
if self.add_attention:
attentions.append(
Attention(
in_channels,
heads=in_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=attn_groups,
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
else:
attentions.append(None)
resnets.append(
ResnetBlockFlat(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = attn(hidden_states, temb=temb)
hidden_states = resnet(hidden_states, temb)
return hidden_states
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
class UNetMidBlockFlatCrossAttn(nn.Module):
def __init__(
......@@ -1991,7 +2154,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
......@@ -2011,6 +2174,10 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
# there is always at least one resnet
resnets = [
ResnetBlockFlat(
......@@ -2028,14 +2195,14 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
]
attentions = []
for _ in range(num_layers):
for i in range(num_layers):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=transformer_layers_per_block,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
......
......@@ -606,6 +606,22 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert (sample - sample_copy).abs().max() < 1e-4
def test_asymmetrical_unet(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
# Add asymmetry to configs
init_dict["transformer_layers_per_block"] = [[3, 2], 1]
init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1]
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
output = model(**inputs_dict).sample
expected_shape = inputs_dict["sample"].shape
# Check if input and output shapes are the same
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
@slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment