Unverified Commit 88d26946 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Correct bad attn naming (#3797)



* relax tolerance slightly

* correct incorrect naming

* correct namingc

* correct more

* Apply suggestions from code review

* Fix more

* Correct more

* correct incorrect naming

* Update src/diffusers/models/controlnet.py

* Correct flax

* Correct renaming

* Correct blocks

* Fix more

* Correct more

* mkae style

* mkae style

* mkae style

* mkae style

* mkae style

* Fix flax

* mkae style

* rename

* rename

* rename attn head dim to attention_head_dim

* correct flax

* make style

* improve

* Correct more

* make style

* fix more

* mkae style

* Update src/diffusers/models/controlnet_flax.py

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

---------
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 0c6d1bc9
...@@ -112,6 +112,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -112,6 +112,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
cross_attention_dim: int = 1280, cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8, attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
use_linear_projection: bool = False, use_linear_projection: bool = False,
class_embed_type: Optional[str] = None, class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None, num_class_embeds: Optional[int] = None,
...@@ -124,6 +125,14 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -124,6 +125,14 @@ class ControlNetModel(ModelMixin, ConfigMixin):
): ):
super().__init__() super().__init__()
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs # Check inputs
if len(block_out_channels) != len(down_block_types): if len(block_out_channels) != len(down_block_types):
raise ValueError( raise ValueError(
...@@ -135,9 +144,9 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -135,9 +144,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
) )
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError( raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
) )
# input # input
...@@ -198,6 +207,9 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -198,6 +207,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
if isinstance(attention_head_dim, int): if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types) attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
# down # down
output_channel = block_out_channels[0] output_channel = block_out_channels[0]
...@@ -221,7 +233,8 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -221,7 +233,8 @@ class ControlNetModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i], num_attention_heads=num_attention_heads[i],
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
...@@ -255,7 +268,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -255,7 +268,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
output_scale_factor=mid_block_scale_factor, output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1], num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
...@@ -292,6 +305,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -292,6 +305,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
norm_eps=unet.config.norm_eps, norm_eps=unet.config.norm_eps,
cross_attention_dim=unet.config.cross_attention_dim, cross_attention_dim=unet.config.cross_attention_dim,
attention_head_dim=unet.config.attention_head_dim, attention_head_dim=unet.config.attention_head_dim,
num_attention_heads=unet.config.num_attention_heads,
use_linear_projection=unet.config.use_linear_projection, use_linear_projection=unet.config.use_linear_projection,
class_embed_type=unet.config.class_embed_type, class_embed_type=unet.config.class_embed_type,
num_class_embeds=unet.config.num_class_embeds, num_class_embeds=unet.config.num_class_embeds,
...@@ -390,8 +404,8 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -390,8 +404,8 @@ class ControlNetModel(ModelMixin, ConfigMixin):
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` provided, uses as many slices as `num_attention_heads // slice_size`. In this case,
must be a multiple of `slice_size`. `num_attention_heads` must be a multiple of `slice_size`.
""" """
sliceable_head_dims = [] sliceable_head_dims = []
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Tuple, Union from typing import Optional, Tuple, Union
import flax import flax
import flax.linen as nn import flax.linen as nn
...@@ -129,6 +129,8 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -129,6 +129,8 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
The number of layers per block. The number of layers per block.
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8): attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
The dimension of the attention heads. The dimension of the attention heads.
num_attention_heads (`int` or `Tuple[int]`, *optional*):
The number of attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768): cross_attention_dim (`int`, *optional*, defaults to 768):
The dimension of the cross attention features. The dimension of the cross attention features.
dropout (`float`, *optional*, defaults to 0): dropout (`float`, *optional*, defaults to 0):
...@@ -155,6 +157,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -155,6 +157,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
block_out_channels: Tuple[int] = (320, 640, 1280, 1280) block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: int = 2 layers_per_block: int = 2
attention_head_dim: Union[int, Tuple[int]] = 8 attention_head_dim: Union[int, Tuple[int]] = 8
num_attention_heads: Optional[Union[int, Tuple[int]]] = None
cross_attention_dim: int = 1280 cross_attention_dim: int = 1280
dropout: float = 0.0 dropout: float = 0.0
use_linear_projection: bool = False use_linear_projection: bool = False
...@@ -182,6 +185,14 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -182,6 +185,14 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
block_out_channels = self.block_out_channels block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4 time_embed_dim = block_out_channels[0] * 4
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = self.num_attention_heads or self.attention_head_dim
# input # input
self.conv_in = nn.Conv( self.conv_in = nn.Conv(
block_out_channels[0], block_out_channels[0],
...@@ -206,9 +217,8 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -206,9 +217,8 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
if isinstance(only_cross_attention, bool): if isinstance(only_cross_attention, bool):
only_cross_attention = (only_cross_attention,) * len(self.down_block_types) only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
attention_head_dim = self.attention_head_dim if isinstance(num_attention_heads, int):
if isinstance(attention_head_dim, int): num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
attention_head_dim = (attention_head_dim,) * len(self.down_block_types)
# down # down
down_blocks = [] down_blocks = []
...@@ -237,7 +247,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -237,7 +247,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel, out_channels=output_channel,
dropout=self.dropout, dropout=self.dropout,
num_layers=self.layers_per_block, num_layers=self.layers_per_block,
attn_num_head_channels=attention_head_dim[i], num_attention_heads=num_attention_heads[i],
add_downsample=not is_final_block, add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
...@@ -285,7 +295,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -285,7 +295,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
self.mid_block = FlaxUNetMidBlock2DCrossAttn( self.mid_block = FlaxUNetMidBlock2DCrossAttn(
in_channels=mid_block_channel, in_channels=mid_block_channel,
dropout=self.dropout, dropout=self.dropout,
attn_num_head_channels=attention_head_dim[-1], num_attention_heads=num_attention_heads[-1],
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
dtype=self.dtype, dtype=self.dtype,
) )
......
...@@ -164,7 +164,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -164,7 +164,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
resnet_eps=norm_eps, resnet_eps=norm_eps,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim, attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
) )
...@@ -178,7 +178,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -178,7 +178,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor, output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
attn_num_head_channels=attention_head_dim, attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
add_attention=add_attention, add_attention=add_attention,
) )
...@@ -204,7 +204,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -204,7 +204,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
resnet_eps=norm_eps, resnet_eps=norm_eps,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim, attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
......
This diff is collapsed.
...@@ -33,7 +33,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): ...@@ -33,7 +33,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
Dropout rate Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1): num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): num_attention_heads (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block Number of attention heads of each spatial transformer block
add_downsample (:obj:`bool`, *optional*, defaults to `True`): add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output Whether to add downsampling layer before each final output
...@@ -46,7 +46,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): ...@@ -46,7 +46,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
out_channels: int out_channels: int
dropout: float = 0.0 dropout: float = 0.0
num_layers: int = 1 num_layers: int = 1
attn_num_head_channels: int = 1 num_attention_heads: int = 1
add_downsample: bool = True add_downsample: bool = True
use_linear_projection: bool = False use_linear_projection: bool = False
only_cross_attention: bool = False only_cross_attention: bool = False
...@@ -70,8 +70,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module): ...@@ -70,8 +70,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
attn_block = FlaxTransformer2DModel( attn_block = FlaxTransformer2DModel(
in_channels=self.out_channels, in_channels=self.out_channels,
n_heads=self.attn_num_head_channels, n_heads=self.num_attention_heads,
d_head=self.out_channels // self.attn_num_head_channels, d_head=self.out_channels // self.num_attention_heads,
depth=1, depth=1,
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention, only_cross_attention=self.only_cross_attention,
...@@ -172,7 +172,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): ...@@ -172,7 +172,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
Dropout rate Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1): num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): num_attention_heads (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block Number of attention heads of each spatial transformer block
add_upsample (:obj:`bool`, *optional*, defaults to `True`): add_upsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add upsampling layer before each final output Whether to add upsampling layer before each final output
...@@ -186,7 +186,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): ...@@ -186,7 +186,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
prev_output_channel: int prev_output_channel: int
dropout: float = 0.0 dropout: float = 0.0
num_layers: int = 1 num_layers: int = 1
attn_num_head_channels: int = 1 num_attention_heads: int = 1
add_upsample: bool = True add_upsample: bool = True
use_linear_projection: bool = False use_linear_projection: bool = False
only_cross_attention: bool = False only_cross_attention: bool = False
...@@ -211,8 +211,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module): ...@@ -211,8 +211,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
attn_block = FlaxTransformer2DModel( attn_block = FlaxTransformer2DModel(
in_channels=self.out_channels, in_channels=self.out_channels,
n_heads=self.attn_num_head_channels, n_heads=self.num_attention_heads,
d_head=self.out_channels // self.attn_num_head_channels, d_head=self.out_channels // self.num_attention_heads,
depth=1, depth=1,
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention, only_cross_attention=self.only_cross_attention,
...@@ -317,7 +317,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): ...@@ -317,7 +317,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
Dropout rate Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1): num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): num_attention_heads (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block Number of attention heads of each spatial transformer block
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682 enable memory efficient attention https://arxiv.org/abs/2112.05682
...@@ -327,7 +327,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): ...@@ -327,7 +327,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
in_channels: int in_channels: int
dropout: float = 0.0 dropout: float = 0.0
num_layers: int = 1 num_layers: int = 1
attn_num_head_channels: int = 1 num_attention_heads: int = 1
use_linear_projection: bool = False use_linear_projection: bool = False
use_memory_efficient_attention: bool = False use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
...@@ -348,8 +348,8 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): ...@@ -348,8 +348,8 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
for _ in range(self.num_layers): for _ in range(self.num_layers):
attn_block = FlaxTransformer2DModel( attn_block = FlaxTransformer2DModel(
in_channels=self.in_channels, in_channels=self.in_channels,
n_heads=self.attn_num_head_channels, n_heads=self.num_attention_heads,
d_head=self.in_channels // self.attn_num_head_channels, d_head=self.in_channels // self.num_attention_heads,
depth=1, depth=1,
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention, use_memory_efficient_attention=self.use_memory_efficient_attention,
......
...@@ -103,6 +103,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -103,6 +103,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
num_attention_heads (`int`, *optional*):
The number of attention heads. If not defined, defaults to `attention_head_dim`
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): class_embed_type (`str`, *optional*, defaults to None):
...@@ -169,6 +171,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -169,6 +171,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
encoder_hid_dim: Optional[int] = None, encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None, encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8, attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
dual_cross_attention: bool = False, dual_cross_attention: bool = False,
use_linear_projection: bool = False, use_linear_projection: bool = False,
class_embed_type: Optional[str] = None, class_embed_type: Optional[str] = None,
...@@ -195,6 +198,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -195,6 +198,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
self.sample_size = sample_size self.sample_size = sample_size
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs # Check inputs
if len(down_block_types) != len(up_block_types): if len(down_block_types) != len(up_block_types):
raise ValueError( raise ValueError(
...@@ -211,6 +222,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -211,6 +222,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
) )
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
raise ValueError( raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
...@@ -353,6 +369,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -353,6 +369,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if mid_block_only_cross_attention is None: if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = False mid_block_only_cross_attention = False
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
if isinstance(attention_head_dim, int): if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types) attention_head_dim = (attention_head_dim,) * len(down_block_types)
...@@ -388,7 +407,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -388,7 +407,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim[i], cross_attention_dim=cross_attention_dim[i],
attn_num_head_channels=attention_head_dim[i], num_attention_heads=num_attention_heads[i],
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
...@@ -398,6 +417,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -398,6 +417,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_skip_time_act=resnet_skip_time_act, resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor, resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm, cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
) )
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
...@@ -411,7 +431,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -411,7 +431,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
output_scale_factor=mid_block_scale_factor, output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim[-1], cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1], num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
...@@ -425,7 +445,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -425,7 +445,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor, output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim[-1], cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1], attention_head_dim=attention_head_dim[-1],
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act, skip_time_act=resnet_skip_time_act,
...@@ -442,7 +462,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -442,7 +462,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# up # up
reversed_block_out_channels = list(reversed(block_out_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim)) reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block)) reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_cross_attention_dim = list(reversed(cross_attention_dim))
only_cross_attention = list(reversed(only_cross_attention)) only_cross_attention = list(reversed(only_cross_attention))
...@@ -474,7 +494,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -474,7 +494,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
cross_attention_dim=reversed_cross_attention_dim[i], cross_attention_dim=reversed_cross_attention_dim[i],
attn_num_head_channels=reversed_attention_head_dim[i], num_attention_heads=reversed_num_attention_heads[i],
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
...@@ -483,6 +503,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -483,6 +503,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_skip_time_act=resnet_skip_time_act, resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor, resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm, cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
...@@ -575,8 +596,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -575,8 +596,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` provided, uses as many slices as `num_attention_heads // slice_size`. In this case,
must be a multiple of `slice_size`. `num_attention_heads` must be a multiple of `slice_size`.
""" """
sliceable_head_dims = [] sliceable_head_dims = []
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Tuple, Union from typing import Optional, Tuple, Union
import flax import flax
import flax.linen as nn import flax.linen as nn
...@@ -81,6 +81,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -81,6 +81,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
The number of layers per block. The number of layers per block.
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8): attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
The dimension of the attention heads. The dimension of the attention heads.
num_attention_heads (`int` or `Tuple[int]`, *optional*):
The number of attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768): cross_attention_dim (`int`, *optional*, defaults to 768):
The dimension of the cross attention features. The dimension of the cross attention features.
dropout (`float`, *optional*, defaults to 0): dropout (`float`, *optional*, defaults to 0):
...@@ -107,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -107,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
block_out_channels: Tuple[int] = (320, 640, 1280, 1280) block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: int = 2 layers_per_block: int = 2
attention_head_dim: Union[int, Tuple[int]] = 8 attention_head_dim: Union[int, Tuple[int]] = 8
num_attention_heads: Optional[Union[int, Tuple[int]]] = None
cross_attention_dim: int = 1280 cross_attention_dim: int = 1280
dropout: float = 0.0 dropout: float = 0.0
use_linear_projection: bool = False use_linear_projection: bool = False
...@@ -131,6 +134,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -131,6 +134,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
block_out_channels = self.block_out_channels block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4 time_embed_dim = block_out_channels[0] * 4
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = self.num_attention_heads or self.attention_head_dim
# input # input
self.conv_in = nn.Conv( self.conv_in = nn.Conv(
block_out_channels[0], block_out_channels[0],
...@@ -150,9 +161,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -150,9 +161,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
if isinstance(only_cross_attention, bool): if isinstance(only_cross_attention, bool):
only_cross_attention = (only_cross_attention,) * len(self.down_block_types) only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
attention_head_dim = self.attention_head_dim if isinstance(num_attention_heads, int):
if isinstance(attention_head_dim, int): num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
attention_head_dim = (attention_head_dim,) * len(self.down_block_types)
# down # down
down_blocks = [] down_blocks = []
...@@ -168,7 +178,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -168,7 +178,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel, out_channels=output_channel,
dropout=self.dropout, dropout=self.dropout,
num_layers=self.layers_per_block, num_layers=self.layers_per_block,
attn_num_head_channels=attention_head_dim[i], num_attention_heads=num_attention_heads[i],
add_downsample=not is_final_block, add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
...@@ -192,7 +202,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -192,7 +202,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
self.mid_block = FlaxUNetMidBlock2DCrossAttn( self.mid_block = FlaxUNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1], in_channels=block_out_channels[-1],
dropout=self.dropout, dropout=self.dropout,
attn_num_head_channels=attention_head_dim[-1], num_attention_heads=num_attention_heads[-1],
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention, use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype, dtype=self.dtype,
...@@ -201,7 +211,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -201,7 +211,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
# up # up
up_blocks = [] up_blocks = []
reversed_block_out_channels = list(reversed(block_out_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim)) reversed_num_attention_heads = list(reversed(num_attention_heads))
only_cross_attention = list(reversed(only_cross_attention)) only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(self.up_block_types): for i, up_block_type in enumerate(self.up_block_types):
...@@ -217,7 +227,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -217,7 +227,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel, out_channels=output_channel,
prev_output_channel=prev_output_channel, prev_output_channel=prev_output_channel,
num_layers=self.layers_per_block + 1, num_layers=self.layers_per_block + 1,
attn_num_head_channels=reversed_attention_head_dim[i], num_attention_heads=reversed_num_attention_heads[i],
add_upsample=not is_final_block, add_upsample=not is_final_block,
dropout=self.dropout, dropout=self.dropout,
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
......
...@@ -29,7 +29,7 @@ def get_down_block( ...@@ -29,7 +29,7 @@ def get_down_block(
add_downsample, add_downsample,
resnet_eps, resnet_eps,
resnet_act_fn, resnet_act_fn,
attn_num_head_channels, num_attention_heads,
resnet_groups=None, resnet_groups=None,
cross_attention_dim=None, cross_attention_dim=None,
downsample_padding=None, downsample_padding=None,
...@@ -66,7 +66,7 @@ def get_down_block( ...@@ -66,7 +66,7 @@ def get_down_block(
resnet_groups=resnet_groups, resnet_groups=resnet_groups,
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels, num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention, only_cross_attention=only_cross_attention,
...@@ -86,7 +86,7 @@ def get_up_block( ...@@ -86,7 +86,7 @@ def get_up_block(
add_upsample, add_upsample,
resnet_eps, resnet_eps,
resnet_act_fn, resnet_act_fn,
attn_num_head_channels, num_attention_heads,
resnet_groups=None, resnet_groups=None,
cross_attention_dim=None, cross_attention_dim=None,
dual_cross_attention=False, dual_cross_attention=False,
...@@ -122,7 +122,7 @@ def get_up_block( ...@@ -122,7 +122,7 @@ def get_up_block(
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups, resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels, num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention, only_cross_attention=only_cross_attention,
...@@ -144,7 +144,7 @@ class UNetMidBlock3DCrossAttn(nn.Module): ...@@ -144,7 +144,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attn_num_head_channels=1, num_attention_heads=1,
output_scale_factor=1.0, output_scale_factor=1.0,
cross_attention_dim=1280, cross_attention_dim=1280,
dual_cross_attention=False, dual_cross_attention=False,
...@@ -154,7 +154,7 @@ class UNetMidBlock3DCrossAttn(nn.Module): ...@@ -154,7 +154,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
super().__init__() super().__init__()
self.has_cross_attention = True self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet # there is always at least one resnet
...@@ -185,8 +185,8 @@ class UNetMidBlock3DCrossAttn(nn.Module): ...@@ -185,8 +185,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
for _ in range(num_layers): for _ in range(num_layers):
attentions.append( attentions.append(
Transformer2DModel( Transformer2DModel(
in_channels // attn_num_head_channels, in_channels // num_attention_heads,
attn_num_head_channels, num_attention_heads,
in_channels=in_channels, in_channels=in_channels,
num_layers=1, num_layers=1,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
...@@ -197,8 +197,8 @@ class UNetMidBlock3DCrossAttn(nn.Module): ...@@ -197,8 +197,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
) )
temp_attentions.append( temp_attentions.append(
TransformerTemporalModel( TransformerTemporalModel(
in_channels // attn_num_head_channels, in_channels // num_attention_heads,
attn_num_head_channels, num_attention_heads,
in_channels=in_channels, in_channels=in_channels,
num_layers=1, num_layers=1,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
...@@ -273,7 +273,7 @@ class CrossAttnDownBlock3D(nn.Module): ...@@ -273,7 +273,7 @@ class CrossAttnDownBlock3D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attn_num_head_channels=1, num_attention_heads=1,
cross_attention_dim=1280, cross_attention_dim=1280,
output_scale_factor=1.0, output_scale_factor=1.0,
downsample_padding=1, downsample_padding=1,
...@@ -290,7 +290,7 @@ class CrossAttnDownBlock3D(nn.Module): ...@@ -290,7 +290,7 @@ class CrossAttnDownBlock3D(nn.Module):
temp_convs = [] temp_convs = []
self.has_cross_attention = True self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels self.num_attention_heads = num_attention_heads
for i in range(num_layers): for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels in_channels = in_channels if i == 0 else out_channels
...@@ -317,8 +317,8 @@ class CrossAttnDownBlock3D(nn.Module): ...@@ -317,8 +317,8 @@ class CrossAttnDownBlock3D(nn.Module):
) )
attentions.append( attentions.append(
Transformer2DModel( Transformer2DModel(
out_channels // attn_num_head_channels, out_channels // num_attention_heads,
attn_num_head_channels, num_attention_heads,
in_channels=out_channels, in_channels=out_channels,
num_layers=1, num_layers=1,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
...@@ -330,8 +330,8 @@ class CrossAttnDownBlock3D(nn.Module): ...@@ -330,8 +330,8 @@ class CrossAttnDownBlock3D(nn.Module):
) )
temp_attentions.append( temp_attentions.append(
TransformerTemporalModel( TransformerTemporalModel(
out_channels // attn_num_head_channels, out_channels // num_attention_heads,
attn_num_head_channels, num_attention_heads,
in_channels=out_channels, in_channels=out_channels,
num_layers=1, num_layers=1,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
...@@ -486,7 +486,7 @@ class CrossAttnUpBlock3D(nn.Module): ...@@ -486,7 +486,7 @@ class CrossAttnUpBlock3D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attn_num_head_channels=1, num_attention_heads=1,
cross_attention_dim=1280, cross_attention_dim=1280,
output_scale_factor=1.0, output_scale_factor=1.0,
add_upsample=True, add_upsample=True,
...@@ -502,7 +502,7 @@ class CrossAttnUpBlock3D(nn.Module): ...@@ -502,7 +502,7 @@ class CrossAttnUpBlock3D(nn.Module):
temp_attentions = [] temp_attentions = []
self.has_cross_attention = True self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels self.num_attention_heads = num_attention_heads
for i in range(num_layers): for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
...@@ -531,8 +531,8 @@ class CrossAttnUpBlock3D(nn.Module): ...@@ -531,8 +531,8 @@ class CrossAttnUpBlock3D(nn.Module):
) )
attentions.append( attentions.append(
Transformer2DModel( Transformer2DModel(
out_channels // attn_num_head_channels, out_channels // num_attention_heads,
attn_num_head_channels, num_attention_heads,
in_channels=out_channels, in_channels=out_channels,
num_layers=1, num_layers=1,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
...@@ -544,8 +544,8 @@ class CrossAttnUpBlock3D(nn.Module): ...@@ -544,8 +544,8 @@ class CrossAttnUpBlock3D(nn.Module):
) )
temp_attentions.append( temp_attentions.append(
TransformerTemporalModel( TransformerTemporalModel(
out_channels // attn_num_head_channels, out_channels // num_attention_heads,
attn_num_head_channels, num_attention_heads,
in_channels=out_channels, in_channels=out_channels,
num_layers=1, num_layers=1,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
......
...@@ -79,6 +79,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -79,6 +79,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
num_attention_heads (`int`, *optional*): The number of attention heads.
""" """
_supports_gradient_checkpointing = False _supports_gradient_checkpointing = False
...@@ -105,11 +106,20 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -105,11 +106,20 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
cross_attention_dim: int = 1024, cross_attention_dim: int = 1024,
attention_head_dim: Union[int, Tuple[int]] = 64, attention_head_dim: Union[int, Tuple[int]] = 64,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
): ):
super().__init__() super().__init__()
self.sample_size = sample_size self.sample_size = sample_size
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs # Check inputs
if len(down_block_types) != len(up_block_types): if len(down_block_types) != len(up_block_types):
raise ValueError( raise ValueError(
...@@ -121,9 +131,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -121,9 +131,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
) )
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError( raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
) )
# input # input
...@@ -156,8 +166,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -156,8 +166,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
self.down_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([])
if isinstance(attention_head_dim, int): if isinstance(num_attention_heads, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types) num_attention_heads = (num_attention_heads,) * len(down_block_types)
# down # down
output_channel = block_out_channels[0] output_channel = block_out_channels[0]
...@@ -177,7 +187,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -177,7 +187,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i], num_attention_heads=num_attention_heads[i],
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
dual_cross_attention=False, dual_cross_attention=False,
) )
...@@ -191,7 +201,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -191,7 +201,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor, output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1], num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
dual_cross_attention=False, dual_cross_attention=False,
) )
...@@ -201,7 +211,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -201,7 +211,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# up # up
reversed_block_out_channels = list(reversed(block_out_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim)) reversed_num_attention_heads = list(reversed(num_attention_heads))
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types): for i, up_block_type in enumerate(up_block_types):
...@@ -230,7 +240,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -230,7 +240,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=reversed_attention_head_dim[i], num_attention_heads=reversed_num_attention_heads[i],
dual_cross_attention=False, dual_cross_attention=False,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
...@@ -288,8 +298,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -288,8 +298,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` provided, uses as many slices as `num_attention_heads // slice_size`. In this case,
must be a multiple of `slice_size`. `num_attention_heads` must be a multiple of `slice_size`.
""" """
sliceable_head_dims = [] sliceable_head_dims = []
......
...@@ -79,7 +79,7 @@ class Encoder(nn.Module): ...@@ -79,7 +79,7 @@ class Encoder(nn.Module):
downsample_padding=0, downsample_padding=0,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
attn_num_head_channels=None, attention_head_dim=output_channel,
temb_channels=None, temb_channels=None,
) )
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
...@@ -91,7 +91,7 @@ class Encoder(nn.Module): ...@@ -91,7 +91,7 @@ class Encoder(nn.Module):
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
output_scale_factor=1, output_scale_factor=1,
resnet_time_scale_shift="default", resnet_time_scale_shift="default",
attn_num_head_channels=None, attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
temb_channels=None, temb_channels=None,
) )
...@@ -184,7 +184,7 @@ class Decoder(nn.Module): ...@@ -184,7 +184,7 @@ class Decoder(nn.Module):
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
output_scale_factor=1, output_scale_factor=1,
resnet_time_scale_shift="default" if norm_type == "group" else norm_type, resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
attn_num_head_channels=None, attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
temb_channels=temb_channels, temb_channels=temb_channels,
) )
...@@ -208,7 +208,7 @@ class Decoder(nn.Module): ...@@ -208,7 +208,7 @@ class Decoder(nn.Module):
resnet_eps=1e-6, resnet_eps=1e-6,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
attn_num_head_channels=None, attention_head_dim=output_channel,
temb_channels=temb_channels, temb_channels=temb_channels,
resnet_time_scale_shift=norm_type, resnet_time_scale_shift=norm_type,
) )
......
...@@ -396,7 +396,7 @@ class FlaxUNetMidBlock2D(nn.Module): ...@@ -396,7 +396,7 @@ class FlaxUNetMidBlock2D(nn.Module):
Number of Resnet layer block Number of Resnet layer block
resnet_groups (:obj:`int`, *optional*, defaults to `32`): resnet_groups (:obj:`int`, *optional*, defaults to `32`):
The number of groups to use for the Resnet and Attention block group norm The number of groups to use for the Resnet and Attention block group norm
attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`): num_attention_heads (:obj:`int`, *optional*, defaults to `1`):
Number of attention heads for each attention block Number of attention heads for each attention block
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
...@@ -405,7 +405,7 @@ class FlaxUNetMidBlock2D(nn.Module): ...@@ -405,7 +405,7 @@ class FlaxUNetMidBlock2D(nn.Module):
dropout: float = 0.0 dropout: float = 0.0
num_layers: int = 1 num_layers: int = 1
resnet_groups: int = 32 resnet_groups: int = 32
attn_num_head_channels: int = 1 num_attention_heads: int = 1
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
...@@ -427,7 +427,7 @@ class FlaxUNetMidBlock2D(nn.Module): ...@@ -427,7 +427,7 @@ class FlaxUNetMidBlock2D(nn.Module):
for _ in range(self.num_layers): for _ in range(self.num_layers):
attn_block = FlaxAttentionBlock( attn_block = FlaxAttentionBlock(
channels=self.in_channels, channels=self.in_channels,
num_head_channels=self.attn_num_head_channels, num_head_channels=self.num_attention_heads,
num_groups=resnet_groups, num_groups=resnet_groups,
dtype=self.dtype, dtype=self.dtype,
) )
...@@ -532,7 +532,7 @@ class FlaxEncoder(nn.Module): ...@@ -532,7 +532,7 @@ class FlaxEncoder(nn.Module):
self.mid_block = FlaxUNetMidBlock2D( self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1], in_channels=block_out_channels[-1],
resnet_groups=self.norm_num_groups, resnet_groups=self.norm_num_groups,
attn_num_head_channels=None, num_attention_heads=None,
dtype=self.dtype, dtype=self.dtype,
) )
...@@ -625,7 +625,7 @@ class FlaxDecoder(nn.Module): ...@@ -625,7 +625,7 @@ class FlaxDecoder(nn.Module):
self.mid_block = FlaxUNetMidBlock2D( self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1], in_channels=block_out_channels[-1],
resnet_groups=self.norm_num_groups, resnet_groups=self.norm_num_groups,
attn_num_head_channels=None, num_attention_heads=None,
dtype=self.dtype, dtype=self.dtype,
) )
......
...@@ -41,7 +41,7 @@ def get_down_block( ...@@ -41,7 +41,7 @@ def get_down_block(
add_downsample, add_downsample,
resnet_eps, resnet_eps,
resnet_act_fn, resnet_act_fn,
attn_num_head_channels, num_attention_heads,
resnet_groups=None, resnet_groups=None,
cross_attention_dim=None, cross_attention_dim=None,
downsample_padding=None, downsample_padding=None,
...@@ -82,7 +82,7 @@ def get_down_block( ...@@ -82,7 +82,7 @@ def get_down_block(
resnet_groups=resnet_groups, resnet_groups=resnet_groups,
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels, num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention, only_cross_attention=only_cross_attention,
...@@ -101,7 +101,7 @@ def get_up_block( ...@@ -101,7 +101,7 @@ def get_up_block(
add_upsample, add_upsample,
resnet_eps, resnet_eps,
resnet_act_fn, resnet_act_fn,
attn_num_head_channels, num_attention_heads,
resnet_groups=None, resnet_groups=None,
cross_attention_dim=None, cross_attention_dim=None,
dual_cross_attention=False, dual_cross_attention=False,
...@@ -141,7 +141,7 @@ def get_up_block( ...@@ -141,7 +141,7 @@ def get_up_block(
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups, resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels, num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention, only_cross_attention=only_cross_attention,
...@@ -196,6 +196,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -196,6 +196,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
num_attention_heads (`int`, *optional*):
The number of attention heads. If not defined, defaults to `attention_head_dim`
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`. for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): class_embed_type (`str`, *optional*, defaults to None):
...@@ -267,6 +269,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -267,6 +269,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
encoder_hid_dim: Optional[int] = None, encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None, encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8, attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
dual_cross_attention: bool = False, dual_cross_attention: bool = False,
use_linear_projection: bool = False, use_linear_projection: bool = False,
class_embed_type: Optional[str] = None, class_embed_type: Optional[str] = None,
...@@ -293,6 +296,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -293,6 +296,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.sample_size = sample_size self.sample_size = sample_size
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs # Check inputs
if len(down_block_types) != len(up_block_types): if len(down_block_types) != len(up_block_types):
raise ValueError( raise ValueError(
...@@ -312,6 +323,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -312,6 +323,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
f" `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." f" `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
) )
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`:"
f" {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
raise ValueError( raise ValueError(
"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`:" "Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`:"
...@@ -457,6 +474,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -457,6 +474,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if mid_block_only_cross_attention is None: if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = False mid_block_only_cross_attention = False
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
if isinstance(attention_head_dim, int): if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types) attention_head_dim = (attention_head_dim,) * len(down_block_types)
...@@ -492,7 +512,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -492,7 +512,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim[i], cross_attention_dim=cross_attention_dim[i],
attn_num_head_channels=attention_head_dim[i], num_attention_heads=num_attention_heads[i],
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
...@@ -502,6 +522,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -502,6 +522,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_skip_time_act=resnet_skip_time_act, resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor, resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm, cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
) )
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
...@@ -515,7 +536,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -515,7 +536,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
output_scale_factor=mid_block_scale_factor, output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim[-1], cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1], num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
...@@ -529,7 +550,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -529,7 +550,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor, output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim[-1], cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1], attention_head_dim=attention_head_dim[-1],
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act, skip_time_act=resnet_skip_time_act,
...@@ -546,7 +567,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -546,7 +567,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# up # up
reversed_block_out_channels = list(reversed(block_out_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim)) reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block)) reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_cross_attention_dim = list(reversed(cross_attention_dim))
only_cross_attention = list(reversed(only_cross_attention)) only_cross_attention = list(reversed(only_cross_attention))
...@@ -578,7 +599,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -578,7 +599,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
cross_attention_dim=reversed_cross_attention_dim[i], cross_attention_dim=reversed_cross_attention_dim[i],
attn_num_head_channels=reversed_attention_head_dim[i], num_attention_heads=reversed_num_attention_heads[i],
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
...@@ -587,6 +608,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -587,6 +608,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_skip_time_act=resnet_skip_time_act, resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor, resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm, cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
...@@ -679,8 +701,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -679,8 +701,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` provided, uses as many slices as `num_attention_heads // slice_size`. In this case,
must be a multiple of `slice_size`. `num_attention_heads` must be a multiple of `slice_size`.
""" """
sliceable_head_dims = [] sliceable_head_dims = []
...@@ -1192,7 +1214,7 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1192,7 +1214,7 @@ class CrossAttnDownBlockFlat(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attn_num_head_channels=1, num_attention_heads=1,
cross_attention_dim=1280, cross_attention_dim=1280,
output_scale_factor=1.0, output_scale_factor=1.0,
downsample_padding=1, downsample_padding=1,
...@@ -1207,7 +1229,7 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1207,7 +1229,7 @@ class CrossAttnDownBlockFlat(nn.Module):
attentions = [] attentions = []
self.has_cross_attention = True self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels self.num_attention_heads = num_attention_heads
for i in range(num_layers): for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels in_channels = in_channels if i == 0 else out_channels
...@@ -1228,8 +1250,8 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1228,8 +1250,8 @@ class CrossAttnDownBlockFlat(nn.Module):
if not dual_cross_attention: if not dual_cross_attention:
attentions.append( attentions.append(
Transformer2DModel( Transformer2DModel(
attn_num_head_channels, num_attention_heads,
out_channels // attn_num_head_channels, out_channels // num_attention_heads,
in_channels=out_channels, in_channels=out_channels,
num_layers=1, num_layers=1,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
...@@ -1242,8 +1264,8 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1242,8 +1264,8 @@ class CrossAttnDownBlockFlat(nn.Module):
else: else:
attentions.append( attentions.append(
DualTransformer2DModel( DualTransformer2DModel(
attn_num_head_channels, num_attention_heads,
out_channels // attn_num_head_channels, out_channels // num_attention_heads,
in_channels=out_channels, in_channels=out_channels,
num_layers=1, num_layers=1,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
...@@ -1426,7 +1448,7 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1426,7 +1448,7 @@ class CrossAttnUpBlockFlat(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attn_num_head_channels=1, num_attention_heads=1,
cross_attention_dim=1280, cross_attention_dim=1280,
output_scale_factor=1.0, output_scale_factor=1.0,
add_upsample=True, add_upsample=True,
...@@ -1440,7 +1462,7 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1440,7 +1462,7 @@ class CrossAttnUpBlockFlat(nn.Module):
attentions = [] attentions = []
self.has_cross_attention = True self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels self.num_attention_heads = num_attention_heads
for i in range(num_layers): for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
...@@ -1463,8 +1485,8 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1463,8 +1485,8 @@ class CrossAttnUpBlockFlat(nn.Module):
if not dual_cross_attention: if not dual_cross_attention:
attentions.append( attentions.append(
Transformer2DModel( Transformer2DModel(
attn_num_head_channels, num_attention_heads,
out_channels // attn_num_head_channels, out_channels // num_attention_heads,
in_channels=out_channels, in_channels=out_channels,
num_layers=1, num_layers=1,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
...@@ -1477,8 +1499,8 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1477,8 +1499,8 @@ class CrossAttnUpBlockFlat(nn.Module):
else: else:
attentions.append( attentions.append(
DualTransformer2DModel( DualTransformer2DModel(
attn_num_head_channels, num_attention_heads,
out_channels // attn_num_head_channels, out_channels // num_attention_heads,
in_channels=out_channels, in_channels=out_channels,
num_layers=1, num_layers=1,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
...@@ -1572,7 +1594,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -1572,7 +1594,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attn_num_head_channels=1, num_attention_heads=1,
output_scale_factor=1.0, output_scale_factor=1.0,
cross_attention_dim=1280, cross_attention_dim=1280,
dual_cross_attention=False, dual_cross_attention=False,
...@@ -1582,7 +1604,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -1582,7 +1604,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
super().__init__() super().__init__()
self.has_cross_attention = True self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet # there is always at least one resnet
...@@ -1606,8 +1628,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -1606,8 +1628,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
if not dual_cross_attention: if not dual_cross_attention:
attentions.append( attentions.append(
Transformer2DModel( Transformer2DModel(
attn_num_head_channels, num_attention_heads,
in_channels // attn_num_head_channels, in_channels // num_attention_heads,
in_channels=in_channels, in_channels=in_channels,
num_layers=1, num_layers=1,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
...@@ -1619,8 +1641,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -1619,8 +1641,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
else: else:
attentions.append( attentions.append(
DualTransformer2DModel( DualTransformer2DModel(
attn_num_head_channels, num_attention_heads,
in_channels // attn_num_head_channels, in_channels // num_attention_heads,
in_channels=in_channels, in_channels=in_channels,
num_layers=1, num_layers=1,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
...@@ -1682,7 +1704,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -1682,7 +1704,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attn_num_head_channels=1, attention_head_dim=1,
output_scale_factor=1.0, output_scale_factor=1.0,
cross_attention_dim=1280, cross_attention_dim=1280,
skip_time_act=False, skip_time_act=False,
...@@ -1693,10 +1715,10 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -1693,10 +1715,10 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
self.has_cross_attention = True self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels self.attention_head_dim = attention_head_dim
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.num_heads = in_channels // self.attn_num_head_channels self.num_heads = in_channels // self.attention_head_dim
# there is always at least one resnet # there is always at least one resnet
resnets = [ resnets = [
...@@ -1726,7 +1748,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -1726,7 +1748,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
query_dim=in_channels, query_dim=in_channels,
cross_attention_dim=in_channels, cross_attention_dim=in_channels,
heads=self.num_heads, heads=self.num_heads,
dim_head=attn_num_head_channels, dim_head=self.attention_head_dim,
added_kv_proj_dim=cross_attention_dim, added_kv_proj_dim=cross_attention_dim,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
bias=True, bias=True,
......
...@@ -59,7 +59,7 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -59,7 +59,7 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
"block_out_channels": (32, 64), "block_out_channels": (32, 64),
"down_block_types": ("DownBlock2D", "AttnDownBlock2D"), "down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
"up_block_types": ("AttnUpBlock2D", "UpBlock2D"), "up_block_types": ("AttnUpBlock2D", "UpBlock2D"),
"attention_head_dim": None, "attention_head_dim": 3,
"out_channels": 3, "out_channels": 3,
"in_channels": 3, "in_channels": 3,
"layers_per_block": 2, "layers_per_block": 2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment