Unverified Commit 10798663 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Fix attention for Glide (#75)

parent d9316bf8
......@@ -73,6 +73,8 @@ class AttentionBlock(nn.Module):
self.proj = zero_module(nn.Conv1d(channels, channels, 1))
self.overwrite_qkv = overwrite_qkv
self.overwrite_linear = overwrite_linear
if overwrite_qkv:
in_channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
......@@ -80,9 +82,7 @@ class AttentionBlock(nn.Module):
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.overwrite_linear = overwrite_linear
if self.overwrite_linear:
elif self.overwrite_linear:
num_groups = min(channels // 4, 32)
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
......@@ -91,6 +91,8 @@ class AttentionBlock(nn.Module):
self.NIN_3 = NIN(channels, channels)
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
else:
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self.is_overwritten = False
......@@ -120,9 +122,12 @@ class AttentionBlock(nn.Module):
self.norm.weight.data = self.GroupNorm_0.weight.data
self.norm.bias.data = self.GroupNorm_0.bias.data
else:
self.proj.weight.data = module.proj_out.weight.data
self.proj.bias.data = module.proj_out.bias.data
def forward(self, x, encoder_out=None):
if (self.overwrite_qkv or self.overwrite_linear) and not self.is_overwritten:
if not self.is_overwritten:
self.set_weights(self)
self.is_overwritten = True
......
......@@ -133,7 +133,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
overwrite_for_grad_tts=True,
)
# self.mid = UNetMidBlock2D
# self.mid = UNetMidBlock2D
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment