Commit d726857f authored by Patrick von Platen's avatar Patrick von Platen
Browse files

remove einops from unet_ldm

parent c991ffd4
......@@ -6,19 +6,18 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from einops import rearrange, repeat
except:
print("Einops is not installed")
pass
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
#try:
# from einops import rearrange, repeat
#except:
# print("Einops is not installed")
# pass
def exists(val):
return val is not None
......@@ -153,7 +152,23 @@ class CrossAttention(nn.Module):
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def forward(self, x, context=None, mask=None):
batch_size, sequence_length, dim = x.shape
h = self.heads
q = self.to_q(x)
......@@ -161,21 +176,29 @@ class CrossAttention(nn.Module):
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
q = self.reshape_heads_to_batch_dim(q)
k = self.reshape_heads_to_batch_dim(k)
v = self.reshape_heads_to_batch_dim(v)
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
mask = rearrange(mask, "b ... -> b (...)")
# mask = rearrange(mask, "b ... -> b (...)")
maks = mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, "b j -> (b h) () j", h=h)
# mask = repeat(mask, "b j -> (b h) () j", h=h)
mask = mask[:, None, :].repeat(h, 1, 1)
# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
out = self.reshape_batch_dim_to_heads(out)
# out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
......@@ -233,10 +256,10 @@ class SpatialTransformer(nn.Module):
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c")
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
for block in self.transformer_blocks:
x = block(x, context=context)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
x = self.proj_out(x)
return x + x_in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment