"docs/vscode:/vscode.git/clone" did not exist on "993fd3f94baff137216d2b16dade638f3b6c99c3"
Commit af6c1439 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

remove einops

parent d726857f
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -29,6 +28,7 @@ def conv_nd(dims, *args, **kwargs): ...@@ -29,6 +28,7 @@ def conv_nd(dims, *args, **kwargs):
return nn.Conv3d(*args, **kwargs) return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}") raise ValueError(f"unsupported dimensions: {dims}")
def conv_transpose_nd(dims, *args, **kwargs): def conv_transpose_nd(dims, *args, **kwargs):
""" """
Create a 1D, 2D, or 3D convolution module. Create a 1D, 2D, or 3D convolution module.
...@@ -138,6 +138,7 @@ class UNetUpsample(nn.Module): ...@@ -138,6 +138,7 @@ class UNetUpsample(nn.Module):
x = self.conv(x) x = self.conv(x)
return x return x
class GlideUpsample(nn.Module): class GlideUpsample(nn.Module):
""" """
An upsampling layer with an optional convolution. An upsampling layer with an optional convolution.
......
import torch import torch
try:
from einops import rearrange
except:
print("Einops is not installed")
pass
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
...@@ -81,6 +74,7 @@ class LinearAttention(torch.nn.Module): ...@@ -81,6 +74,7 @@ class LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32): def __init__(self, dim, heads=4, dim_head=32):
super(LinearAttention, self).__init__() super(LinearAttention, self).__init__()
self.heads = heads self.heads = heads
self.dim_head = dim_head
hidden_dim = dim_head * heads hidden_dim = dim_head * heads
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
...@@ -88,11 +82,17 @@ class LinearAttention(torch.nn.Module): ...@@ -88,11 +82,17 @@ class LinearAttention(torch.nn.Module):
def forward(self, x): def forward(self, x):
b, c, h, w = x.shape b, c, h, w = x.shape
qkv = self.to_qkv(x) qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) # q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
q, k, v = (
qkv.reshape(b, 3, self.heads, self.dim_head, h, w)
.permute(1, 0, 2, 3, 4, 5)
.reshape(3, b, self.heads, self.dim_head, -1)
)
k = k.softmax(dim=-1) k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v) context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q) out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w) # out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
return self.to_out(out) return self.to_out(out)
......
...@@ -6,14 +6,15 @@ import numpy as np ...@@ -6,14 +6,15 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
#try: # try:
# from einops import rearrange, repeat # from einops import rearrange, repeat
#except: # except:
# print("Einops is not installed") # print("Einops is not installed")
# pass # pass
...@@ -80,7 +81,7 @@ def Normalize(in_channels): ...@@ -80,7 +81,7 @@ def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
#class LinearAttention(nn.Module): # class LinearAttention(nn.Module):
# def __init__(self, dim, heads=4, dim_head=32): # def __init__(self, dim, heads=4, dim_head=32):
# super().__init__() # super().__init__()
# self.heads = heads # self.heads = heads
...@@ -100,7 +101,7 @@ def Normalize(in_channels): ...@@ -100,7 +101,7 @@ def Normalize(in_channels):
# return self.to_out(out) # return self.to_out(out)
# #
#class SpatialSelfAttention(nn.Module): # class SpatialSelfAttention(nn.Module):
# def __init__(self, in_channels): # def __init__(self, in_channels):
# super().__init__() # super().__init__()
# self.in_channels = in_channels # self.in_channels = in_channels
...@@ -118,7 +119,7 @@ def Normalize(in_channels): ...@@ -118,7 +119,7 @@ def Normalize(in_channels):
# k = self.k(h_) # k = self.k(h_)
# v = self.v(h_) # v = self.v(h_)
# #
# compute attention # compute attention
# b, c, h, w = q.shape # b, c, h, w = q.shape
# q = rearrange(q, "b c h w -> b (h w) c") # q = rearrange(q, "b c h w -> b (h w) c")
# k = rearrange(k, "b c h w -> b c (h w)") # k = rearrange(k, "b c h w -> b c (h w)")
...@@ -127,7 +128,7 @@ def Normalize(in_channels): ...@@ -127,7 +128,7 @@ def Normalize(in_channels):
# w_ = w_ * (int(c) ** (-0.5)) # w_ = w_ * (int(c) ** (-0.5))
# w_ = torch.nn.functional.softmax(w_, dim=2) # w_ = torch.nn.functional.softmax(w_, dim=2)
# #
# attend to values # attend to values
# v = rearrange(v, "b c h w -> b c (h w)") # v = rearrange(v, "b c h w -> b c (h w)")
# w_ = rearrange(w_, "b i j -> b j i") # w_ = rearrange(w_, "b i j -> b j i")
# h_ = torch.einsum("bij,bjk->bik", v, w_) # h_ = torch.einsum("bij,bjk->bik", v, w_)
...@@ -137,6 +138,7 @@ def Normalize(in_channels): ...@@ -137,6 +138,7 @@ def Normalize(in_channels):
# return x + h_ # return x + h_
# #
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__() super().__init__()
...@@ -176,7 +178,7 @@ class CrossAttention(nn.Module): ...@@ -176,7 +178,7 @@ class CrossAttention(nn.Module):
k = self.to_k(context) k = self.to_k(context)
v = self.to_v(context) v = self.to_v(context)
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) # q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
q = self.reshape_heads_to_batch_dim(q) q = self.reshape_heads_to_batch_dim(q)
k = self.reshape_heads_to_batch_dim(k) k = self.reshape_heads_to_batch_dim(k)
...@@ -185,12 +187,12 @@ class CrossAttention(nn.Module): ...@@ -185,12 +187,12 @@ class CrossAttention(nn.Module):
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask): if exists(mask):
# mask = rearrange(mask, "b ... -> b (...)") # mask = rearrange(mask, "b ... -> b (...)")
maks = mask.reshape(batch_size, -1) maks = mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max max_neg_value = -torch.finfo(sim.dtype).max
# mask = repeat(mask, "b j -> (b h) () j", h=h) # mask = repeat(mask, "b j -> (b h) () j", h=h)
mask = mask[:, None, :].repeat(h, 1, 1) mask = mask[:, None, :].repeat(h, 1, 1)
# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) # x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
sim.masked_fill_(~mask, max_neg_value) sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of # attention, what we cannot get enough of
...@@ -198,7 +200,7 @@ class CrossAttention(nn.Module): ...@@ -198,7 +200,7 @@ class CrossAttention(nn.Module):
out = torch.einsum("b i j, b j d -> b i d", attn, v) out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.reshape_batch_dim_to_heads(out) out = self.reshape_batch_dim_to_heads(out)
# out = rearrange(out, "(b h) n d -> b n (h d)", h=h) # out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out) return self.to_out(out)
......
...@@ -5,18 +5,19 @@ import math ...@@ -5,18 +5,19 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
try:
import einops
from einops.layers.torch import Rearrange
except:
print("Einops is not installed")
pass
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
# try:
# import einops
# from einops.layers.torch import Rearrange
# except:
# print("Einops is not installed")
# pass
class SinusoidalPosEmb(nn.Module): class SinusoidalPosEmb(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
...@@ -50,6 +51,21 @@ class Upsample1d(nn.Module): ...@@ -50,6 +51,21 @@ class Upsample1d(nn.Module):
return self.conv(x) return self.conv(x)
class RearrangeDim(nn.Module):
def __init__(self):
super().__init__()
def forward(self, tensor):
if len(tensor.shape) == 2:
return tensor[:, :, None]
if len(tensor.shape) == 3:
return tensor[:, :, None, :]
elif len(tensor.shape) == 4:
return tensor[:, :, 0, :]
else:
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
class Conv1dBlock(nn.Module): class Conv1dBlock(nn.Module):
""" """
Conv1d --> GroupNorm --> Mish Conv1d --> GroupNorm --> Mish
...@@ -60,9 +76,11 @@ class Conv1dBlock(nn.Module): ...@@ -60,9 +76,11 @@ class Conv1dBlock(nn.Module):
self.block = nn.Sequential( self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
Rearrange("batch channels horizon -> batch channels 1 horizon"), RearrangeDim(),
# Rearrange("batch channels horizon -> batch channels 1 horizon"),
nn.GroupNorm(n_groups, out_channels), nn.GroupNorm(n_groups, out_channels),
Rearrange("batch channels 1 horizon -> batch channels horizon"), RearrangeDim(),
# Rearrange("batch channels 1 horizon -> batch channels horizon"),
nn.Mish(), nn.Mish(),
) )
...@@ -84,7 +102,8 @@ class ResidualTemporalBlock(nn.Module): ...@@ -84,7 +102,8 @@ class ResidualTemporalBlock(nn.Module):
self.time_mlp = nn.Sequential( self.time_mlp = nn.Sequential(
nn.Mish(), nn.Mish(),
nn.Linear(embed_dim, out_channels), nn.Linear(embed_dim, out_channels),
Rearrange("batch t -> batch t 1"), RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
) )
self.residual_conv = ( self.residual_conv = (
...@@ -184,7 +203,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): ...@@ -184,7 +203,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
x : [ batch x horizon x transition ] x : [ batch x horizon x transition ]
""" """
x = einops.rearrange(x, "b h t -> b t h") # x = einops.rearrange(x, "b h t -> b t h")
x = x.permute(0, 2, 1)
t = self.time_mlp(time) t = self.time_mlp(time)
h = [] h = []
...@@ -206,7 +226,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): ...@@ -206,7 +226,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
x = self.final_conv(x) x = self.final_conv(x)
x = einops.rearrange(x, "b t h -> b h t") # x = einops.rearrange(x, "b t h -> b h t")
x = x.permute(0, 2, 1)
return x return x
...@@ -263,7 +284,8 @@ class TemporalValue(nn.Module): ...@@ -263,7 +284,8 @@ class TemporalValue(nn.Module):
x : [ batch x horizon x transition ] x : [ batch x horizon x transition ]
""" """
x = einops.rearrange(x, "b h t -> b t h") # x = einops.rearrange(x, "b h t -> b t h")
x = x.permute(0, 2, 1)
t = self.time_mlp(time) t = self.time_mlp(time)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment