Commit 4f1df69d authored by sayakpaul's avatar sayakpaul
Browse files

Revert "add attention_head_dim"

This reverts commit 15f6b224.
parent 15f6b224
......@@ -158,7 +158,6 @@ class BasicTransformerBlock(nn.Module):
super().__init__()
self.only_cross_attention = only_cross_attention
# We keep these boolean flags for backwards-compatibility.
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
......
......@@ -120,7 +120,6 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, normalization and activation layers is skipped in post-processing.
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
num_attention_heads (`int`, *optional*): The number of attention heads.
"""
......@@ -148,16 +147,10 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
layers_per_block: int = 2,
norm_num_groups: Optional[int] = 32,
cross_attention_dim: int = 1024,
attention_head_dim: Union[int, Tuple[int]] = None,
num_attention_heads: Optional[Union[int, Tuple[int]]] = 64,
):
super().__init__()
# We didn't define `attention_head_dim` when we first integrated this UNet. As a result,
# we had to use `num_attention_heads` in to pass values for arguments that actually denote
# attention head dimension. This is why we correct it here.
attention_head_dim = num_attention_heads or attention_head_dim
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
......@@ -179,7 +172,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
self.transformer_in = TransformerTemporalModel(
num_attention_heads=8,
attention_head_dim=attention_head_dim,
attention_head_dim=num_attention_heads,
in_channels=block_out_channels[0],
num_layers=1,
norm_num_groups=norm_num_groups,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment