Unverified Commit c352faea authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add MidBlock to Grad-TTS (#74)

Finish
parent 10798663
...@@ -5,6 +5,7 @@ from ..modeling_utils import ModelMixin ...@@ -5,6 +5,7 @@ from ..modeling_utils import ModelMixin
from .attention import LinearAttention from .attention import LinearAttention
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
from .unet_new import UNetMidBlock2D
class Mish(torch.nn.Module): class Mish(torch.nn.Module):
...@@ -111,6 +112,17 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -111,6 +112,17 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
) )
mid_dim = dims[-1] mid_dim = dims[-1]
self.mid = UNetMidBlock2D(
in_channels=mid_dim,
temb_channels=dim,
resnet_groups=8,
resnet_pre_norm=False,
resnet_eps=1e-5,
resnet_act_fn="mish",
attention_layer_type="linear",
)
self.mid_block1 = ResnetBlock2D( self.mid_block1 = ResnetBlock2D(
in_channels=mid_dim, in_channels=mid_dim,
out_channels=mid_dim, out_channels=mid_dim,
...@@ -132,8 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -132,8 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
non_linearity="mish", non_linearity="mish",
overwrite_for_grad_tts=True, overwrite_for_grad_tts=True,
) )
self.mid.resnet_1 = self.mid_block1
# self.mid = UNetMidBlock2D self.mid.attn = self.mid_attn
self.mid.resnet_2 = self.mid_block2
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append( self.ups.append(
...@@ -198,9 +211,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -198,9 +211,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
masks = masks[:-1] masks = masks[:-1]
mask_mid = masks[-1] mask_mid = masks[-1]
x = self.mid_block1(x, t, mask_mid)
x = self.mid_attn(x) x = self.mid(x, t, mask=mask_mid)
x = self.mid_block2(x, t, mask_mid)
for resnet1, resnet2, attn, upsample in self.ups: for resnet1, resnet2, attn, upsample in self.ups:
mask_up = masks.pop() mask_up = masks.pop()
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from torch import nn from torch import nn
from .attention import AttentionBlock, SpatialTransformer from .attention import AttentionBlock, LinearAttention, SpatialTransformer
from .resnet import ResnetBlock2D from .resnet import ResnetBlock2D
...@@ -23,11 +23,12 @@ class UNetMidBlock2D(nn.Module): ...@@ -23,11 +23,12 @@ class UNetMidBlock2D(nn.Module):
self, self,
in_channels: int, in_channels: int,
temb_channels: int, temb_channels: int,
dropout: float, dropout: float = 0.0,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attention_layer_type: str = "self", attention_layer_type: str = "self",
attn_num_heads=1, attn_num_heads=1,
attn_num_head_channels=None, attn_num_head_channels=None,
...@@ -49,6 +50,7 @@ class UNetMidBlock2D(nn.Module): ...@@ -49,6 +50,7 @@ class UNetMidBlock2D(nn.Module):
time_embedding_norm=resnet_time_scale_shift, time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
) )
if attention_layer_type == "self": if attention_layer_type == "self":
...@@ -61,15 +63,14 @@ class UNetMidBlock2D(nn.Module): ...@@ -61,15 +63,14 @@ class UNetMidBlock2D(nn.Module):
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
) )
elif attention_layer_type == "spatial": elif attention_layer_type == "spatial":
self.attn = ( self.attn = SpatialTransformer(
SpatialTransformer( attn_num_heads,
in_channels, attn_num_head_channels,
attn_num_heads, depth=attn_depth,
attn_num_head_channels, context_dim=attn_encoder_channels,
depth=attn_depth,
context_dim=attn_encoder_channels,
),
) )
elif attention_layer_type == "linear":
self.attn = LinearAttention(in_channels)
self.resnet_2 = ResnetBlock2D( self.resnet_2 = ResnetBlock2D(
in_channels=in_channels, in_channels=in_channels,
...@@ -80,6 +81,7 @@ class UNetMidBlock2D(nn.Module): ...@@ -80,6 +81,7 @@ class UNetMidBlock2D(nn.Module):
time_embedding_norm=resnet_time_scale_shift, time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
) )
# TODO(Patrick) - delete all of the following code # TODO(Patrick) - delete all of the following code
...@@ -110,19 +112,20 @@ class UNetMidBlock2D(nn.Module): ...@@ -110,19 +112,20 @@ class UNetMidBlock2D(nn.Module):
eps=resnet_eps, eps=resnet_eps,
) )
def forward(self, hidden_states, temb=None, encoder_states=None): def forward(self, hidden_states, temb=None, encoder_states=None, mask=1.0):
if not self.is_overwritten and self.overwrite_unet: if not self.is_overwritten and self.overwrite_unet:
self.resnet_1 = self.block_1 self.resnet_1 = self.block_1
self.attn = self.attn_1 self.attn = self.attn_1
self.resnet_2 = self.block_2 self.resnet_2 = self.block_2
self.is_overwritten = True self.is_overwritten = True
hidden_states = self.resnet_1(hidden_states, temb) hidden_states = self.resnet_1(hidden_states, temb, mask=mask)
if encoder_states is None: if encoder_states is None:
hidden_states = self.attn(hidden_states) hidden_states = self.attn(hidden_states)
else: else:
hidden_states = self.attn(hidden_states, encoder_states) hidden_states = self.attn(hidden_states, encoder_states)
hidden_states = self.resnet_2(hidden_states, temb) hidden_states = self.resnet_2(hidden_states, temb, mask=mask)
return hidden_states return hidden_states
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment