Unverified Commit 5e036921 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Make cross-attention check more robust (#1560)

* Make cross-attention check more robust.

* Fix copies.
parent bea7eb43
......@@ -343,6 +343,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
):
super().__init__()
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
......@@ -526,6 +527,7 @@ class CrossAttnDownBlock2D(nn.Module):
resnets = []
attentions = []
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
......@@ -1110,6 +1112,7 @@ class CrossAttnUpBlock2D(nn.Module):
resnets = []
attentions = []
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
......
......@@ -377,7 +377,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
......@@ -403,7 +403,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
......
......@@ -455,7 +455,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
......@@ -481,7 +481,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
......@@ -726,6 +726,7 @@ class CrossAttnDownBlockFlat(nn.Module):
resnets = []
attentions = []
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
......@@ -924,6 +925,7 @@ class CrossAttnUpBlockFlat(nn.Module):
resnets = []
attentions = []
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
......@@ -1043,6 +1045,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
):
super().__init__()
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment