Unverified Commit 8dba1808 authored by Vishnu V Jaddipal's avatar Vishnu V Jaddipal Committed by GitHub
Browse files

Added support to create asymmetrical U-Net structures (#5400)



* Added args, kwargs to ```U

* Add UNetMidBlock2D as a supported mid block type

* Fix extra init input for UNetMidBlock2D, change allowed types for Mid-block init

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_blocks.py

* Update unet_2d_blocks.py

* Update unet_2d_blocks.py

* Update unet_2d_condition.py

* Update unet_2d_blocks.py

* Updated docstring, increased check strictness

Updated the docstring for ```UNet2DConditionModel``` to include ```reverse_transformer_layers_per_block``` and updated checking for nested list type ```transformer_layers_per_block```

* Add basic shape-check test for asymmetrical unets

* Update src/diffusers/models/unet_2d_blocks.py

Removed blank line
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update unet_2d_condition.py

Remove blank space

* Update unet_2d_condition.py

Changed docstring for `mid_block_type`

* Fixed docstring and wrong default value

* Reformat with black

* Reformat with necessary commands

* Add UNetMidBlockFlat to versatile_diffusion/modeling_text_unet.py to ensure consistency

* Removed args, kwargs, use on mid-block type

* Make fix-copies

* Update src/diffusers/models/unet_2d_condition.py

Wrap into single line
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* make fix-copies

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 5366db5d
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -634,7 +634,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -634,7 +634,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
temb_channels: int, temb_channels: int,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
transformer_layers_per_block: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
...@@ -654,6 +654,10 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -654,6 +654,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
# there is always at least one resnet # there is always at least one resnet
resnets = [ resnets = [
ResnetBlock2D( ResnetBlock2D(
...@@ -671,14 +675,14 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -671,14 +675,14 @@ class UNetMidBlock2DCrossAttn(nn.Module):
] ]
attentions = [] attentions = []
for _ in range(num_layers): for i in range(num_layers):
if not dual_cross_attention: if not dual_cross_attention:
attentions.append( attentions.append(
Transformer2DModel( Transformer2DModel(
num_attention_heads, num_attention_heads,
in_channels // num_attention_heads, in_channels // num_attention_heads,
in_channels=in_channels, in_channels=in_channels,
num_layers=transformer_layers_per_block, num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
...@@ -1018,7 +1022,7 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -1018,7 +1022,7 @@ class CrossAttnDownBlock2D(nn.Module):
temb_channels: int, temb_channels: int,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
transformer_layers_per_block: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
...@@ -1041,6 +1045,8 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -1041,6 +1045,8 @@ class CrossAttnDownBlock2D(nn.Module):
self.has_cross_attention = True self.has_cross_attention = True
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
for i in range(num_layers): for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels in_channels = in_channels if i == 0 else out_channels
...@@ -1064,7 +1070,7 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -1064,7 +1070,7 @@ class CrossAttnDownBlock2D(nn.Module):
num_attention_heads, num_attention_heads,
out_channels // num_attention_heads, out_channels // num_attention_heads,
in_channels=out_channels, in_channels=out_channels,
num_layers=transformer_layers_per_block, num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
...@@ -2167,7 +2173,7 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -2167,7 +2173,7 @@ class CrossAttnUpBlock2D(nn.Module):
resolution_idx: int = None, resolution_idx: int = None,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
transformer_layers_per_block: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
...@@ -2190,6 +2196,9 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -2190,6 +2196,9 @@ class CrossAttnUpBlock2D(nn.Module):
self.has_cross_attention = True self.has_cross_attention = True
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
for i in range(num_layers): for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels
...@@ -2214,7 +2223,7 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -2214,7 +2223,7 @@ class CrossAttnUpBlock2D(nn.Module):
num_attention_heads, num_attention_heads,
out_channels // num_attention_heads, out_channels // num_attention_heads,
in_channels=out_channels, in_channels=out_channels,
num_layers=transformer_layers_per_block, num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
......
...@@ -43,6 +43,7 @@ from .embeddings import ( ...@@ -43,6 +43,7 @@ from .embeddings import (
) )
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .unet_2d_blocks import ( from .unet_2d_blocks import (
UNetMidBlock2D,
UNetMidBlock2DCrossAttn, UNetMidBlock2DCrossAttn,
UNetMidBlock2DSimpleCrossAttn, UNetMidBlock2DSimpleCrossAttn,
get_down_block, get_down_block,
...@@ -86,7 +87,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -86,7 +87,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use. The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
The tuple of upsample blocks to use. The tuple of upsample blocks to use.
...@@ -105,10 +106,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -105,10 +106,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features. The dimension of the cross attention features.
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None): encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`. dimension to `cross_attention_dim`.
...@@ -142,9 +148,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -142,9 +148,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, defaults to `None`): time_cond_proj_dim (`int`, *optional*, defaults to `None`):
The dimension of `cond_proj` layer in the timestep embedding. The dimension of `cond_proj` layer in the timestep embedding.
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when *optional*): The dimension of the `class_labels` input when
`class_embed_type="projection"`. Required when `class_embed_type="projection"`. `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings. embeddings with the class embeddings.
...@@ -184,7 +190,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -184,7 +190,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_num_groups: Optional[int] = 32, norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280, cross_attention_dim: Union[int, Tuple[int]] = 1280,
transformer_layers_per_block: Union[int, Tuple[int]] = 1, transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
encoder_hid_dim: Optional[int] = None, encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None, encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8, attention_head_dim: Union[int, Tuple[int]] = 8,
...@@ -265,6 +272,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -265,6 +272,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
raise ValueError( raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
) )
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
for layer_number_per_block in transformer_layers_per_block:
if isinstance(layer_number_per_block, list):
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
# input # input
conv_in_padding = (conv_in_kernel - 1) // 2 conv_in_padding = (conv_in_kernel - 1) // 2
...@@ -500,6 +511,19 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -500,6 +511,19 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
only_cross_attention=mid_block_only_cross_attention, only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm, cross_attention_norm=cross_attention_norm,
) )
elif mid_block_type == "UNetMidBlock2D":
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim,
dropout=dropout,
num_layers=0,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
add_attention=False,
)
elif mid_block_type is None: elif mid_block_type is None:
self.mid_block = None self.mid_block = None
else: else:
...@@ -513,7 +537,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -513,7 +537,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block)) reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) reversed_transformer_layers_per_block = (
list(reversed(transformer_layers_per_block))
if reverse_transformer_layers_per_block is None
else reverse_transformer_layers_per_block
)
only_cross_attention = list(reversed(only_cross_attention)) only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[0]
...@@ -1062,14 +1090,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -1062,14 +1090,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# 4. mid # 4. mid
if self.mid_block is not None: if self.mid_block is not None:
sample = self.mid_block( if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
sample, sample = self.mid_block(
emb, sample,
encoder_hidden_states=encoder_hidden_states, emb,
attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask, cross_attention_kwargs=cross_attention_kwargs,
) encoder_attention_mask=encoder_attention_mask,
)
else:
sample = self.mid_block(sample, emb)
# To support T2I-Adapter-XL # To support T2I-Adapter-XL
if ( if (
is_adapter is_adapter
......
...@@ -106,6 +106,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -106,6 +106,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
feature_extractor ([`~transformers.CLIPImageProcessor`]): feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
""" """
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
......
...@@ -134,6 +134,7 @@ class AltDiffusionImg2ImgPipeline( ...@@ -134,6 +134,7 @@ class AltDiffusionImg2ImgPipeline(
feature_extractor ([`~transformers.CLIPImageProcessor`]): feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
""" """
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
......
...@@ -324,7 +324,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa ...@@ -324,7 +324,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
if "disable_self_attentions" in unet_params: if "disable_self_attentions" in unet_params:
config["only_cross_attention"] = unet_params.disable_self_attentions config["only_cross_attention"] = unet_params.disable_self_attentions
if "num_classes" in unet_params and type(unet_params.num_classes) == int: if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
config["num_class_embeds"] = unet_params.num_classes config["num_class_embeds"] = unet_params.num_classes
if controlnet: if controlnet:
......
...@@ -281,7 +281,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -281,7 +281,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`):
The tuple of downsample blocks to use. The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`): mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`):
Block type for middle of UNet, it can be either `UNetMidBlockFlatCrossAttn` or Block type for middle of UNet, it can be one of `UNetMidBlockFlatCrossAttn`, `UNetMidBlockFlat`, or
`UNetMidBlockFlatSimpleCrossAttn`. If `None`, the mid block layer is skipped. `UNetMidBlockFlatSimpleCrossAttn`. If `None`, the mid block layer is skipped.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat")`): up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat")`):
The tuple of upsample blocks to use. The tuple of upsample blocks to use.
...@@ -300,10 +300,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -300,10 +300,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features. The dimension of the cross attention features.
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`], [`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`],
[`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`]. [`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`].
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
[`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`],
[`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None): encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`. dimension to `cross_attention_dim`.
...@@ -337,9 +342,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -337,9 +342,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, defaults to `None`): time_cond_proj_dim (`int`, *optional*, defaults to `None`):
The dimension of `cond_proj` layer in the timestep embedding. The dimension of `cond_proj` layer in the timestep embedding.
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when *optional*): The dimension of the `class_labels` input when
`class_embed_type="projection"`. Required when `class_embed_type="projection"`. `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings. embeddings with the class embeddings.
...@@ -384,7 +389,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -384,7 +389,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
norm_num_groups: Optional[int] = 32, norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280, cross_attention_dim: Union[int, Tuple[int]] = 1280,
transformer_layers_per_block: Union[int, Tuple[int]] = 1, transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
encoder_hid_dim: Optional[int] = None, encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None, encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8, attention_head_dim: Union[int, Tuple[int]] = 8,
...@@ -475,6 +481,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -475,6 +481,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`:" "Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`:"
f" {layers_per_block}. `down_block_types`: {down_block_types}." f" {layers_per_block}. `down_block_types`: {down_block_types}."
) )
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
for layer_number_per_block in transformer_layers_per_block:
if isinstance(layer_number_per_block, list):
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
# input # input
conv_in_padding = (conv_in_kernel - 1) // 2 conv_in_padding = (conv_in_kernel - 1) // 2
...@@ -710,6 +720,19 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -710,6 +720,19 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
only_cross_attention=mid_block_only_cross_attention, only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm, cross_attention_norm=cross_attention_norm,
) )
elif mid_block_type == "UNetMidBlockFlat":
self.mid_block = UNetMidBlockFlat(
in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim,
dropout=dropout,
num_layers=0,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
add_attention=False,
)
elif mid_block_type is None: elif mid_block_type is None:
self.mid_block = None self.mid_block = None
else: else:
...@@ -723,7 +746,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -723,7 +746,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block)) reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) reversed_transformer_layers_per_block = (
list(reversed(transformer_layers_per_block))
if reverse_transformer_layers_per_block is None
else reverse_transformer_layers_per_block
)
only_cross_attention = list(reversed(only_cross_attention)) only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[0]
...@@ -1281,14 +1308,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1281,14 +1308,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# 4. mid # 4. mid
if self.mid_block is not None: if self.mid_block is not None:
sample = self.mid_block( if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
sample, sample = self.mid_block(
emb, sample,
encoder_hidden_states=encoder_hidden_states, emb,
attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask, cross_attention_kwargs=cross_attention_kwargs,
) encoder_attention_mask=encoder_attention_mask,
)
else:
sample = self.mid_block(sample, emb)
# To support T2I-Adapter-XL # To support T2I-Adapter-XL
if ( if (
is_adapter is_adapter
...@@ -1557,7 +1588,7 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1557,7 +1588,7 @@ class CrossAttnDownBlockFlat(nn.Module):
temb_channels: int, temb_channels: int,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
transformer_layers_per_block: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
...@@ -1580,6 +1611,8 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1580,6 +1611,8 @@ class CrossAttnDownBlockFlat(nn.Module):
self.has_cross_attention = True self.has_cross_attention = True
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
for i in range(num_layers): for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels in_channels = in_channels if i == 0 else out_channels
...@@ -1603,7 +1636,7 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1603,7 +1636,7 @@ class CrossAttnDownBlockFlat(nn.Module):
num_attention_heads, num_attention_heads,
out_channels // num_attention_heads, out_channels // num_attention_heads,
in_channels=out_channels, in_channels=out_channels,
num_layers=transformer_layers_per_block, num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
...@@ -1823,7 +1856,7 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1823,7 +1856,7 @@ class CrossAttnUpBlockFlat(nn.Module):
resolution_idx: int = None, resolution_idx: int = None,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
transformer_layers_per_block: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
...@@ -1846,6 +1879,9 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1846,6 +1879,9 @@ class CrossAttnUpBlockFlat(nn.Module):
self.has_cross_attention = True self.has_cross_attention = True
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
for i in range(num_layers): for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels
...@@ -1870,7 +1906,7 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1870,7 +1906,7 @@ class CrossAttnUpBlockFlat(nn.Module):
num_attention_heads, num_attention_heads,
out_channels // num_attention_heads, out_channels // num_attention_heads,
in_channels=out_channels, in_channels=out_channels,
num_layers=transformer_layers_per_block, num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
...@@ -1983,6 +2019,133 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1983,6 +2019,133 @@ class CrossAttnUpBlockFlat(nn.Module):
return hidden_states return hidden_states
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat
class UNetMidBlockFlat(nn.Module):
"""
A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks.
Args:
in_channels (`int`): The number of input channels.
temb_channels (`int`): The number of temporal embedding channels.
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
The type of normalization to apply to the time embeddings. This can help to improve the performance of the
model on tasks with long-range temporal dependencies.
resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
resnet_pre_norm (`bool`, *optional*, defaults to `True`):
Whether to use pre-normalization for the resnet blocks.
add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
attention_head_dim (`int`, *optional*, defaults to 1):
Dimension of a single attention head. The number of attention heads is determined based on this value and
the number of input channels.
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
in_channels, height, width)`.
"""
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
attn_groups: Optional[int] = None,
resnet_pre_norm: bool = True,
add_attention: bool = True,
attention_head_dim=1,
output_scale_factor=1.0,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention
if attn_groups is None:
attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
# there is always at least one resnet
resnets = [
ResnetBlockFlat(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
if attention_head_dim is None:
logger.warn(
"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to"
f" `in_channels`: {in_channels}."
)
attention_head_dim = in_channels
for _ in range(num_layers):
if self.add_attention:
attentions.append(
Attention(
in_channels,
heads=in_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=attn_groups,
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
else:
attentions.append(None)
resnets.append(
ResnetBlockFlat(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = attn(hidden_states, temb=temb)
hidden_states = resnet(hidden_states, temb)
return hidden_states
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat # Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
class UNetMidBlockFlatCrossAttn(nn.Module): class UNetMidBlockFlatCrossAttn(nn.Module):
def __init__( def __init__(
...@@ -1991,7 +2154,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -1991,7 +2154,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
temb_channels: int, temb_channels: int,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
transformer_layers_per_block: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
...@@ -2011,6 +2174,10 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -2011,6 +2174,10 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
# there is always at least one resnet # there is always at least one resnet
resnets = [ resnets = [
ResnetBlockFlat( ResnetBlockFlat(
...@@ -2028,14 +2195,14 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -2028,14 +2195,14 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
] ]
attentions = [] attentions = []
for _ in range(num_layers): for i in range(num_layers):
if not dual_cross_attention: if not dual_cross_attention:
attentions.append( attentions.append(
Transformer2DModel( Transformer2DModel(
num_attention_heads, num_attention_heads,
in_channels // num_attention_heads, in_channels // num_attention_heads,
in_channels=in_channels, in_channels=in_channels,
num_layers=transformer_layers_per_block, num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
......
...@@ -606,6 +606,22 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -606,6 +606,22 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert (sample - sample_copy).abs().max() < 1e-4 assert (sample - sample_copy).abs().max() < 1e-4
def test_asymmetrical_unet(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
# Add asymmetry to configs
init_dict["transformer_layers_per_block"] = [[3, 2], 1]
init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1]
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
output = model(**inputs_dict).sample
expected_shape = inputs_dict["sample"].shape
# Check if input and output shapes are the same
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
@slow @slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase): class UNet2DConditionModelIntegrationTests(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment