Commit c482d7bd authored by Patrick von Platen's avatar Patrick von Platen
Browse files

some clean up

parent 31d1f3c8
......@@ -15,22 +15,12 @@
# helpers functions
import copy
import math
from pathlib import Path
import torch
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.optim import Adam
from torch.utils import data
from PIL import Image
from tqdm import tqdm
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention2d import AttentionBlock
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample
......@@ -219,11 +209,7 @@ class UNetModel(ModelMixin, ConfigMixin):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
# self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block])
# h = self.down[i_level].attn_2[i_block](h)
h = self.down[i_level].attn[i_block](h)
# print("Result", (h - h_2).abs().sum())
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
......
......@@ -6,7 +6,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention2d import AttentionBlock
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample
......
import torch
from numpy import pad
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention2d import LinearAttention
from .attention import LinearAttention
from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample
......@@ -55,32 +54,6 @@ class ResnetBlock(torch.nn.Module):
return output
class old_LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super(LinearAttention, self).__init__()
self.heads = heads
self.dim_head = dim_head
hidden_dim = dim_head * heads
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
q, k, v = (
qkv.reshape(b, 3, self.heads, self.dim_head, h, w)
.permute(1, 0, 2, 3, 4, 5)
.reshape(3, b, self.heads, self.dim_head, -1)
)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
return self.to_out(out)
class Residual(torch.nn.Module):
def __init__(self, fn):
super(Residual, self).__init__()
......
......@@ -9,7 +9,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention2d import AttentionBlock
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample
......
......@@ -26,7 +26,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention2d import AttentionBlock
from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment