Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
af6c1439
Commit
af6c1439
authored
Jun 27, 2022
by
Patrick von Platen
Browse files
remove einops
parent
d726857f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
38 deletions
+63
-38
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+5
-4
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+9
-9
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+13
-11
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+36
-14
No files found.
src/diffusers/models/resnet.py
View file @
af6c1439
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -29,6 +28,7 @@ def conv_nd(dims, *args, **kwargs):
...
@@ -29,6 +28,7 @@ def conv_nd(dims, *args, **kwargs):
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
conv_transpose_nd
(
dims
,
*
args
,
**
kwargs
):
def
conv_transpose_nd
(
dims
,
*
args
,
**
kwargs
):
"""
"""
Create a 1D, 2D, or 3D convolution module.
Create a 1D, 2D, or 3D convolution module.
...
@@ -81,15 +81,15 @@ class Upsample(nn.Module):
...
@@ -81,15 +81,15 @@ class Upsample(nn.Module):
assert
x
.
shape
[
1
]
==
self
.
channels
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
use_conv_transpose
:
if
self
.
use_conv_transpose
:
return
self
.
conv
(
x
)
return
self
.
conv
(
x
)
if
self
.
dims
==
3
:
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
x
=
F
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
use_conv
:
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
x
=
self
.
conv
(
x
)
return
x
return
x
...
@@ -138,6 +138,7 @@ class UNetUpsample(nn.Module):
...
@@ -138,6 +138,7 @@ class UNetUpsample(nn.Module):
x
=
self
.
conv
(
x
)
x
=
self
.
conv
(
x
)
return
x
return
x
class
GlideUpsample
(
nn
.
Module
):
class
GlideUpsample
(
nn
.
Module
):
"""
"""
An upsampling layer with an optional convolution.
An upsampling layer with an optional convolution.
...
...
src/diffusers/models/unet_grad_tts.py
View file @
af6c1439
import
torch
import
torch
try
:
from
einops
import
rearrange
except
:
print
(
"Einops is not installed"
)
pass
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
...
@@ -81,6 +74,7 @@ class LinearAttention(torch.nn.Module):
...
@@ -81,6 +74,7 @@ class LinearAttention(torch.nn.Module):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
super
(
LinearAttention
,
self
).
__init__
()
super
(
LinearAttention
,
self
).
__init__
()
self
.
heads
=
heads
self
.
heads
=
heads
self
.
dim_head
=
dim_head
hidden_dim
=
dim_head
*
heads
hidden_dim
=
dim_head
*
heads
self
.
to_qkv
=
torch
.
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_qkv
=
torch
.
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_out
=
torch
.
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
self
.
to_out
=
torch
.
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
...
@@ -88,11 +82,17 @@ class LinearAttention(torch.nn.Module):
...
@@ -88,11 +82,17 @@ class LinearAttention(torch.nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
)
qkv
=
self
.
to_qkv
(
x
)
q
,
k
,
v
=
rearrange
(
qkv
,
"b (qkv heads c) h w -> qkv b heads c (h w)"
,
heads
=
self
.
heads
,
qkv
=
3
)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
q
,
k
,
v
=
(
qkv
.
reshape
(
b
,
3
,
self
.
heads
,
self
.
dim_head
,
h
,
w
)
.
permute
(
1
,
0
,
2
,
3
,
4
,
5
)
.
reshape
(
3
,
b
,
self
.
heads
,
self
.
dim_head
,
-
1
)
)
k
=
k
.
softmax
(
dim
=-
1
)
k
=
k
.
softmax
(
dim
=-
1
)
context
=
torch
.
einsum
(
"bhdn,bhen->bhde"
,
k
,
v
)
context
=
torch
.
einsum
(
"bhdn,bhen->bhde"
,
k
,
v
)
out
=
torch
.
einsum
(
"bhde,bhdn->bhen"
,
context
,
q
)
out
=
torch
.
einsum
(
"bhde,bhdn->bhen"
,
context
,
q
)
out
=
rearrange
(
out
,
"b heads c (h w) -> b (heads c) h w"
,
heads
=
self
.
heads
,
h
=
h
,
w
=
w
)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
out
=
out
.
reshape
(
b
,
self
.
heads
,
self
.
dim_head
,
h
,
w
).
reshape
(
b
,
self
.
heads
*
self
.
dim_head
,
h
,
w
)
return
self
.
to_out
(
out
)
return
self
.
to_out
(
out
)
...
...
src/diffusers/models/unet_ldm.py
View file @
af6c1439
...
@@ -6,14 +6,15 @@ import numpy as np
...
@@ -6,14 +6,15 @@ import numpy as np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
#try:
#
try:
# from einops import rearrange, repeat
# from einops import rearrange, repeat
#except:
#
except:
# print("Einops is not installed")
# print("Einops is not installed")
# pass
# pass
...
@@ -80,7 +81,7 @@ def Normalize(in_channels):
...
@@ -80,7 +81,7 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
#class LinearAttention(nn.Module):
#
class LinearAttention(nn.Module):
# def __init__(self, dim, heads=4, dim_head=32):
# def __init__(self, dim, heads=4, dim_head=32):
# super().__init__()
# super().__init__()
# self.heads = heads
# self.heads = heads
...
@@ -100,7 +101,7 @@ def Normalize(in_channels):
...
@@ -100,7 +101,7 @@ def Normalize(in_channels):
# return self.to_out(out)
# return self.to_out(out)
#
#
#class SpatialSelfAttention(nn.Module):
#
class SpatialSelfAttention(nn.Module):
# def __init__(self, in_channels):
# def __init__(self, in_channels):
# super().__init__()
# super().__init__()
# self.in_channels = in_channels
# self.in_channels = in_channels
...
@@ -118,7 +119,7 @@ def Normalize(in_channels):
...
@@ -118,7 +119,7 @@ def Normalize(in_channels):
# k = self.k(h_)
# k = self.k(h_)
# v = self.v(h_)
# v = self.v(h_)
#
#
# compute attention
# compute attention
# b, c, h, w = q.shape
# b, c, h, w = q.shape
# q = rearrange(q, "b c h w -> b (h w) c")
# q = rearrange(q, "b c h w -> b (h w) c")
# k = rearrange(k, "b c h w -> b c (h w)")
# k = rearrange(k, "b c h w -> b c (h w)")
...
@@ -127,7 +128,7 @@ def Normalize(in_channels):
...
@@ -127,7 +128,7 @@ def Normalize(in_channels):
# w_ = w_ * (int(c) ** (-0.5))
# w_ = w_ * (int(c) ** (-0.5))
# w_ = torch.nn.functional.softmax(w_, dim=2)
# w_ = torch.nn.functional.softmax(w_, dim=2)
#
#
# attend to values
# attend to values
# v = rearrange(v, "b c h w -> b c (h w)")
# v = rearrange(v, "b c h w -> b c (h w)")
# w_ = rearrange(w_, "b i j -> b j i")
# w_ = rearrange(w_, "b i j -> b j i")
# h_ = torch.einsum("bij,bjk->bik", v, w_)
# h_ = torch.einsum("bij,bjk->bik", v, w_)
...
@@ -137,6 +138,7 @@ def Normalize(in_channels):
...
@@ -137,6 +138,7 @@ def Normalize(in_channels):
# return x + h_
# return x + h_
#
#
class
CrossAttention
(
nn
.
Module
):
class
CrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
super
().
__init__
()
super
().
__init__
()
...
@@ -176,7 +178,7 @@ class CrossAttention(nn.Module):
...
@@ -176,7 +178,7 @@ class CrossAttention(nn.Module):
k
=
self
.
to_k
(
context
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
v
=
self
.
to_v
(
context
)
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
q
=
self
.
reshape_heads_to_batch_dim
(
q
)
q
=
self
.
reshape_heads_to_batch_dim
(
q
)
k
=
self
.
reshape_heads_to_batch_dim
(
k
)
k
=
self
.
reshape_heads_to_batch_dim
(
k
)
...
@@ -185,12 +187,12 @@ class CrossAttention(nn.Module):
...
@@ -185,12 +187,12 @@ class CrossAttention(nn.Module):
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
if
exists
(
mask
):
if
exists
(
mask
):
# mask = rearrange(mask, "b ... -> b (...)")
# mask = rearrange(mask, "b ... -> b (...)")
maks
=
mask
.
reshape
(
batch_size
,
-
1
)
maks
=
mask
.
reshape
(
batch_size
,
-
1
)
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
# mask = repeat(mask, "b j -> (b h) () j", h=h)
# mask = repeat(mask, "b j -> (b h) () j", h=h)
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
# attention, what we cannot get enough of
# attention, what we cannot get enough of
...
@@ -198,7 +200,7 @@ class CrossAttention(nn.Module):
...
@@ -198,7 +200,7 @@ class CrossAttention(nn.Module):
out
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
out
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
out
=
self
.
reshape_batch_dim_to_heads
(
out
)
out
=
self
.
reshape_batch_dim_to_heads
(
out
)
# out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
# out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return
self
.
to_out
(
out
)
return
self
.
to_out
(
out
)
...
...
src/diffusers/models/unet_rl.py
View file @
af6c1439
...
@@ -5,18 +5,19 @@ import math
...
@@ -5,18 +5,19 @@ import math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
try
:
import
einops
from
einops.layers.torch
import
Rearrange
except
:
print
(
"Einops is not installed"
)
pass
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
# try:
# import einops
# from einops.layers.torch import Rearrange
# except:
# print("Einops is not installed")
# pass
class
SinusoidalPosEmb
(
nn
.
Module
):
class
SinusoidalPosEmb
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
super
().
__init__
()
...
@@ -50,6 +51,21 @@ class Upsample1d(nn.Module):
...
@@ -50,6 +51,21 @@ class Upsample1d(nn.Module):
return
self
.
conv
(
x
)
return
self
.
conv
(
x
)
class
RearrangeDim
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
tensor
):
if
len
(
tensor
.
shape
)
==
2
:
return
tensor
[:,
:,
None
]
if
len
(
tensor
.
shape
)
==
3
:
return
tensor
[:,
:,
None
,
:]
elif
len
(
tensor
.
shape
)
==
4
:
return
tensor
[:,
:,
0
,
:]
else
:
raise
ValueError
(
f
"`len(tensor)`:
{
len
(
tensor
)
}
has to be 2, 3 or 4."
)
class
Conv1dBlock
(
nn
.
Module
):
class
Conv1dBlock
(
nn
.
Module
):
"""
"""
Conv1d --> GroupNorm --> Mish
Conv1d --> GroupNorm --> Mish
...
@@ -60,9 +76,11 @@ class Conv1dBlock(nn.Module):
...
@@ -60,9 +76,11 @@ class Conv1dBlock(nn.Module):
self
.
block
=
nn
.
Sequential
(
self
.
block
=
nn
.
Sequential
(
nn
.
Conv1d
(
inp_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
),
nn
.
Conv1d
(
inp_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
),
Rearrange
(
"batch channels horizon -> batch channels 1 horizon"
),
RearrangeDim
(),
# Rearrange("batch channels horizon -> batch channels 1 horizon"),
nn
.
GroupNorm
(
n_groups
,
out_channels
),
nn
.
GroupNorm
(
n_groups
,
out_channels
),
Rearrange
(
"batch channels 1 horizon -> batch channels horizon"
),
RearrangeDim
(),
# Rearrange("batch channels 1 horizon -> batch channels horizon"),
nn
.
Mish
(),
nn
.
Mish
(),
)
)
...
@@ -84,7 +102,8 @@ class ResidualTemporalBlock(nn.Module):
...
@@ -84,7 +102,8 @@ class ResidualTemporalBlock(nn.Module):
self
.
time_mlp
=
nn
.
Sequential
(
self
.
time_mlp
=
nn
.
Sequential
(
nn
.
Mish
(),
nn
.
Mish
(),
nn
.
Linear
(
embed_dim
,
out_channels
),
nn
.
Linear
(
embed_dim
,
out_channels
),
Rearrange
(
"batch t -> batch t 1"
),
RearrangeDim
(),
# Rearrange("batch t -> batch t 1"),
)
)
self
.
residual_conv
=
(
self
.
residual_conv
=
(
...
@@ -184,7 +203,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -184,7 +203,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
x : [ batch x horizon x transition ]
x : [ batch x horizon x transition ]
"""
"""
x
=
einops
.
rearrange
(
x
,
"b h t -> b t h"
)
# x = einops.rearrange(x, "b h t -> b t h")
x
=
x
.
permute
(
0
,
2
,
1
)
t
=
self
.
time_mlp
(
time
)
t
=
self
.
time_mlp
(
time
)
h
=
[]
h
=
[]
...
@@ -206,7 +226,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -206,7 +226,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
x
=
self
.
final_conv
(
x
)
x
=
self
.
final_conv
(
x
)
x
=
einops
.
rearrange
(
x
,
"b t h -> b h t"
)
# x = einops.rearrange(x, "b t h -> b h t")
x
=
x
.
permute
(
0
,
2
,
1
)
return
x
return
x
...
@@ -263,7 +284,8 @@ class TemporalValue(nn.Module):
...
@@ -263,7 +284,8 @@ class TemporalValue(nn.Module):
x : [ batch x horizon x transition ]
x : [ batch x horizon x transition ]
"""
"""
x
=
einops
.
rearrange
(
x
,
"b h t -> b t h"
)
# x = einops.rearrange(x, "b h t -> b t h")
x
=
x
.
permute
(
0
,
2
,
1
)
t
=
self
.
time_mlp
(
time
)
t
=
self
.
time_mlp
(
time
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment