Unverified Commit ea8d58ea authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[MidBlock] Fix mid block (#78)

* upload files

* finish
parent c352faea
...@@ -17,7 +17,7 @@ class LinearAttention(torch.nn.Module): ...@@ -17,7 +17,7 @@ class LinearAttention(torch.nn.Module):
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x): def forward(self, x, encoder_states=None):
b, c, h, w = x.shape b, c, h, w = x.shape
qkv = self.to_qkv(x) qkv = self.to_qkv(x)
q, k, v = ( q, k, v = (
......
...@@ -106,9 +106,20 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -106,9 +106,20 @@ class UNetModel(ModelMixin, ConfigMixin):
self.down.append(down) self.down.append(down)
# middle # middle
self.mid = UNetMidBlock2D( self.mid = nn.Module()
self.mid.block_1 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
self.mid.block_2 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid_new = UNetMidBlock2D(
in_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, overwrite_qkv=True, overwrite_unet=True in_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, overwrite_qkv=True, overwrite_unet=True
) )
self.mid_new.resnets[0] = self.mid.block_1
self.mid_new.attentions[0] = self.mid.attn_1
self.mid_new.resnets[1] = self.mid.block_2
# upsampling # upsampling
self.up = nn.ModuleList() self.up = nn.ModuleList()
...@@ -167,10 +178,7 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -167,10 +178,7 @@ class UNetModel(ModelMixin, ConfigMixin):
hs.append(self.down[i_level].downsample(hs[-1])) hs.append(self.down[i_level].downsample(hs[-1]))
# middle # middle
h = self.mid(hs[-1], temb) h = self.mid_new(hs[-1], temb)
# h = self.mid.block_1(h, temb)
# h = self.mid.attn_1(h)
# h = self.mid.block_2(h, temb)
# upsampling # upsampling
for i_level in reversed(range(self.num_resolutions)): for i_level in reversed(range(self.num_resolutions)):
......
...@@ -266,9 +266,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -266,9 +266,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
overwrite_for_glide=True, overwrite_for_glide=True,
), ),
) )
self.mid.resnet_1 = self.middle_block[0] self.mid.resnets[0] = self.middle_block[0]
self.mid.attn = self.middle_block[1] self.mid.attentions[0] = self.middle_block[1]
self.mid.resnet_2 = self.middle_block[2] self.mid.resnets[1] = self.middle_block[2]
self._feature_size += ch self._feature_size += ch
...@@ -542,7 +542,7 @@ class GlideSuperResUNetModel(GlideUNetModel): ...@@ -542,7 +542,7 @@ class GlideSuperResUNetModel(GlideUNetModel):
for module in self.input_blocks: for module in self.input_blocks:
h = module(h, emb) h = module(h, emb)
hs.append(h) hs.append(h)
h = self.middle_block(h, emb) h = self.mid(h, emb)
for module in self.output_blocks: for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1) h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb) h = module(h, emb)
......
...@@ -19,8 +19,8 @@ class Rezero(torch.nn.Module): ...@@ -19,8 +19,8 @@ class Rezero(torch.nn.Module):
self.fn = fn self.fn = fn
self.g = torch.nn.Parameter(torch.zeros(1)) self.g = torch.nn.Parameter(torch.zeros(1))
def forward(self, x): def forward(self, x, encoder_out=None):
return self.fn(x) * self.g return self.fn(x, encoder_out) * self.g
class Block(torch.nn.Module): class Block(torch.nn.Module):
...@@ -144,9 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -144,9 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
non_linearity="mish", non_linearity="mish",
overwrite_for_grad_tts=True, overwrite_for_grad_tts=True,
) )
self.mid.resnet_1 = self.mid_block1 self.mid.resnets[0] = self.mid_block1
self.mid.attn = self.mid_attn self.mid.attentions[0] = self.mid_attn
self.mid.resnet_2 = self.mid_block2 self.mid.resnets[1] = self.mid_block2
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append( self.ups.append(
......
...@@ -408,9 +408,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -408,9 +408,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
overwrite_for_ldm=True, overwrite_for_ldm=True,
), ),
) )
self.mid.resnet_1 = self.middle_block[0] self.mid.resnets[0] = self.middle_block[0]
self.mid.attn = self.middle_block[1] self.mid.attentions[0] = self.middle_block[1]
self.mid.resnet_2 = self.middle_block[2] self.mid.resnets[1] = self.middle_block[2]
self._feature_size += ch self._feature_size += ch
......
...@@ -24,6 +24,7 @@ class UNetMidBlock2D(nn.Module): ...@@ -24,6 +24,7 @@ class UNetMidBlock2D(nn.Module):
in_channels: int, in_channels: int,
temb_channels: int, temb_channels: int,
dropout: float = 0.0, dropout: float = 0.0,
num_blocks: int = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
...@@ -41,91 +42,95 @@ class UNetMidBlock2D(nn.Module): ...@@ -41,91 +42,95 @@ class UNetMidBlock2D(nn.Module):
): ):
super().__init__() super().__init__()
self.resnet_1 = ResnetBlock2D( # there is always at least one resnet
in_channels=in_channels, resnets = [
out_channels=in_channels, ResnetBlock2D(
temb_channels=temb_channels, in_channels=in_channels,
groups=resnet_groups, out_channels=in_channels,
dropout=dropout, temb_channels=temb_channels,
time_embedding_norm=resnet_time_scale_shift, groups=resnet_groups,
non_linearity=resnet_act_fn, dropout=dropout,
output_scale_factor=output_scale_factor, time_embedding_norm=resnet_time_scale_shift,
pre_norm=resnet_pre_norm, non_linearity=resnet_act_fn,
) output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
if attention_layer_type == "self":
self.attn = AttentionBlock(
in_channels,
num_heads=attn_num_heads,
num_head_channels=attn_num_head_channels,
encoder_channels=attn_encoder_channels,
overwrite_qkv=overwrite_qkv,
rescale_output_factor=output_scale_factor,
)
elif attention_layer_type == "spatial":
self.attn = SpatialTransformer(
attn_num_heads,
attn_num_head_channels,
depth=attn_depth,
context_dim=attn_encoder_channels,
) )
elif attention_layer_type == "linear": ]
self.attn = LinearAttention(in_channels) attentions = []
self.resnet_2 = ResnetBlock2D( for _ in range(num_blocks):
in_channels=in_channels, if attention_layer_type == "self":
out_channels=in_channels, attentions.append(
temb_channels=temb_channels, AttentionBlock(
groups=resnet_groups, in_channels,
dropout=dropout, num_heads=attn_num_heads,
time_embedding_norm=resnet_time_scale_shift, num_head_channels=attn_num_head_channels,
non_linearity=resnet_act_fn, encoder_channels=attn_encoder_channels,
output_scale_factor=output_scale_factor, overwrite_qkv=overwrite_qkv,
pre_norm=resnet_pre_norm, rescale_output_factor=output_scale_factor,
) )
)
elif attention_layer_type == "spatial":
attentions.append(
SpatialTransformer(
in_channels,
attn_num_heads,
attn_num_head_channels,
depth=attn_depth,
context_dim=attn_encoder_channels,
)
)
elif attention_layer_type == "linear":
attentions.append(LinearAttention(in_channels))
# TODO(Patrick) - delete all of the following code resnets.append(
self.is_overwritten = False ResnetBlock2D(
self.overwrite_unet = overwrite_unet in_channels=in_channels,
if self.overwrite_unet: out_channels=in_channels,
block_in = in_channels temb_channels=temb_channels,
self.temb_ch = temb_channels groups=resnet_groups,
self.block_1 = ResnetBlock2D( dropout=dropout,
in_channels=block_in, time_embedding_norm=resnet_time_scale_shift,
out_channels=block_in, non_linearity=resnet_act_fn,
temb_channels=self.temb_ch, output_scale_factor=output_scale_factor,
dropout=dropout, pre_norm=resnet_pre_norm,
eps=resnet_eps, )
)
self.attn_1 = AttentionBlock(
block_in,
num_heads=attn_num_heads,
num_head_channels=attn_num_head_channels,
encoder_channels=attn_encoder_channels,
overwrite_qkv=True,
)
self.block_2 = ResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
eps=resnet_eps,
) )
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_states=None, mask=1.0): def forward(self, hidden_states, temb=None, encoder_states=None, mask=1.0):
if not self.is_overwritten and self.overwrite_unet: hidden_states = self.resnets[0](hidden_states, temb, mask=mask)
self.resnet_1 = self.block_1
self.attn = self.attn_1
self.resnet_2 = self.block_2
self.is_overwritten = True
hidden_states = self.resnet_1(hidden_states, temb, mask=mask) for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_states)
hidden_states = resnet(hidden_states, temb, mask=mask)
if encoder_states is None: return hidden_states
hidden_states = self.attn(hidden_states)
else:
hidden_states = self.attn(hidden_states, encoder_states)
hidden_states = self.resnet_2(hidden_states, temb, mask=mask)
return hidden_states # class UNetResAttnDownBlock(nn.Module):
# def __init__(
# self,
# in_channels: int,
# out_channels: int,
# temb_channels: int,
# dropout: float = 0.0,
# resnet_eps: float = 1e-6,
# resnet_time_scale_shift: str = "default",
# resnet_act_fn: str = "swish",
# resnet_groups: int = 32,
# resnet_pre_norm: bool = True,
# attention_layer_type: str = "self",
# attn_num_heads=1,
# attn_num_head_channels=None,
# attn_encoder_channels=None,
# attn_dim_head=None,
# attn_depth=None,
# output_scale_factor=1.0,
# overwrite_qkv=False,
# overwrite_unet=False,
# ):
#
# self.resents =
...@@ -249,9 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -249,9 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
overwrite_for_score_vde=True, overwrite_for_score_vde=True,
) )
) )
self.mid.resnet_1 = modules[len(modules) - 3] self.mid.resnets[0] = modules[len(modules) - 3]
self.mid.attn = modules[len(modules) - 2] self.mid.attentions[0] = modules[len(modules) - 2]
self.mid.resnet_2 = modules[len(modules) - 1] self.mid.resnets[1] = modules[len(modules) - 1]
pyramid_ch = 0 pyramid_ch = 0
# Upsampling block # Upsampling block
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment