Commit af6c1439 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

remove einops

parent d726857f
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -29,6 +28,7 @@ def conv_nd(dims, *args, **kwargs):
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_transpose_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
......@@ -81,15 +81,15 @@ class Upsample(nn.Module):
assert x.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
......@@ -138,6 +138,7 @@ class UNetUpsample(nn.Module):
x = self.conv(x)
return x
class GlideUpsample(nn.Module):
"""
An upsampling layer with an optional convolution.
......
import torch
try:
from einops import rearrange
except:
print("Einops is not installed")
pass
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
......@@ -81,6 +74,7 @@ class LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super(LinearAttention, self).__init__()
self.heads = heads
self.dim_head = dim_head
hidden_dim = dim_head * heads
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
......@@ -88,11 +82,17 @@ class LinearAttention(torch.nn.Module):
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
q, k, v = (
qkv.reshape(b, 3, self.heads, self.dim_head, h, w)
.permute(1, 0, 2, 3, 4, 5)
.reshape(3, b, self.heads, self.dim_head, -1)
)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
return self.to_out(out)
......
......@@ -6,14 +6,15 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
#try:
# try:
# from einops import rearrange, repeat
#except:
# except:
# print("Einops is not installed")
# pass
......@@ -80,7 +81,7 @@ def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
#class LinearAttention(nn.Module):
# class LinearAttention(nn.Module):
# def __init__(self, dim, heads=4, dim_head=32):
# super().__init__()
# self.heads = heads
......@@ -100,7 +101,7 @@ def Normalize(in_channels):
# return self.to_out(out)
#
#class SpatialSelfAttention(nn.Module):
# class SpatialSelfAttention(nn.Module):
# def __init__(self, in_channels):
# super().__init__()
# self.in_channels = in_channels
......@@ -118,7 +119,7 @@ def Normalize(in_channels):
# k = self.k(h_)
# v = self.v(h_)
#
# compute attention
# compute attention
# b, c, h, w = q.shape
# q = rearrange(q, "b c h w -> b (h w) c")
# k = rearrange(k, "b c h w -> b c (h w)")
......@@ -127,7 +128,7 @@ def Normalize(in_channels):
# w_ = w_ * (int(c) ** (-0.5))
# w_ = torch.nn.functional.softmax(w_, dim=2)
#
# attend to values
# attend to values
# v = rearrange(v, "b c h w -> b c (h w)")
# w_ = rearrange(w_, "b i j -> b j i")
# h_ = torch.einsum("bij,bjk->bik", v, w_)
......@@ -137,6 +138,7 @@ def Normalize(in_channels):
# return x + h_
#
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
......@@ -176,7 +178,7 @@ class CrossAttention(nn.Module):
k = self.to_k(context)
v = self.to_v(context)
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
q = self.reshape_heads_to_batch_dim(q)
k = self.reshape_heads_to_batch_dim(k)
......@@ -185,12 +187,12 @@ class CrossAttention(nn.Module):
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
# mask = rearrange(mask, "b ... -> b (...)")
# mask = rearrange(mask, "b ... -> b (...)")
maks = mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
# mask = repeat(mask, "b j -> (b h) () j", h=h)
# mask = repeat(mask, "b j -> (b h) () j", h=h)
mask = mask[:, None, :].repeat(h, 1, 1)
# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
......@@ -198,7 +200,7 @@ class CrossAttention(nn.Module):
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.reshape_batch_dim_to_heads(out)
# out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
# out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
......
......@@ -5,18 +5,19 @@ import math
import torch
import torch.nn as nn
try:
import einops
from einops.layers.torch import Rearrange
except:
print("Einops is not installed")
pass
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
# try:
# import einops
# from einops.layers.torch import Rearrange
# except:
# print("Einops is not installed")
# pass
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
......@@ -50,6 +51,21 @@ class Upsample1d(nn.Module):
return self.conv(x)
class RearrangeDim(nn.Module):
def __init__(self):
super().__init__()
def forward(self, tensor):
if len(tensor.shape) == 2:
return tensor[:, :, None]
if len(tensor.shape) == 3:
return tensor[:, :, None, :]
elif len(tensor.shape) == 4:
return tensor[:, :, 0, :]
else:
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
......@@ -60,9 +76,11 @@ class Conv1dBlock(nn.Module):
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
Rearrange("batch channels horizon -> batch channels 1 horizon"),
RearrangeDim(),
# Rearrange("batch channels horizon -> batch channels 1 horizon"),
nn.GroupNorm(n_groups, out_channels),
Rearrange("batch channels 1 horizon -> batch channels horizon"),
RearrangeDim(),
# Rearrange("batch channels 1 horizon -> batch channels horizon"),
nn.Mish(),
)
......@@ -84,7 +102,8 @@ class ResidualTemporalBlock(nn.Module):
self.time_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(embed_dim, out_channels),
Rearrange("batch t -> batch t 1"),
RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
)
self.residual_conv = (
......@@ -184,7 +203,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
x : [ batch x horizon x transition ]
"""
x = einops.rearrange(x, "b h t -> b t h")
# x = einops.rearrange(x, "b h t -> b t h")
x = x.permute(0, 2, 1)
t = self.time_mlp(time)
h = []
......@@ -206,7 +226,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
x = self.final_conv(x)
x = einops.rearrange(x, "b t h -> b h t")
# x = einops.rearrange(x, "b t h -> b h t")
x = x.permute(0, 2, 1)
return x
......@@ -263,7 +284,8 @@ class TemporalValue(nn.Module):
x : [ batch x horizon x transition ]
"""
x = einops.rearrange(x, "b h t -> b t h")
# x = einops.rearrange(x, "b h t -> b t h")
x = x.permute(0, 2, 1)
t = self.time_mlp(time)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment