Unverified Commit 88d26946 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Correct bad attn naming (#3797)



* relax tolerance slightly

* correct incorrect naming

* correct namingc

* correct more

* Apply suggestions from code review

* Fix more

* Correct more

* correct incorrect naming

* Update src/diffusers/models/controlnet.py

* Correct flax

* Correct renaming

* Correct blocks

* Fix more

* Correct more

* mkae style

* mkae style

* mkae style

* mkae style

* mkae style

* Fix flax

* mkae style

* rename

* rename

* rename attn head dim to attention_head_dim

* correct flax

* make style

* improve

* Correct more

* make style

* fix more

* mkae style

* Update src/diffusers/models/controlnet_flax.py

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

---------
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 0c6d1bc9
......@@ -112,6 +112,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
......@@ -124,6 +125,14 @@ class ControlNetModel(ModelMixin, ConfigMixin):
):
super().__init__()
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(block_out_channels) != len(down_block_types):
raise ValueError(
......@@ -135,9 +144,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
# input
......@@ -198,6 +207,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
......@@ -221,7 +233,8 @@ class ControlNetModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i],
num_attention_heads=num_attention_heads[i],
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
......@@ -255,7 +268,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
......@@ -292,6 +305,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
norm_eps=unet.config.norm_eps,
cross_attention_dim=unet.config.cross_attention_dim,
attention_head_dim=unet.config.attention_head_dim,
num_attention_heads=unet.config.num_attention_heads,
use_linear_projection=unet.config.use_linear_projection,
class_embed_type=unet.config.class_embed_type,
num_class_embeds=unet.config.num_class_embeds,
......@@ -390,8 +404,8 @@ class ControlNetModel(ModelMixin, ConfigMixin):
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
provided, uses as many slices as `num_attention_heads // slice_size`. In this case,
`num_attention_heads` must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
......
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union
from typing import Optional, Tuple, Union
import flax
import flax.linen as nn
......@@ -129,6 +129,8 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
The number of layers per block.
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
The dimension of the attention heads.
num_attention_heads (`int` or `Tuple[int]`, *optional*):
The number of attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768):
The dimension of the cross attention features.
dropout (`float`, *optional*, defaults to 0):
......@@ -155,6 +157,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: int = 2
attention_head_dim: Union[int, Tuple[int]] = 8
num_attention_heads: Optional[Union[int, Tuple[int]]] = None
cross_attention_dim: int = 1280
dropout: float = 0.0
use_linear_projection: bool = False
......@@ -182,6 +185,14 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = self.num_attention_heads or self.attention_head_dim
# input
self.conv_in = nn.Conv(
block_out_channels[0],
......@@ -206,9 +217,8 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
if isinstance(only_cross_attention, bool):
only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
attention_head_dim = self.attention_head_dim
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(self.down_block_types)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
# down
down_blocks = []
......@@ -237,7 +247,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel,
dropout=self.dropout,
num_layers=self.layers_per_block,
attn_num_head_channels=attention_head_dim[i],
num_attention_heads=num_attention_heads[i],
add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
......@@ -285,7 +295,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
in_channels=mid_block_channel,
dropout=self.dropout,
attn_num_head_channels=attention_head_dim[-1],
num_attention_heads=num_attention_heads[-1],
use_linear_projection=self.use_linear_projection,
dtype=self.dtype,
)
......
......@@ -164,7 +164,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim,
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
)
......@@ -178,7 +178,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
attn_num_head_channels=attention_head_dim,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups,
add_attention=add_attention,
)
......@@ -204,7 +204,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim,
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.up_blocks.append(up_block)
......
This diff is collapsed.
......@@ -33,7 +33,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
num_attention_heads (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
......@@ -46,7 +46,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
out_channels: int
dropout: float = 0.0
num_layers: int = 1
attn_num_head_channels: int = 1
num_attention_heads: int = 1
add_downsample: bool = True
use_linear_projection: bool = False
only_cross_attention: bool = False
......@@ -70,8 +70,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
attn_block = FlaxTransformer2DModel(
in_channels=self.out_channels,
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
n_heads=self.num_attention_heads,
d_head=self.out_channels // self.num_attention_heads,
depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
......@@ -172,7 +172,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
num_attention_heads (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add upsampling layer before each final output
......@@ -186,7 +186,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
prev_output_channel: int
dropout: float = 0.0
num_layers: int = 1
attn_num_head_channels: int = 1
num_attention_heads: int = 1
add_upsample: bool = True
use_linear_projection: bool = False
only_cross_attention: bool = False
......@@ -211,8 +211,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
attn_block = FlaxTransformer2DModel(
in_channels=self.out_channels,
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
n_heads=self.num_attention_heads,
d_head=self.out_channels // self.num_attention_heads,
depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
......@@ -317,7 +317,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
num_attention_heads (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
......@@ -327,7 +327,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
in_channels: int
dropout: float = 0.0
num_layers: int = 1
attn_num_head_channels: int = 1
num_attention_heads: int = 1
use_linear_projection: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
......@@ -348,8 +348,8 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
for _ in range(self.num_layers):
attn_block = FlaxTransformer2DModel(
in_channels=self.in_channels,
n_heads=self.attn_num_head_channels,
d_head=self.in_channels // self.attn_num_head_channels,
n_heads=self.num_attention_heads,
d_head=self.in_channels // self.num_attention_heads,
depth=1,
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
......
......@@ -103,6 +103,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
num_attention_heads (`int`, *optional*):
The number of attention heads. If not defined, defaults to `attention_head_dim`
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None):
......@@ -169,6 +171,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
......@@ -195,6 +198,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
self.sample_size = sample_size
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
......@@ -211,6 +222,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
......@@ -353,6 +369,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = False
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
......@@ -388,7 +407,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim[i],
attn_num_head_channels=attention_head_dim[i],
num_attention_heads=num_attention_heads[i],
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
......@@ -398,6 +417,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
)
self.down_blocks.append(down_block)
......@@ -411,7 +431,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1],
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
......@@ -425,7 +445,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1],
attention_head_dim=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
......@@ -442,7 +462,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
only_cross_attention = list(reversed(only_cross_attention))
......@@ -474,7 +494,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=reversed_cross_attention_dim[i],
attn_num_head_channels=reversed_attention_head_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
......@@ -483,6 +503,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
......@@ -575,8 +596,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
provided, uses as many slices as `num_attention_heads // slice_size`. In this case,
`num_attention_heads` must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
......
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union
from typing import Optional, Tuple, Union
import flax
import flax.linen as nn
......@@ -81,6 +81,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
The number of layers per block.
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
The dimension of the attention heads.
num_attention_heads (`int` or `Tuple[int]`, *optional*):
The number of attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768):
The dimension of the cross attention features.
dropout (`float`, *optional*, defaults to 0):
......@@ -107,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: int = 2
attention_head_dim: Union[int, Tuple[int]] = 8
num_attention_heads: Optional[Union[int, Tuple[int]]] = None
cross_attention_dim: int = 1280
dropout: float = 0.0
use_linear_projection: bool = False
......@@ -131,6 +134,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = self.num_attention_heads or self.attention_head_dim
# input
self.conv_in = nn.Conv(
block_out_channels[0],
......@@ -150,9 +161,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
if isinstance(only_cross_attention, bool):
only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
attention_head_dim = self.attention_head_dim
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(self.down_block_types)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
# down
down_blocks = []
......@@ -168,7 +178,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel,
dropout=self.dropout,
num_layers=self.layers_per_block,
attn_num_head_channels=attention_head_dim[i],
num_attention_heads=num_attention_heads[i],
add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
......@@ -192,7 +202,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
dropout=self.dropout,
attn_num_head_channels=attention_head_dim[-1],
num_attention_heads=num_attention_heads[-1],
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
......@@ -201,7 +211,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
# up
up_blocks = []
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_num_attention_heads = list(reversed(num_attention_heads))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(self.up_block_types):
......@@ -217,7 +227,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel,
prev_output_channel=prev_output_channel,
num_layers=self.layers_per_block + 1,
attn_num_head_channels=reversed_attention_head_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
add_upsample=not is_final_block,
dropout=self.dropout,
use_linear_projection=self.use_linear_projection,
......
......@@ -29,7 +29,7 @@ def get_down_block(
add_downsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
num_attention_heads,
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
......@@ -66,7 +66,7 @@ def get_down_block(
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
......@@ -86,7 +86,7 @@ def get_up_block(
add_upsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
num_attention_heads,
resnet_groups=None,
cross_attention_dim=None,
dual_cross_attention=False,
......@@ -122,7 +122,7 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
......@@ -144,7 +144,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
num_attention_heads=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
dual_cross_attention=False,
......@@ -154,7 +154,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
super().__init__()
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet
......@@ -185,8 +185,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
for _ in range(num_layers):
attentions.append(
Transformer2DModel(
in_channels // attn_num_head_channels,
attn_num_head_channels,
in_channels // num_attention_heads,
num_attention_heads,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -197,8 +197,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
)
temp_attentions.append(
TransformerTemporalModel(
in_channels // attn_num_head_channels,
attn_num_head_channels,
in_channels // num_attention_heads,
num_attention_heads,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -273,7 +273,7 @@ class CrossAttnDownBlock3D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
num_attention_heads=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
downsample_padding=1,
......@@ -290,7 +290,7 @@ class CrossAttnDownBlock3D(nn.Module):
temp_convs = []
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.num_attention_heads = num_attention_heads
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
......@@ -317,8 +317,8 @@ class CrossAttnDownBlock3D(nn.Module):
)
attentions.append(
Transformer2DModel(
out_channels // attn_num_head_channels,
attn_num_head_channels,
out_channels // num_attention_heads,
num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -330,8 +330,8 @@ class CrossAttnDownBlock3D(nn.Module):
)
temp_attentions.append(
TransformerTemporalModel(
out_channels // attn_num_head_channels,
attn_num_head_channels,
out_channels // num_attention_heads,
num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -486,7 +486,7 @@ class CrossAttnUpBlock3D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
num_attention_heads=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
add_upsample=True,
......@@ -502,7 +502,7 @@ class CrossAttnUpBlock3D(nn.Module):
temp_attentions = []
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.num_attention_heads = num_attention_heads
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
......@@ -531,8 +531,8 @@ class CrossAttnUpBlock3D(nn.Module):
)
attentions.append(
Transformer2DModel(
out_channels // attn_num_head_channels,
attn_num_head_channels,
out_channels // num_attention_heads,
num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -544,8 +544,8 @@ class CrossAttnUpBlock3D(nn.Module):
)
temp_attentions.append(
TransformerTemporalModel(
out_channels // attn_num_head_channels,
attn_num_head_channels,
out_channels // num_attention_heads,
num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......
......@@ -79,6 +79,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
num_attention_heads (`int`, *optional*): The number of attention heads.
"""
_supports_gradient_checkpointing = False
......@@ -105,11 +106,20 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_eps: float = 1e-5,
cross_attention_dim: int = 1024,
attention_head_dim: Union[int, Tuple[int]] = 64,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
):
super().__init__()
self.sample_size = sample_size
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
......@@ -121,9 +131,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
# input
......@@ -156,8 +166,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
......@@ -177,7 +187,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i],
num_attention_heads=num_attention_heads[i],
downsample_padding=downsample_padding,
dual_cross_attention=False,
)
......@@ -191,7 +201,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=False,
)
......@@ -201,7 +211,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_num_attention_heads = list(reversed(num_attention_heads))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
......@@ -230,7 +240,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=reversed_attention_head_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
dual_cross_attention=False,
)
self.up_blocks.append(up_block)
......@@ -288,8 +298,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
provided, uses as many slices as `num_attention_heads // slice_size`. In this case,
`num_attention_heads` must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
......
......@@ -79,7 +79,7 @@ class Encoder(nn.Module):
downsample_padding=0,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=None,
attention_head_dim=output_channel,
temb_channels=None,
)
self.down_blocks.append(down_block)
......@@ -91,7 +91,7 @@ class Encoder(nn.Module):
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attn_num_head_channels=None,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
)
......@@ -184,7 +184,7 @@ class Decoder(nn.Module):
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
attn_num_head_channels=None,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=temb_channels,
)
......@@ -208,7 +208,7 @@ class Decoder(nn.Module):
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=None,
attention_head_dim=output_channel,
temb_channels=temb_channels,
resnet_time_scale_shift=norm_type,
)
......
......@@ -396,7 +396,7 @@ class FlaxUNetMidBlock2D(nn.Module):
Number of Resnet layer block
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
The number of groups to use for the Resnet and Attention block group norm
attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`):
num_attention_heads (:obj:`int`, *optional*, defaults to `1`):
Number of attention heads for each attention block
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
......@@ -405,7 +405,7 @@ class FlaxUNetMidBlock2D(nn.Module):
dropout: float = 0.0
num_layers: int = 1
resnet_groups: int = 32
attn_num_head_channels: int = 1
num_attention_heads: int = 1
dtype: jnp.dtype = jnp.float32
def setup(self):
......@@ -427,7 +427,7 @@ class FlaxUNetMidBlock2D(nn.Module):
for _ in range(self.num_layers):
attn_block = FlaxAttentionBlock(
channels=self.in_channels,
num_head_channels=self.attn_num_head_channels,
num_head_channels=self.num_attention_heads,
num_groups=resnet_groups,
dtype=self.dtype,
)
......@@ -532,7 +532,7 @@ class FlaxEncoder(nn.Module):
self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_groups=self.norm_num_groups,
attn_num_head_channels=None,
num_attention_heads=None,
dtype=self.dtype,
)
......@@ -625,7 +625,7 @@ class FlaxDecoder(nn.Module):
self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_groups=self.norm_num_groups,
attn_num_head_channels=None,
num_attention_heads=None,
dtype=self.dtype,
)
......
......@@ -41,7 +41,7 @@ def get_down_block(
add_downsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
num_attention_heads,
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
......@@ -82,7 +82,7 @@ def get_down_block(
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
......@@ -101,7 +101,7 @@ def get_up_block(
add_upsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
num_attention_heads,
resnet_groups=None,
cross_attention_dim=None,
dual_cross_attention=False,
......@@ -141,7 +141,7 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
......@@ -196,6 +196,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
num_attention_heads (`int`, *optional*):
The number of attention heads. If not defined, defaults to `attention_head_dim`
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None):
......@@ -267,6 +269,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
......@@ -293,6 +296,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.sample_size = sample_size
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
......@@ -312,6 +323,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
f" `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`:"
f" {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
raise ValueError(
"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`:"
......@@ -457,6 +474,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = False
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
......@@ -492,7 +512,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim[i],
attn_num_head_channels=attention_head_dim[i],
num_attention_heads=num_attention_heads[i],
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
......@@ -502,6 +522,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
)
self.down_blocks.append(down_block)
......@@ -515,7 +536,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1],
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
......@@ -529,7 +550,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1],
attention_head_dim=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
......@@ -546,7 +567,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
only_cross_attention = list(reversed(only_cross_attention))
......@@ -578,7 +599,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=reversed_cross_attention_dim[i],
attn_num_head_channels=reversed_attention_head_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
......@@ -587,6 +608,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
......@@ -679,8 +701,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
provided, uses as many slices as `num_attention_heads // slice_size`. In this case,
`num_attention_heads` must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
......@@ -1192,7 +1214,7 @@ class CrossAttnDownBlockFlat(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
num_attention_heads=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
downsample_padding=1,
......@@ -1207,7 +1229,7 @@ class CrossAttnDownBlockFlat(nn.Module):
attentions = []
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.num_attention_heads = num_attention_heads
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
......@@ -1228,8 +1250,8 @@ class CrossAttnDownBlockFlat(nn.Module):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -1242,8 +1264,8 @@ class CrossAttnDownBlockFlat(nn.Module):
else:
attentions.append(
DualTransformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -1426,7 +1448,7 @@ class CrossAttnUpBlockFlat(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
num_attention_heads=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
add_upsample=True,
......@@ -1440,7 +1462,7 @@ class CrossAttnUpBlockFlat(nn.Module):
attentions = []
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.num_attention_heads = num_attention_heads
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
......@@ -1463,8 +1485,8 @@ class CrossAttnUpBlockFlat(nn.Module):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -1477,8 +1499,8 @@ class CrossAttnUpBlockFlat(nn.Module):
else:
attentions.append(
DualTransformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -1572,7 +1594,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
num_attention_heads=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
dual_cross_attention=False,
......@@ -1582,7 +1604,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
super().__init__()
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet
......@@ -1606,8 +1628,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -1619,8 +1641,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
else:
attentions.append(
DualTransformer2DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -1682,7 +1704,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_head_dim=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
skip_time_act=False,
......@@ -1693,10 +1715,10 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.attention_head_dim = attention_head_dim
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.num_heads = in_channels // self.attn_num_head_channels
self.num_heads = in_channels // self.attention_head_dim
# there is always at least one resnet
resnets = [
......@@ -1726,7 +1748,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
query_dim=in_channels,
cross_attention_dim=in_channels,
heads=self.num_heads,
dim_head=attn_num_head_channels,
dim_head=self.attention_head_dim,
added_kv_proj_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
bias=True,
......
......@@ -59,7 +59,7 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
"block_out_channels": (32, 64),
"down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
"up_block_types": ("AttnUpBlock2D", "UpBlock2D"),
"attention_head_dim": None,
"attention_head_dim": 3,
"out_channels": 3,
"in_channels": 3,
"layers_per_block": 2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment