"docs/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "ab03dc4370acaaef05810465a077691472624b2b"
Unverified Commit 10798663 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Fix attention for Glide (#75)

parent d9316bf8
...@@ -73,6 +73,8 @@ class AttentionBlock(nn.Module): ...@@ -73,6 +73,8 @@ class AttentionBlock(nn.Module):
self.proj = zero_module(nn.Conv1d(channels, channels, 1)) self.proj = zero_module(nn.Conv1d(channels, channels, 1))
self.overwrite_qkv = overwrite_qkv self.overwrite_qkv = overwrite_qkv
self.overwrite_linear = overwrite_linear
if overwrite_qkv: if overwrite_qkv:
in_channels = channels in_channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6) self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
...@@ -80,9 +82,7 @@ class AttentionBlock(nn.Module): ...@@ -80,9 +82,7 @@ class AttentionBlock(nn.Module):
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
elif self.overwrite_linear:
self.overwrite_linear = overwrite_linear
if self.overwrite_linear:
num_groups = min(channels // 4, 32) num_groups = min(channels // 4, 32)
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6) self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.NIN_0 = NIN(channels, channels) self.NIN_0 = NIN(channels, channels)
...@@ -91,6 +91,8 @@ class AttentionBlock(nn.Module): ...@@ -91,6 +91,8 @@ class AttentionBlock(nn.Module):
self.NIN_3 = NIN(channels, channels) self.NIN_3 = NIN(channels, channels)
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6) self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
else:
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self.is_overwritten = False self.is_overwritten = False
...@@ -120,9 +122,12 @@ class AttentionBlock(nn.Module): ...@@ -120,9 +122,12 @@ class AttentionBlock(nn.Module):
self.norm.weight.data = self.GroupNorm_0.weight.data self.norm.weight.data = self.GroupNorm_0.weight.data
self.norm.bias.data = self.GroupNorm_0.bias.data self.norm.bias.data = self.GroupNorm_0.bias.data
else:
self.proj.weight.data = module.proj_out.weight.data
self.proj.bias.data = module.proj_out.bias.data
def forward(self, x, encoder_out=None): def forward(self, x, encoder_out=None):
if (self.overwrite_qkv or self.overwrite_linear) and not self.is_overwritten: if not self.is_overwritten:
self.set_weights(self) self.set_weights(self)
self.is_overwritten = True self.is_overwritten = True
......
...@@ -133,7 +133,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -133,7 +133,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
overwrite_for_grad_tts=True, overwrite_for_grad_tts=True,
) )
# self.mid = UNetMidBlock2D # self.mid = UNetMidBlock2D
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append( self.ups.append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment