Unverified Commit 5e036921 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Make cross-attention check more robust (#1560)

* Make cross-attention check more robust.

* Fix copies.
parent bea7eb43
...@@ -343,6 +343,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -343,6 +343,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
): ):
super().__init__() super().__init__()
self.has_cross_attention = True
self.attention_type = attention_type self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels self.attn_num_head_channels = attn_num_head_channels
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
...@@ -526,6 +527,7 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -526,6 +527,7 @@ class CrossAttnDownBlock2D(nn.Module):
resnets = [] resnets = []
attentions = [] attentions = []
self.has_cross_attention = True
self.attention_type = attention_type self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels self.attn_num_head_channels = attn_num_head_channels
...@@ -1110,6 +1112,7 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1110,6 +1112,7 @@ class CrossAttnUpBlock2D(nn.Module):
resnets = [] resnets = []
attentions = [] attentions = []
self.has_cross_attention = True
self.attention_type = attention_type self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels self.attn_num_head_channels = attn_num_head_channels
......
...@@ -377,7 +377,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -377,7 +377,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# 3. down # 3. down
down_block_res_samples = (sample,) down_block_res_samples = (sample,)
for downsample_block in self.down_blocks: for downsample_block in self.down_blocks:
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block( sample, res_samples = downsample_block(
hidden_states=sample, hidden_states=sample,
temb=emb, temb=emb,
...@@ -403,7 +403,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -403,7 +403,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
if not is_final_block and forward_upsample_size: if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:] upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block( sample = upsample_block(
hidden_states=sample, hidden_states=sample,
temb=emb, temb=emb,
......
...@@ -455,7 +455,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -455,7 +455,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# 3. down # 3. down
down_block_res_samples = (sample,) down_block_res_samples = (sample,)
for downsample_block in self.down_blocks: for downsample_block in self.down_blocks:
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block( sample, res_samples = downsample_block(
hidden_states=sample, hidden_states=sample,
temb=emb, temb=emb,
...@@ -481,7 +481,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -481,7 +481,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if not is_final_block and forward_upsample_size: if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:] upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block( sample = upsample_block(
hidden_states=sample, hidden_states=sample,
temb=emb, temb=emb,
...@@ -726,6 +726,7 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -726,6 +726,7 @@ class CrossAttnDownBlockFlat(nn.Module):
resnets = [] resnets = []
attentions = [] attentions = []
self.has_cross_attention = True
self.attention_type = attention_type self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels self.attn_num_head_channels = attn_num_head_channels
...@@ -924,6 +925,7 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -924,6 +925,7 @@ class CrossAttnUpBlockFlat(nn.Module):
resnets = [] resnets = []
attentions = [] attentions = []
self.has_cross_attention = True
self.attention_type = attention_type self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels self.attn_num_head_channels = attn_num_head_channels
...@@ -1043,6 +1045,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -1043,6 +1045,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
): ):
super().__init__() super().__init__()
self.has_cross_attention = True
self.attention_type = attention_type self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels self.attn_num_head_channels = attn_num_head_channels
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment