Commit 31d1f3c8 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

final fix

parent 635da723
...@@ -91,11 +91,15 @@ class AttentionBlock(nn.Module): ...@@ -91,11 +91,15 @@ class AttentionBlock(nn.Module):
self.NIN_2 = NIN(channels, channels) self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels) self.NIN_3 = NIN(channels, channels)
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
self.is_overwritten = False self.is_overwritten = False
def set_weights(self, module): def set_weights(self, module):
if self.overwrite_qkv: if self.overwrite_qkv:
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[:, :, :, 0] qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
:, :, :, 0
]
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0) qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
self.qkv.weight.data = qkv_weight self.qkv.weight.data = qkv_weight
...@@ -107,14 +111,19 @@ class AttentionBlock(nn.Module): ...@@ -107,14 +111,19 @@ class AttentionBlock(nn.Module):
self.proj_out = proj_out self.proj_out = proj_out
elif self.overwrite_linear: elif self.overwrite_linear:
self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None] self.qkv.weight.data = torch.concat(
[self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0
)[:, :, None]
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0) self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None] self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
self.proj_out.bias.data = self.NIN_3.b.data self.proj_out.bias.data = self.NIN_3.b.data
self.norm.weight.data = self.GroupNorm_0.weight.data
self.norm.bias.data = self.GroupNorm_0.bias.data
def forward(self, x, encoder_out=None): def forward(self, x, encoder_out=None):
if self.overwrite_qkv and not self.is_overwritten: if (self.overwrite_qkv or self.overwrite_linear) and not self.is_overwritten:
self.set_weights(self) self.set_weights(self)
self.is_overwritten = True self.is_overwritten = True
...@@ -152,7 +161,7 @@ class AttentionBlock(nn.Module): ...@@ -152,7 +161,7 @@ class AttentionBlock(nn.Module):
# unet_score_estimation.py # unet_score_estimation.py
#class AttnBlockpp(nn.Module): # class AttnBlockpp(nn.Module):
# """Channel-wise self-attention block. Modified from DDPM.""" # """Channel-wise self-attention block. Modified from DDPM."""
# #
# def __init__( # def __init__(
...@@ -187,14 +196,11 @@ class AttentionBlock(nn.Module): ...@@ -187,14 +196,11 @@ class AttentionBlock(nn.Module):
# self.num_heads = channels // num_head_channels # self.num_heads = channels // num_head_channels
# #
# self.use_checkpoint = use_checkpoint # self.use_checkpoint = use_checkpoint
# self.norm = normalization(channels, num_groups=num_groups, eps=1e-6, swish=None) # self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
# self.qkv = conv_nd(1, channels, channels * 3, 1) # self.qkv = nn.Conv1d(channels, channels * 3, 1)
# self.n_heads = self.num_heads # self.n_heads = self.num_heads
# #
# if encoder_channels is not None: # self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
# self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
#
# self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
# #
# self.is_weight_set = False # self.is_weight_set = False
# #
...@@ -205,6 +211,9 @@ class AttentionBlock(nn.Module): ...@@ -205,6 +211,9 @@ class AttentionBlock(nn.Module):
# self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None] # self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
# self.proj_out.bias.data = self.NIN_3.b.data # self.proj_out.bias.data = self.NIN_3.b.data
# #
# self.norm.weight.data = self.GroupNorm_0.weight.data
# self.norm.bias.data = self.GroupNorm_0.bias.data
#
# def forward(self, x): # def forward(self, x):
# if not self.is_weight_set: # if not self.is_weight_set:
# self.set_weights() # self.set_weights()
...@@ -261,6 +270,7 @@ class AttentionBlock(nn.Module): ...@@ -261,6 +270,7 @@ class AttentionBlock(nn.Module):
# #
# return (x + h) / np.sqrt(2.0) # return (x + h) / np.sqrt(2.0)
# TODO(Patrick) - this can and should be removed # TODO(Patrick) - this can and should be removed
def zero_module(module): def zero_module(module):
""" """
......
...@@ -30,9 +30,9 @@ from tqdm import tqdm ...@@ -30,9 +30,9 @@ from tqdm import tqdm
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention2d import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample, Upsample
from .attention2d import AttentionBlock
def nonlinearity(x): def nonlinearity(x):
...@@ -219,11 +219,11 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -219,11 +219,11 @@ class UNetModel(ModelMixin, ConfigMixin):
for i_block in range(self.num_res_blocks): for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb) h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0: if len(self.down[i_level].attn) > 0:
# self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block]) # self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block])
# h = self.down[i_level].attn_2[i_block](h) # h = self.down[i_level].attn_2[i_block](h)
h = self.down[i_level].attn[i_block](h) h = self.down[i_level].attn[i_block](h)
# print("Result", (h - h_2).abs().sum()) # print("Result", (h - h_2).abs().sum())
hs.append(h) hs.append(h)
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1])) hs.append(self.down[i_level].downsample(hs[-1]))
......
...@@ -3,9 +3,9 @@ from numpy import pad ...@@ -3,9 +3,9 @@ from numpy import pad
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention2d import LinearAttention
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample, Upsample
from .attention2d import LinearAttention
class Mish(torch.nn.Module): class Mish(torch.nn.Module):
......
...@@ -16,18 +16,18 @@ ...@@ -16,18 +16,18 @@
# helpers functions # helpers functions
import functools import functools
import math
import string import string
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import math
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .attention2d import AttentionBlock from .attention2d import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
...@@ -728,7 +728,6 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -728,7 +728,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
nn.init.zeros_(modules[-1].bias) nn.init.zeros_(modules[-1].bias)
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
if progressive == "output_skip": if progressive == "output_skip":
......
...@@ -859,7 +859,9 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -859,7 +859,9 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -1, -3:, -3:].cpu()
assert image.shape == (1, 3, 32, 32) assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor([-0.5712, -0.6215, -0.5953, -0.5438, -0.4775, -0.4539, -0.5172, -0.4872, -0.5105]) expected_slice = torch.tensor(
[-0.5712, -0.6215, -0.5953, -0.5438, -0.4775, -0.4539, -0.5172, -0.4872, -0.5105]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow @slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment