Unverified Commit cecdd8bd authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Adapt UNet2D for supre-resolution (#1385)

* allow disabling self attention

* add class_embedding

* fix copies

* fix condition

* fix copies

* do_self_attention -> only_cross_attention

* fix copies

* num_classes -> num_class_embeds

* fix default value
parent 30f6f441
...@@ -100,6 +100,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -100,6 +100,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
activation_fn: str = "geglu", activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None, num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False, use_linear_projection: bool = False,
only_cross_attention: bool = False,
): ):
super().__init__() super().__init__()
self.use_linear_projection = use_linear_projection self.use_linear_projection = use_linear_projection
...@@ -157,6 +158,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -157,6 +158,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
activation_fn=activation_fn, activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm, num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias, attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
) )
for d in range(num_layers) for d in range(num_layers)
] ]
...@@ -387,14 +389,17 @@ class BasicTransformerBlock(nn.Module): ...@@ -387,14 +389,17 @@ class BasicTransformerBlock(nn.Module):
activation_fn: str = "geglu", activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None, num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False, attention_bias: bool = False,
only_cross_attention: bool = False,
): ):
super().__init__() super().__init__()
self.only_cross_attention = only_cross_attention
self.attn1 = CrossAttention( self.attn1 = CrossAttention(
query_dim=dim, query_dim=dim,
heads=num_attention_heads, heads=num_attention_heads,
dim_head=attention_head_dim, dim_head=attention_head_dim,
dropout=dropout, dropout=dropout,
bias=attention_bias, bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
) # is a self-attention ) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.attn2 = CrossAttention( self.attn2 = CrossAttention(
...@@ -461,6 +466,10 @@ class BasicTransformerBlock(nn.Module): ...@@ -461,6 +466,10 @@ class BasicTransformerBlock(nn.Module):
norm_hidden_states = ( norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
) )
if self.only_cross_attention:
hidden_states = self.attn1(norm_hidden_states, context) + hidden_states
else:
hidden_states = self.attn1(norm_hidden_states) + hidden_states hidden_states = self.attn1(norm_hidden_states) + hidden_states
# 2. Cross-Attention # 2. Cross-Attention
......
...@@ -34,6 +34,7 @@ def get_down_block( ...@@ -34,6 +34,7 @@ def get_down_block(
downsample_padding=None, downsample_padding=None,
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=False, use_linear_projection=False,
only_cross_attention=False,
): ):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D": if down_block_type == "DownBlock2D":
...@@ -78,6 +79,7 @@ def get_down_block( ...@@ -78,6 +79,7 @@ def get_down_block(
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
) )
elif down_block_type == "SkipDownBlock2D": elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D( return SkipDownBlock2D(
...@@ -143,6 +145,7 @@ def get_up_block( ...@@ -143,6 +145,7 @@ def get_up_block(
cross_attention_dim=None, cross_attention_dim=None,
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=False, use_linear_projection=False,
only_cross_attention=False,
): ):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D": if up_block_type == "UpBlock2D":
...@@ -174,6 +177,7 @@ def get_up_block( ...@@ -174,6 +177,7 @@ def get_up_block(
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
) )
elif up_block_type == "AttnUpBlock2D": elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D( return AttnUpBlock2D(
...@@ -530,6 +534,7 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -530,6 +534,7 @@ class CrossAttnDownBlock2D(nn.Module):
add_downsample=True, add_downsample=True,
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=False, use_linear_projection=False,
only_cross_attention=False,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -564,6 +569,7 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -564,6 +569,7 @@ class CrossAttnDownBlock2D(nn.Module):
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
) )
) )
else: else:
...@@ -1129,6 +1135,7 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1129,6 +1135,7 @@ class CrossAttnUpBlock2D(nn.Module):
add_upsample=True, add_upsample=True,
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=False, use_linear_projection=False,
only_cross_attention=False,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1165,6 +1172,7 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1165,6 +1172,7 @@ class CrossAttnUpBlock2D(nn.Module):
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
) )
) )
else: else:
......
...@@ -98,6 +98,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -98,6 +98,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
"DownBlock2D", "DownBlock2D",
), ),
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2, layers_per_block: int = 2,
downsample_padding: int = 1, downsample_padding: int = 1,
...@@ -109,6 +110,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -109,6 +110,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
attention_head_dim: Union[int, Tuple[int]] = 8, attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False, dual_cross_attention: bool = False,
use_linear_projection: bool = False, use_linear_projection: bool = False,
num_class_embeds: Optional[int] = None,
): ):
super().__init__() super().__init__()
...@@ -124,10 +126,17 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -124,10 +126,17 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
# class embedding
if num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
self.down_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
self.mid_block = None self.mid_block = None
self.up_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int): if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types) attention_head_dim = (attention_head_dim,) * len(down_block_types)
...@@ -153,6 +162,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -153,6 +162,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
) )
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
...@@ -177,6 +187,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -177,6 +187,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# up # up
reversed_block_out_channels = list(reversed(block_out_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim)) reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types): for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1 is_final_block = i == len(block_out_channels) - 1
...@@ -207,6 +218,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -207,6 +218,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
attn_num_head_channels=reversed_attention_head_dim[i], attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
...@@ -258,6 +270,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -258,6 +270,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
sample: torch.FloatTensor, sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int], timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]: ) -> Union[UNet2DConditionOutput, Tuple]:
r""" r"""
...@@ -310,6 +323,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -310,6 +323,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
t_emb = t_emb.to(dtype=self.dtype) t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb) emb = self.time_embedding(t_emb)
if self.config.num_class_embeds is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process # 2. pre-process
sample = self.conv_in(sample) sample = self.conv_in(sample)
......
...@@ -166,6 +166,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -166,6 +166,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
"CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",
"CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",
), ),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2, layers_per_block: int = 2,
downsample_padding: int = 1, downsample_padding: int = 1,
...@@ -177,6 +178,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -177,6 +178,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
attention_head_dim: Union[int, Tuple[int]] = 8, attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False, dual_cross_attention: bool = False,
use_linear_projection: bool = False, use_linear_projection: bool = False,
num_class_embeds: Optional[int] = None,
): ):
super().__init__() super().__init__()
...@@ -192,10 +194,17 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -192,10 +194,17 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
# class embedding
if num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
self.down_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
self.mid_block = None self.mid_block = None
self.up_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int): if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types) attention_head_dim = (attention_head_dim,) * len(down_block_types)
...@@ -221,6 +230,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -221,6 +230,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
) )
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
...@@ -245,6 +255,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -245,6 +255,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# up # up
reversed_block_out_channels = list(reversed(block_out_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim)) reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types): for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1 is_final_block = i == len(block_out_channels) - 1
...@@ -275,6 +286,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -275,6 +286,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
attn_num_head_channels=reversed_attention_head_dim[i], attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
...@@ -326,6 +338,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -326,6 +338,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
sample: torch.FloatTensor, sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int], timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]: ) -> Union[UNet2DConditionOutput, Tuple]:
r""" r"""
...@@ -378,6 +391,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -378,6 +391,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
t_emb = t_emb.to(dtype=self.dtype) t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb) emb = self.time_embedding(t_emb)
if self.config.num_class_embeds is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process # 2. pre-process
sample = self.conv_in(sample) sample = self.conv_in(sample)
...@@ -648,6 +667,7 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -648,6 +667,7 @@ class CrossAttnDownBlockFlat(nn.Module):
add_downsample=True, add_downsample=True,
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=False, use_linear_projection=False,
only_cross_attention=False,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -682,6 +702,7 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -682,6 +702,7 @@ class CrossAttnDownBlockFlat(nn.Module):
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
) )
) )
else: else:
...@@ -861,6 +882,7 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -861,6 +882,7 @@ class CrossAttnUpBlockFlat(nn.Module):
add_upsample=True, add_upsample=True,
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=False, use_linear_projection=False,
only_cross_attention=False,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -897,6 +919,7 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -897,6 +919,7 @@ class CrossAttnUpBlockFlat(nn.Module):
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
) )
) )
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment