Unverified Commit 88d26946 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Correct bad attn naming (#3797)



* relax tolerance slightly

* correct incorrect naming

* correct namingc

* correct more

* Apply suggestions from code review

* Fix more

* Correct more

* correct incorrect naming

* Update src/diffusers/models/controlnet.py

* Correct flax

* Correct renaming

* Correct blocks

* Fix more

* Correct more

* mkae style

* mkae style

* mkae style

* mkae style

* mkae style

* Fix flax

* mkae style

* rename

* rename

* rename attn head dim to attention_head_dim

* correct flax

* make style

* improve

* Correct more

* make style

* fix more

* mkae style

* Update src/diffusers/models/controlnet_flax.py

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

---------
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 0c6d1bc9
......@@ -112,6 +112,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
......@@ -124,6 +125,14 @@ class ControlNetModel(ModelMixin, ConfigMixin):
):
super().__init__()
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(block_out_channels) != len(down_block_types):
raise ValueError(
......@@ -135,9 +144,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
# input
......@@ -198,6 +207,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
......@@ -221,7 +233,8 @@ class ControlNetModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i],
num_attention_heads=num_attention_heads[i],
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
......@@ -255,7 +268,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
......@@ -292,6 +305,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
norm_eps=unet.config.norm_eps,
cross_attention_dim=unet.config.cross_attention_dim,
attention_head_dim=unet.config.attention_head_dim,
num_attention_heads=unet.config.num_attention_heads,
use_linear_projection=unet.config.use_linear_projection,
class_embed_type=unet.config.class_embed_type,
num_class_embeds=unet.config.num_class_embeds,
......@@ -390,8 +404,8 @@ class ControlNetModel(ModelMixin, ConfigMixin):
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
provided, uses as many slices as `num_attention_heads // slice_size`. In this case,
`num_attention_heads` must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
......
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union
from typing import Optional, Tuple, Union
import flax
import flax.linen as nn
......@@ -129,6 +129,8 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
The number of layers per block.
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
The dimension of the attention heads.
num_attention_heads (`int` or `Tuple[int]`, *optional*):
The number of attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768):
The dimension of the cross attention features.
dropout (`float`, *optional*, defaults to 0):
......@@ -155,6 +157,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: int = 2
attention_head_dim: Union[int, Tuple[int]] = 8
num_attention_heads: Optional[Union[int, Tuple[int]]] = None
cross_attention_dim: int = 1280
dropout: float = 0.0
use_linear_projection: bool = False
......@@ -182,6 +185,14 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = self.num_attention_heads or self.attention_head_dim
# input
self.conv_in = nn.Conv(
block_out_channels[0],
......@@ -206,9 +217,8 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
if isinstance(only_cross_attention, bool):
only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
attention_head_dim = self.attention_head_dim
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(self.down_block_types)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
# down
down_blocks = []
......@@ -237,7 +247,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel,
dropout=self.dropout,
num_layers=self.layers_per_block,
attn_num_head_channels=attention_head_dim[i],
num_attention_heads=num_attention_heads[i],
add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
......@@ -285,7 +295,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
in_channels=mid_block_channel,
dropout=self.dropout,
attn_num_head_channels=attention_head_dim[-1],
num_attention_heads=num_attention_heads[-1],
use_linear_projection=self.use_linear_projection,
dtype=self.dtype,
)
......
......@@ -164,7 +164,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim,
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
)
......@@ -178,7 +178,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
attn_num_head_channels=attention_head_dim,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups,
add_attention=add_attention,
)
......@@ -204,7 +204,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim,
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.up_blocks.append(up_block)
......
......@@ -18,7 +18,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from ..utils import is_torch_version
from ..utils import is_torch_version, logging
from .attention import AdaGroupNorm
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
from .dual_transformer_2d import DualTransformer2DModel
......@@ -26,6 +26,9 @@ from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D,
from .transformer_2d import Transformer2DModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_down_block(
down_block_type,
num_layers,
......@@ -35,7 +38,7 @@ def get_down_block(
add_downsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
num_attention_heads=None,
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
......@@ -47,7 +50,15 @@ def get_down_block(
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
cross_attention_norm=None,
attention_head_dim=None,
):
# If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None:
logger.warn(
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
)
attention_head_dim = num_attention_heads
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D":
return DownBlock2D(
......@@ -87,7 +98,7 @@ def get_down_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif down_block_type == "CrossAttnDownBlock2D":
......@@ -104,7 +115,7 @@ def get_down_block(
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
......@@ -124,7 +135,7 @@ def get_down_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
......@@ -152,8 +163,7 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif down_block_type == "DownEncoderBlock2D":
......@@ -178,7 +188,7 @@ def get_down_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif down_block_type == "KDownBlock2D":
......@@ -201,7 +211,7 @@ def get_down_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
attention_head_dim=attention_head_dim,
add_self_attention=True if not add_downsample else False,
)
raise ValueError(f"{down_block_type} does not exist.")
......@@ -217,7 +227,7 @@ def get_up_block(
add_upsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
num_attention_heads=None,
resnet_groups=None,
cross_attention_dim=None,
dual_cross_attention=False,
......@@ -228,7 +238,15 @@ def get_up_block(
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
cross_attention_norm=None,
attention_head_dim=None,
):
# If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None:
logger.warn(
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
)
attention_head_dim = num_attention_heads
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D":
return UpBlock2D(
......@@ -272,7 +290,7 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
......@@ -293,7 +311,7 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
......@@ -311,7 +329,7 @@ def get_up_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
attn_num_head_channels=attn_num_head_channels,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif up_block_type == "SkipUpBlock2D":
......@@ -336,7 +354,7 @@ def get_up_block(
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif up_block_type == "UpDecoderBlock2D":
......@@ -360,7 +378,7 @@ def get_up_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
attn_num_head_channels=attn_num_head_channels,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
temb_channels=temb_channels,
)
......@@ -384,7 +402,7 @@ def get_up_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
attention_head_dim=attention_head_dim,
)
raise ValueError(f"{up_block_type} does not exist.")
......@@ -403,7 +421,7 @@ class UNetMidBlock2D(nn.Module):
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
add_attention: bool = True,
attn_num_head_channels=1,
attention_head_dim=1,
output_scale_factor=1.0,
):
super().__init__()
......@@ -427,13 +445,19 @@ class UNetMidBlock2D(nn.Module):
]
attentions = []
if attention_head_dim is None:
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
)
attention_head_dim = in_channels
for _ in range(num_layers):
if self.add_attention:
attentions.append(
Attention(
in_channels,
heads=in_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else in_channels,
heads=in_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None,
......@@ -487,7 +511,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
num_attention_heads=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
dual_cross_attention=False,
......@@ -497,7 +521,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
super().__init__()
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet
......@@ -521,8 +545,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -534,8 +558,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
else:
attentions.append(
DualTransformer2DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -596,7 +620,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_head_dim=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
skip_time_act=False,
......@@ -607,10 +631,10 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.attention_head_dim = attention_head_dim
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.num_heads = in_channels // self.attn_num_head_channels
self.num_heads = in_channels // self.attention_head_dim
# there is always at least one resnet
resnets = [
......@@ -640,7 +664,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
query_dim=in_channels,
cross_attention_dim=in_channels,
heads=self.num_heads,
dim_head=attn_num_head_channels,
dim_head=self.attention_head_dim,
added_kv_proj_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
bias=True,
......@@ -720,7 +744,7 @@ class AttnDownBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_head_dim=1,
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
......@@ -729,6 +753,12 @@ class AttnDownBlock2D(nn.Module):
resnets = []
attentions = []
if attention_head_dim is None:
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
)
attention_head_dim = out_channels
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
......@@ -748,8 +778,8 @@ class AttnDownBlock2D(nn.Module):
attentions.append(
Attention(
out_channels,
heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
heads=out_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups,
......@@ -804,7 +834,7 @@ class CrossAttnDownBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
num_attention_heads=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
downsample_padding=1,
......@@ -819,7 +849,7 @@ class CrossAttnDownBlock2D(nn.Module):
attentions = []
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.num_attention_heads = num_attention_heads
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
......@@ -840,8 +870,8 @@ class CrossAttnDownBlock2D(nn.Module):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -854,8 +884,8 @@ class CrossAttnDownBlock2D(nn.Module):
else:
attentions.append(
DualTransformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -1099,7 +1129,7 @@ class AttnDownEncoderBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_head_dim=1,
output_scale_factor=1.0,
add_downsample=True,
downsample_padding=1,
......@@ -1108,6 +1138,12 @@ class AttnDownEncoderBlock2D(nn.Module):
resnets = []
attentions = []
if attention_head_dim is None:
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
)
attention_head_dim = out_channels
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
......@@ -1127,8 +1163,8 @@ class AttnDownEncoderBlock2D(nn.Module):
attentions.append(
Attention(
out_channels,
heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
heads=out_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups,
......@@ -1177,15 +1213,20 @@ class AttnSkipDownBlock2D(nn.Module):
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_head_dim=1,
output_scale_factor=np.sqrt(2.0),
downsample_padding=1,
add_downsample=True,
):
super().__init__()
self.attentions = nn.ModuleList([])
self.resnets = nn.ModuleList([])
if attention_head_dim is None:
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
)
attention_head_dim = out_channels
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.resnets.append(
......@@ -1206,8 +1247,8 @@ class AttnSkipDownBlock2D(nn.Module):
self.attentions.append(
Attention(
out_channels,
heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
heads=out_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=32,
......@@ -1451,7 +1492,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_head_dim=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
add_downsample=True,
......@@ -1466,8 +1507,8 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
resnets = []
attentions = []
self.attn_num_head_channels = attn_num_head_channels
self.num_heads = out_channels // self.attn_num_head_channels
self.attention_head_dim = attention_head_dim
self.num_heads = out_channels // self.attention_head_dim
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
......@@ -1496,7 +1537,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
query_dim=out_channels,
cross_attention_dim=out_channels,
heads=self.num_heads,
dim_head=attn_num_head_channels,
dim_head=attention_head_dim,
added_kv_proj_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
bias=True,
......@@ -1686,7 +1727,7 @@ class KCrossAttnDownBlock2D(nn.Module):
num_layers: int = 4,
resnet_group_size: int = 32,
add_downsample=True,
attn_num_head_channels: int = 64,
attention_head_dim: int = 64,
add_self_attention: bool = False,
resnet_eps: float = 1e-5,
resnet_act_fn: str = "gelu",
......@@ -1719,8 +1760,8 @@ class KCrossAttnDownBlock2D(nn.Module):
attentions.append(
KAttentionBlock(
out_channels,
out_channels // attn_num_head_channels,
attn_num_head_channels,
out_channels // attention_head_dim,
attention_head_dim,
cross_attention_dim=cross_attention_dim,
temb_channels=temb_channels,
attention_bias=True,
......@@ -1817,7 +1858,7 @@ class AttnUpBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_head_dim=1,
output_scale_factor=1.0,
add_upsample=True,
):
......@@ -1825,6 +1866,12 @@ class AttnUpBlock2D(nn.Module):
resnets = []
attentions = []
if attention_head_dim is None:
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
)
attention_head_dim = out_channels
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
......@@ -1846,8 +1893,8 @@ class AttnUpBlock2D(nn.Module):
attentions.append(
Attention(
out_channels,
heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
heads=out_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups,
......@@ -1897,7 +1944,7 @@ class CrossAttnUpBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
num_attention_heads=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
add_upsample=True,
......@@ -1911,7 +1958,7 @@ class CrossAttnUpBlock2D(nn.Module):
attentions = []
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.num_attention_heads = num_attention_heads
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
......@@ -1934,8 +1981,8 @@ class CrossAttnUpBlock2D(nn.Module):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -1948,8 +1995,8 @@ class CrossAttnUpBlock2D(nn.Module):
else:
attentions.append(
DualTransformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -2178,7 +2225,7 @@ class AttnUpDecoderBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_head_dim=1,
output_scale_factor=1.0,
add_upsample=True,
temb_channels=None,
......@@ -2187,6 +2234,12 @@ class AttnUpDecoderBlock2D(nn.Module):
resnets = []
attentions = []
if attention_head_dim is None:
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
)
attention_head_dim = out_channels
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
......@@ -2207,8 +2260,8 @@ class AttnUpDecoderBlock2D(nn.Module):
attentions.append(
Attention(
out_channels,
heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
heads=out_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None,
......@@ -2253,9 +2306,8 @@ class AttnSkipUpBlock2D(nn.Module):
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_head_dim=1,
output_scale_factor=np.sqrt(2.0),
upsample_padding=1,
add_upsample=True,
):
super().__init__()
......@@ -2282,11 +2334,17 @@ class AttnSkipUpBlock2D(nn.Module):
)
)
if attention_head_dim is None:
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
)
attention_head_dim = out_channels
self.attentions.append(
Attention(
out_channels,
heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
heads=out_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=32,
......@@ -2563,7 +2621,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_head_dim=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
add_upsample=True,
......@@ -2576,9 +2634,9 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
attentions = []
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.attention_head_dim = attention_head_dim
self.num_heads = out_channels // self.attn_num_head_channels
self.num_heads = out_channels // self.attention_head_dim
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
......@@ -2609,7 +2667,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
query_dim=out_channels,
cross_attention_dim=out_channels,
heads=self.num_heads,
dim_head=attn_num_head_channels,
dim_head=self.attention_head_dim,
added_kv_proj_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
bias=True,
......@@ -2804,7 +2862,7 @@ class KCrossAttnUpBlock2D(nn.Module):
resnet_eps: float = 1e-5,
resnet_act_fn: str = "gelu",
resnet_group_size: int = 32,
attn_num_head_channels=1, # attention dim_head
attention_head_dim=1, # attention dim_head
cross_attention_dim: int = 768,
add_upsample: bool = True,
upcast_attention: bool = False,
......@@ -2818,7 +2876,7 @@ class KCrossAttnUpBlock2D(nn.Module):
add_self_attention = True if is_first_block else False
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.attention_head_dim = attention_head_dim
# in_channels, and out_channels for the block (k-unet)
k_in_channels = out_channels if is_first_block else 2 * out_channels
......@@ -2854,10 +2912,10 @@ class KCrossAttnUpBlock2D(nn.Module):
attentions.append(
KAttentionBlock(
k_out_channels if (i == num_layers - 1) else out_channels,
k_out_channels // attn_num_head_channels
k_out_channels // attention_head_dim
if (i == num_layers - 1)
else out_channels // attn_num_head_channels,
attn_num_head_channels,
else out_channels // attention_head_dim,
attention_head_dim,
cross_attention_dim=cross_attention_dim,
temb_channels=temb_channels,
attention_bias=True,
......
......@@ -33,7 +33,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
num_attention_heads (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
......@@ -46,7 +46,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
out_channels: int
dropout: float = 0.0
num_layers: int = 1
attn_num_head_channels: int = 1
num_attention_heads: int = 1
add_downsample: bool = True
use_linear_projection: bool = False
only_cross_attention: bool = False
......@@ -70,8 +70,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
attn_block = FlaxTransformer2DModel(
in_channels=self.out_channels,
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
n_heads=self.num_attention_heads,
d_head=self.out_channels // self.num_attention_heads,
depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
......@@ -172,7 +172,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
num_attention_heads (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add upsampling layer before each final output
......@@ -186,7 +186,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
prev_output_channel: int
dropout: float = 0.0
num_layers: int = 1
attn_num_head_channels: int = 1
num_attention_heads: int = 1
add_upsample: bool = True
use_linear_projection: bool = False
only_cross_attention: bool = False
......@@ -211,8 +211,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
attn_block = FlaxTransformer2DModel(
in_channels=self.out_channels,
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
n_heads=self.num_attention_heads,
d_head=self.out_channels // self.num_attention_heads,
depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
......@@ -317,7 +317,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
num_attention_heads (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
......@@ -327,7 +327,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
in_channels: int
dropout: float = 0.0
num_layers: int = 1
attn_num_head_channels: int = 1
num_attention_heads: int = 1
use_linear_projection: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
......@@ -348,8 +348,8 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
for _ in range(self.num_layers):
attn_block = FlaxTransformer2DModel(
in_channels=self.in_channels,
n_heads=self.attn_num_head_channels,
d_head=self.in_channels // self.attn_num_head_channels,
n_heads=self.num_attention_heads,
d_head=self.in_channels // self.num_attention_heads,
depth=1,
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
......
......@@ -103,6 +103,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
num_attention_heads (`int`, *optional*):
The number of attention heads. If not defined, defaults to `attention_head_dim`
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None):
......@@ -169,6 +171,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
......@@ -195,6 +198,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
self.sample_size = sample_size
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
......@@ -211,6 +222,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
......@@ -353,6 +369,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = False
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
......@@ -388,7 +407,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim[i],
attn_num_head_channels=attention_head_dim[i],
num_attention_heads=num_attention_heads[i],
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
......@@ -398,6 +417,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
)
self.down_blocks.append(down_block)
......@@ -411,7 +431,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1],
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
......@@ -425,7 +445,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1],
attention_head_dim=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
......@@ -442,7 +462,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
only_cross_attention = list(reversed(only_cross_attention))
......@@ -474,7 +494,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=reversed_cross_attention_dim[i],
attn_num_head_channels=reversed_attention_head_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
......@@ -483,6 +503,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
......@@ -575,8 +596,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
provided, uses as many slices as `num_attention_heads // slice_size`. In this case,
`num_attention_heads` must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
......
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union
from typing import Optional, Tuple, Union
import flax
import flax.linen as nn
......@@ -81,6 +81,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
The number of layers per block.
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
The dimension of the attention heads.
num_attention_heads (`int` or `Tuple[int]`, *optional*):
The number of attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768):
The dimension of the cross attention features.
dropout (`float`, *optional*, defaults to 0):
......@@ -107,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: int = 2
attention_head_dim: Union[int, Tuple[int]] = 8
num_attention_heads: Optional[Union[int, Tuple[int]]] = None
cross_attention_dim: int = 1280
dropout: float = 0.0
use_linear_projection: bool = False
......@@ -131,6 +134,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = self.num_attention_heads or self.attention_head_dim
# input
self.conv_in = nn.Conv(
block_out_channels[0],
......@@ -150,9 +161,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
if isinstance(only_cross_attention, bool):
only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
attention_head_dim = self.attention_head_dim
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(self.down_block_types)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
# down
down_blocks = []
......@@ -168,7 +178,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel,
dropout=self.dropout,
num_layers=self.layers_per_block,
attn_num_head_channels=attention_head_dim[i],
num_attention_heads=num_attention_heads[i],
add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
......@@ -192,7 +202,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
dropout=self.dropout,
attn_num_head_channels=attention_head_dim[-1],
num_attention_heads=num_attention_heads[-1],
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
......@@ -201,7 +211,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
# up
up_blocks = []
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_num_attention_heads = list(reversed(num_attention_heads))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(self.up_block_types):
......@@ -217,7 +227,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel,
prev_output_channel=prev_output_channel,
num_layers=self.layers_per_block + 1,
attn_num_head_channels=reversed_attention_head_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
add_upsample=not is_final_block,
dropout=self.dropout,
use_linear_projection=self.use_linear_projection,
......
......@@ -29,7 +29,7 @@ def get_down_block(
add_downsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
num_attention_heads,
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
......@@ -66,7 +66,7 @@ def get_down_block(
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
......@@ -86,7 +86,7 @@ def get_up_block(
add_upsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
num_attention_heads,
resnet_groups=None,
cross_attention_dim=None,
dual_cross_attention=False,
......@@ -122,7 +122,7 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
......@@ -144,7 +144,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
num_attention_heads=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
dual_cross_attention=False,
......@@ -154,7 +154,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
super().__init__()
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet
......@@ -185,8 +185,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
for _ in range(num_layers):
attentions.append(
Transformer2DModel(
in_channels // attn_num_head_channels,
attn_num_head_channels,
in_channels // num_attention_heads,
num_attention_heads,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -197,8 +197,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
)
temp_attentions.append(
TransformerTemporalModel(
in_channels // attn_num_head_channels,
attn_num_head_channels,
in_channels // num_attention_heads,
num_attention_heads,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -273,7 +273,7 @@ class CrossAttnDownBlock3D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
num_attention_heads=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
downsample_padding=1,
......@@ -290,7 +290,7 @@ class CrossAttnDownBlock3D(nn.Module):
temp_convs = []
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.num_attention_heads = num_attention_heads
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
......@@ -317,8 +317,8 @@ class CrossAttnDownBlock3D(nn.Module):
)
attentions.append(
Transformer2DModel(
out_channels // attn_num_head_channels,
attn_num_head_channels,
out_channels // num_attention_heads,
num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -330,8 +330,8 @@ class CrossAttnDownBlock3D(nn.Module):
)
temp_attentions.append(
TransformerTemporalModel(
out_channels // attn_num_head_channels,
attn_num_head_channels,
out_channels // num_attention_heads,
num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -486,7 +486,7 @@ class CrossAttnUpBlock3D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
num_attention_heads=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
add_upsample=True,
......@@ -502,7 +502,7 @@ class CrossAttnUpBlock3D(nn.Module):
temp_attentions = []
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.num_attention_heads = num_attention_heads
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
......@@ -531,8 +531,8 @@ class CrossAttnUpBlock3D(nn.Module):
)
attentions.append(
Transformer2DModel(
out_channels // attn_num_head_channels,
attn_num_head_channels,
out_channels // num_attention_heads,
num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -544,8 +544,8 @@ class CrossAttnUpBlock3D(nn.Module):
)
temp_attentions.append(
TransformerTemporalModel(
out_channels // attn_num_head_channels,
attn_num_head_channels,
out_channels // num_attention_heads,
num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......
......@@ -79,6 +79,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
num_attention_heads (`int`, *optional*): The number of attention heads.
"""
_supports_gradient_checkpointing = False
......@@ -105,11 +106,20 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_eps: float = 1e-5,
cross_attention_dim: int = 1024,
attention_head_dim: Union[int, Tuple[int]] = 64,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
):
super().__init__()
self.sample_size = sample_size
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
......@@ -121,9 +131,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
# input
......@@ -156,8 +166,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
......@@ -177,7 +187,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i],
num_attention_heads=num_attention_heads[i],
downsample_padding=downsample_padding,
dual_cross_attention=False,
)
......@@ -191,7 +201,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=False,
)
......@@ -201,7 +211,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_num_attention_heads = list(reversed(num_attention_heads))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
......@@ -230,7 +240,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=reversed_attention_head_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
dual_cross_attention=False,
)
self.up_blocks.append(up_block)
......@@ -288,8 +298,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
provided, uses as many slices as `num_attention_heads // slice_size`. In this case,
`num_attention_heads` must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
......
......@@ -79,7 +79,7 @@ class Encoder(nn.Module):
downsample_padding=0,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=None,
attention_head_dim=output_channel,
temb_channels=None,
)
self.down_blocks.append(down_block)
......@@ -91,7 +91,7 @@ class Encoder(nn.Module):
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attn_num_head_channels=None,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
)
......@@ -184,7 +184,7 @@ class Decoder(nn.Module):
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
attn_num_head_channels=None,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=temb_channels,
)
......@@ -208,7 +208,7 @@ class Decoder(nn.Module):
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=None,
attention_head_dim=output_channel,
temb_channels=temb_channels,
resnet_time_scale_shift=norm_type,
)
......
......@@ -396,7 +396,7 @@ class FlaxUNetMidBlock2D(nn.Module):
Number of Resnet layer block
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
The number of groups to use for the Resnet and Attention block group norm
attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`):
num_attention_heads (:obj:`int`, *optional*, defaults to `1`):
Number of attention heads for each attention block
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
......@@ -405,7 +405,7 @@ class FlaxUNetMidBlock2D(nn.Module):
dropout: float = 0.0
num_layers: int = 1
resnet_groups: int = 32
attn_num_head_channels: int = 1
num_attention_heads: int = 1
dtype: jnp.dtype = jnp.float32
def setup(self):
......@@ -427,7 +427,7 @@ class FlaxUNetMidBlock2D(nn.Module):
for _ in range(self.num_layers):
attn_block = FlaxAttentionBlock(
channels=self.in_channels,
num_head_channels=self.attn_num_head_channels,
num_head_channels=self.num_attention_heads,
num_groups=resnet_groups,
dtype=self.dtype,
)
......@@ -532,7 +532,7 @@ class FlaxEncoder(nn.Module):
self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_groups=self.norm_num_groups,
attn_num_head_channels=None,
num_attention_heads=None,
dtype=self.dtype,
)
......@@ -625,7 +625,7 @@ class FlaxDecoder(nn.Module):
self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_groups=self.norm_num_groups,
attn_num_head_channels=None,
num_attention_heads=None,
dtype=self.dtype,
)
......
......@@ -41,7 +41,7 @@ def get_down_block(
add_downsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
num_attention_heads,
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
......@@ -82,7 +82,7 @@ def get_down_block(
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
......@@ -101,7 +101,7 @@ def get_up_block(
add_upsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
num_attention_heads,
resnet_groups=None,
cross_attention_dim=None,
dual_cross_attention=False,
......@@ -141,7 +141,7 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
......@@ -196,6 +196,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
num_attention_heads (`int`, *optional*):
The number of attention heads. If not defined, defaults to `attention_head_dim`
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None):
......@@ -267,6 +269,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
......@@ -293,6 +296,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.sample_size = sample_size
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
......@@ -312,6 +323,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
f" `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`:"
f" {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
raise ValueError(
"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`:"
......@@ -457,6 +474,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = False
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
......@@ -492,7 +512,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim[i],
attn_num_head_channels=attention_head_dim[i],
num_attention_heads=num_attention_heads[i],
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
......@@ -502,6 +522,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
)
self.down_blocks.append(down_block)
......@@ -515,7 +536,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1],
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
......@@ -529,7 +550,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1],
attention_head_dim=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
......@@ -546,7 +567,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
only_cross_attention = list(reversed(only_cross_attention))
......@@ -578,7 +599,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=reversed_cross_attention_dim[i],
attn_num_head_channels=reversed_attention_head_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
......@@ -587,6 +608,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
......@@ -679,8 +701,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
provided, uses as many slices as `num_attention_heads // slice_size`. In this case,
`num_attention_heads` must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
......@@ -1192,7 +1214,7 @@ class CrossAttnDownBlockFlat(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
num_attention_heads=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
downsample_padding=1,
......@@ -1207,7 +1229,7 @@ class CrossAttnDownBlockFlat(nn.Module):
attentions = []
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.num_attention_heads = num_attention_heads
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
......@@ -1228,8 +1250,8 @@ class CrossAttnDownBlockFlat(nn.Module):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -1242,8 +1264,8 @@ class CrossAttnDownBlockFlat(nn.Module):
else:
attentions.append(
DualTransformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -1426,7 +1448,7 @@ class CrossAttnUpBlockFlat(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
num_attention_heads=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
add_upsample=True,
......@@ -1440,7 +1462,7 @@ class CrossAttnUpBlockFlat(nn.Module):
attentions = []
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.num_attention_heads = num_attention_heads
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
......@@ -1463,8 +1485,8 @@ class CrossAttnUpBlockFlat(nn.Module):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -1477,8 +1499,8 @@ class CrossAttnUpBlockFlat(nn.Module):
else:
attentions.append(
DualTransformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -1572,7 +1594,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
num_attention_heads=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
dual_cross_attention=False,
......@@ -1582,7 +1604,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
super().__init__()
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet
......@@ -1606,8 +1628,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -1619,8 +1641,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
else:
attentions.append(
DualTransformer2DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
......@@ -1682,7 +1704,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_head_dim=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
skip_time_act=False,
......@@ -1693,10 +1715,10 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.attention_head_dim = attention_head_dim
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.num_heads = in_channels // self.attn_num_head_channels
self.num_heads = in_channels // self.attention_head_dim
# there is always at least one resnet
resnets = [
......@@ -1726,7 +1748,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
query_dim=in_channels,
cross_attention_dim=in_channels,
heads=self.num_heads,
dim_head=attn_num_head_channels,
dim_head=self.attention_head_dim,
added_kv_proj_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
bias=True,
......
......@@ -59,7 +59,7 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
"block_out_channels": (32, 64),
"down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
"up_block_types": ("AttnUpBlock2D", "UpBlock2D"),
"attention_head_dim": None,
"attention_head_dim": 3,
"out_channels": 3,
"in_channels": 3,
"layers_per_block": 2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment