"examples/vscode:/vscode.git/clone" did not exist on "eadf0e2555cfa19b033e02de53553f71ac33536f"
Commit af6c1439 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

remove einops

parent d726857f
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -29,6 +28,7 @@ def conv_nd(dims, *args, **kwargs):
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_transpose_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
......@@ -81,15 +81,15 @@ class Upsample(nn.Module):
assert x.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
......@@ -138,6 +138,7 @@ class UNetUpsample(nn.Module):
x = self.conv(x)
return x
class GlideUpsample(nn.Module):
"""
An upsampling layer with an optional convolution.
......
import torch
try:
from einops import rearrange
except:
print("Einops is not installed")
pass
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
......@@ -81,6 +74,7 @@ class LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super(LinearAttention, self).__init__()
self.heads = heads
self.dim_head = dim_head
hidden_dim = dim_head * heads
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
......@@ -88,11 +82,17 @@ class LinearAttention(torch.nn.Module):
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
q, k, v = (
qkv.reshape(b, 3, self.heads, self.dim_head, h, w)
.permute(1, 0, 2, 3, 4, 5)
.reshape(3, b, self.heads, self.dim_head, -1)
)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
return self.to_out(out)
......
......@@ -6,14 +6,15 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
#try:
# try:
# from einops import rearrange, repeat
#except:
# except:
# print("Einops is not installed")
# pass
......@@ -80,7 +81,7 @@ def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
#class LinearAttention(nn.Module):
# class LinearAttention(nn.Module):
# def __init__(self, dim, heads=4, dim_head=32):
# super().__init__()
# self.heads = heads
......@@ -100,7 +101,7 @@ def Normalize(in_channels):
# return self.to_out(out)
#
#class SpatialSelfAttention(nn.Module):
# class SpatialSelfAttention(nn.Module):
# def __init__(self, in_channels):
# super().__init__()
# self.in_channels = in_channels
......@@ -118,7 +119,7 @@ def Normalize(in_channels):
# k = self.k(h_)
# v = self.v(h_)
#
# compute attention
# compute attention
# b, c, h, w = q.shape
# q = rearrange(q, "b c h w -> b (h w) c")
# k = rearrange(k, "b c h w -> b c (h w)")
......@@ -127,7 +128,7 @@ def Normalize(in_channels):
# w_ = w_ * (int(c) ** (-0.5))
# w_ = torch.nn.functional.softmax(w_, dim=2)
#
# attend to values
# attend to values
# v = rearrange(v, "b c h w -> b c (h w)")
# w_ = rearrange(w_, "b i j -> b j i")
# h_ = torch.einsum("bij,bjk->bik", v, w_)
......@@ -137,6 +138,7 @@ def Normalize(in_channels):
# return x + h_
#
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
......@@ -176,7 +178,7 @@ class CrossAttention(nn.Module):
k = self.to_k(context)
v = self.to_v(context)
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
q = self.reshape_heads_to_batch_dim(q)
k = self.reshape_heads_to_batch_dim(k)
......@@ -185,12 +187,12 @@ class CrossAttention(nn.Module):
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
# mask = rearrange(mask, "b ... -> b (...)")
# mask = rearrange(mask, "b ... -> b (...)")
maks = mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
# mask = repeat(mask, "b j -> (b h) () j", h=h)
# mask = repeat(mask, "b j -> (b h) () j", h=h)
mask = mask[:, None, :].repeat(h, 1, 1)
# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
......@@ -198,7 +200,7 @@ class CrossAttention(nn.Module):
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.reshape_batch_dim_to_heads(out)
# out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
# out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
......
......@@ -5,18 +5,19 @@ import math
import torch
import torch.nn as nn
try:
import einops
from einops.layers.torch import Rearrange
except:
print("Einops is not installed")
pass
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
# try:
# import einops
# from einops.layers.torch import Rearrange
# except:
# print("Einops is not installed")
# pass
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
......@@ -50,6 +51,21 @@ class Upsample1d(nn.Module):
return self.conv(x)
class RearrangeDim(nn.Module):
def __init__(self):
super().__init__()
def forward(self, tensor):
if len(tensor.shape) == 2:
return tensor[:, :, None]
if len(tensor.shape) == 3:
return tensor[:, :, None, :]
elif len(tensor.shape) == 4:
return tensor[:, :, 0, :]
else:
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
......@@ -60,9 +76,11 @@ class Conv1dBlock(nn.Module):
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
Rearrange("batch channels horizon -> batch channels 1 horizon"),
RearrangeDim(),
# Rearrange("batch channels horizon -> batch channels 1 horizon"),
nn.GroupNorm(n_groups, out_channels),
Rearrange("batch channels 1 horizon -> batch channels horizon"),
RearrangeDim(),
# Rearrange("batch channels 1 horizon -> batch channels horizon"),
nn.Mish(),
)
......@@ -84,7 +102,8 @@ class ResidualTemporalBlock(nn.Module):
self.time_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(embed_dim, out_channels),
Rearrange("batch t -> batch t 1"),
RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
)
self.residual_conv = (
......@@ -184,7 +203,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
x : [ batch x horizon x transition ]
"""
x = einops.rearrange(x, "b h t -> b t h")
# x = einops.rearrange(x, "b h t -> b t h")
x = x.permute(0, 2, 1)
t = self.time_mlp(time)
h = []
......@@ -206,7 +226,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
x = self.final_conv(x)
x = einops.rearrange(x, "b t h -> b h t")
# x = einops.rearrange(x, "b t h -> b h t")
x = x.permute(0, 2, 1)
return x
......@@ -263,7 +284,8 @@ class TemporalValue(nn.Module):
x : [ batch x horizon x transition ]
"""
x = einops.rearrange(x, "b h t -> b t h")
# x = einops.rearrange(x, "b h t -> b t h")
x = x.permute(0, 2, 1)
t = self.time_mlp(time)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment