Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
af6c1439
"vscode:/vscode.git/clone" did not exist on "e3a2c7f02cd9e47a9093efe3e9659c8e99e28aac"
Commit
af6c1439
authored
Jun 27, 2022
by
Patrick von Platen
Browse files
remove einops
parent
d726857f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
38 deletions
+63
-38
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+5
-4
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+9
-9
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+13
-11
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+36
-14
No files found.
src/diffusers/models/resnet.py
View file @
af6c1439
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -29,6 +28,7 @@ def conv_nd(dims, *args, **kwargs):
...
@@ -29,6 +28,7 @@ def conv_nd(dims, *args, **kwargs):
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
conv_transpose_nd
(
dims
,
*
args
,
**
kwargs
):
def
conv_transpose_nd
(
dims
,
*
args
,
**
kwargs
):
"""
"""
Create a 1D, 2D, or 3D convolution module.
Create a 1D, 2D, or 3D convolution module.
...
@@ -81,15 +81,15 @@ class Upsample(nn.Module):
...
@@ -81,15 +81,15 @@ class Upsample(nn.Module):
assert
x
.
shape
[
1
]
==
self
.
channels
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
use_conv_transpose
:
if
self
.
use_conv_transpose
:
return
self
.
conv
(
x
)
return
self
.
conv
(
x
)
if
self
.
dims
==
3
:
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
x
=
F
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
use_conv
:
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
x
=
self
.
conv
(
x
)
return
x
return
x
...
@@ -138,6 +138,7 @@ class UNetUpsample(nn.Module):
...
@@ -138,6 +138,7 @@ class UNetUpsample(nn.Module):
x
=
self
.
conv
(
x
)
x
=
self
.
conv
(
x
)
return
x
return
x
class
GlideUpsample
(
nn
.
Module
):
class
GlideUpsample
(
nn
.
Module
):
"""
"""
An upsampling layer with an optional convolution.
An upsampling layer with an optional convolution.
...
...
src/diffusers/models/unet_grad_tts.py
View file @
af6c1439
import
torch
import
torch
try
:
from
einops
import
rearrange
except
:
print
(
"Einops is not installed"
)
pass
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
...
@@ -81,6 +74,7 @@ class LinearAttention(torch.nn.Module):
...
@@ -81,6 +74,7 @@ class LinearAttention(torch.nn.Module):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
super
(
LinearAttention
,
self
).
__init__
()
super
(
LinearAttention
,
self
).
__init__
()
self
.
heads
=
heads
self
.
heads
=
heads
self
.
dim_head
=
dim_head
hidden_dim
=
dim_head
*
heads
hidden_dim
=
dim_head
*
heads
self
.
to_qkv
=
torch
.
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_qkv
=
torch
.
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_out
=
torch
.
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
self
.
to_out
=
torch
.
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
...
@@ -88,11 +82,17 @@ class LinearAttention(torch.nn.Module):
...
@@ -88,11 +82,17 @@ class LinearAttention(torch.nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
)
qkv
=
self
.
to_qkv
(
x
)
q
,
k
,
v
=
rearrange
(
qkv
,
"b (qkv heads c) h w -> qkv b heads c (h w)"
,
heads
=
self
.
heads
,
qkv
=
3
)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
q
,
k
,
v
=
(
qkv
.
reshape
(
b
,
3
,
self
.
heads
,
self
.
dim_head
,
h
,
w
)
.
permute
(
1
,
0
,
2
,
3
,
4
,
5
)
.
reshape
(
3
,
b
,
self
.
heads
,
self
.
dim_head
,
-
1
)
)
k
=
k
.
softmax
(
dim
=-
1
)
k
=
k
.
softmax
(
dim
=-
1
)
context
=
torch
.
einsum
(
"bhdn,bhen->bhde"
,
k
,
v
)
context
=
torch
.
einsum
(
"bhdn,bhen->bhde"
,
k
,
v
)
out
=
torch
.
einsum
(
"bhde,bhdn->bhen"
,
context
,
q
)
out
=
torch
.
einsum
(
"bhde,bhdn->bhen"
,
context
,
q
)
out
=
rearrange
(
out
,
"b heads c (h w) -> b (heads c) h w"
,
heads
=
self
.
heads
,
h
=
h
,
w
=
w
)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
out
=
out
.
reshape
(
b
,
self
.
heads
,
self
.
dim_head
,
h
,
w
).
reshape
(
b
,
self
.
heads
*
self
.
dim_head
,
h
,
w
)
return
self
.
to_out
(
out
)
return
self
.
to_out
(
out
)
...
...
src/diffusers/models/unet_ldm.py
View file @
af6c1439
...
@@ -6,14 +6,15 @@ import numpy as np
...
@@ -6,14 +6,15 @@ import numpy as np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
#try:
#
try:
# from einops import rearrange, repeat
# from einops import rearrange, repeat
#except:
#
except:
# print("Einops is not installed")
# print("Einops is not installed")
# pass
# pass
...
@@ -80,7 +81,7 @@ def Normalize(in_channels):
...
@@ -80,7 +81,7 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
#class LinearAttention(nn.Module):
#
class LinearAttention(nn.Module):
# def __init__(self, dim, heads=4, dim_head=32):
# def __init__(self, dim, heads=4, dim_head=32):
# super().__init__()
# super().__init__()
# self.heads = heads
# self.heads = heads
...
@@ -100,7 +101,7 @@ def Normalize(in_channels):
...
@@ -100,7 +101,7 @@ def Normalize(in_channels):
# return self.to_out(out)
# return self.to_out(out)
#
#
#class SpatialSelfAttention(nn.Module):
#
class SpatialSelfAttention(nn.Module):
# def __init__(self, in_channels):
# def __init__(self, in_channels):
# super().__init__()
# super().__init__()
# self.in_channels = in_channels
# self.in_channels = in_channels
...
@@ -118,7 +119,7 @@ def Normalize(in_channels):
...
@@ -118,7 +119,7 @@ def Normalize(in_channels):
# k = self.k(h_)
# k = self.k(h_)
# v = self.v(h_)
# v = self.v(h_)
#
#
# compute attention
# compute attention
# b, c, h, w = q.shape
# b, c, h, w = q.shape
# q = rearrange(q, "b c h w -> b (h w) c")
# q = rearrange(q, "b c h w -> b (h w) c")
# k = rearrange(k, "b c h w -> b c (h w)")
# k = rearrange(k, "b c h w -> b c (h w)")
...
@@ -127,7 +128,7 @@ def Normalize(in_channels):
...
@@ -127,7 +128,7 @@ def Normalize(in_channels):
# w_ = w_ * (int(c) ** (-0.5))
# w_ = w_ * (int(c) ** (-0.5))
# w_ = torch.nn.functional.softmax(w_, dim=2)
# w_ = torch.nn.functional.softmax(w_, dim=2)
#
#
# attend to values
# attend to values
# v = rearrange(v, "b c h w -> b c (h w)")
# v = rearrange(v, "b c h w -> b c (h w)")
# w_ = rearrange(w_, "b i j -> b j i")
# w_ = rearrange(w_, "b i j -> b j i")
# h_ = torch.einsum("bij,bjk->bik", v, w_)
# h_ = torch.einsum("bij,bjk->bik", v, w_)
...
@@ -137,6 +138,7 @@ def Normalize(in_channels):
...
@@ -137,6 +138,7 @@ def Normalize(in_channels):
# return x + h_
# return x + h_
#
#
class
CrossAttention
(
nn
.
Module
):
class
CrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
super
().
__init__
()
super
().
__init__
()
...
@@ -176,7 +178,7 @@ class CrossAttention(nn.Module):
...
@@ -176,7 +178,7 @@ class CrossAttention(nn.Module):
k
=
self
.
to_k
(
context
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
v
=
self
.
to_v
(
context
)
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
q
=
self
.
reshape_heads_to_batch_dim
(
q
)
q
=
self
.
reshape_heads_to_batch_dim
(
q
)
k
=
self
.
reshape_heads_to_batch_dim
(
k
)
k
=
self
.
reshape_heads_to_batch_dim
(
k
)
...
@@ -185,12 +187,12 @@ class CrossAttention(nn.Module):
...
@@ -185,12 +187,12 @@ class CrossAttention(nn.Module):
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
if
exists
(
mask
):
if
exists
(
mask
):
# mask = rearrange(mask, "b ... -> b (...)")
# mask = rearrange(mask, "b ... -> b (...)")
maks
=
mask
.
reshape
(
batch_size
,
-
1
)
maks
=
mask
.
reshape
(
batch_size
,
-
1
)
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
# mask = repeat(mask, "b j -> (b h) () j", h=h)
# mask = repeat(mask, "b j -> (b h) () j", h=h)
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
# attention, what we cannot get enough of
# attention, what we cannot get enough of
...
@@ -198,7 +200,7 @@ class CrossAttention(nn.Module):
...
@@ -198,7 +200,7 @@ class CrossAttention(nn.Module):
out
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
out
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
out
=
self
.
reshape_batch_dim_to_heads
(
out
)
out
=
self
.
reshape_batch_dim_to_heads
(
out
)
# out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
# out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return
self
.
to_out
(
out
)
return
self
.
to_out
(
out
)
...
...
src/diffusers/models/unet_rl.py
View file @
af6c1439
...
@@ -5,18 +5,19 @@ import math
...
@@ -5,18 +5,19 @@ import math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
try
:
import
einops
from
einops.layers.torch
import
Rearrange
except
:
print
(
"Einops is not installed"
)
pass
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
# try:
# import einops
# from einops.layers.torch import Rearrange
# except:
# print("Einops is not installed")
# pass
class
SinusoidalPosEmb
(
nn
.
Module
):
class
SinusoidalPosEmb
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
super
().
__init__
()
...
@@ -50,6 +51,21 @@ class Upsample1d(nn.Module):
...
@@ -50,6 +51,21 @@ class Upsample1d(nn.Module):
return
self
.
conv
(
x
)
return
self
.
conv
(
x
)
class
RearrangeDim
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
tensor
):
if
len
(
tensor
.
shape
)
==
2
:
return
tensor
[:,
:,
None
]
if
len
(
tensor
.
shape
)
==
3
:
return
tensor
[:,
:,
None
,
:]
elif
len
(
tensor
.
shape
)
==
4
:
return
tensor
[:,
:,
0
,
:]
else
:
raise
ValueError
(
f
"`len(tensor)`:
{
len
(
tensor
)
}
has to be 2, 3 or 4."
)
class
Conv1dBlock
(
nn
.
Module
):
class
Conv1dBlock
(
nn
.
Module
):
"""
"""
Conv1d --> GroupNorm --> Mish
Conv1d --> GroupNorm --> Mish
...
@@ -60,9 +76,11 @@ class Conv1dBlock(nn.Module):
...
@@ -60,9 +76,11 @@ class Conv1dBlock(nn.Module):
self
.
block
=
nn
.
Sequential
(
self
.
block
=
nn
.
Sequential
(
nn
.
Conv1d
(
inp_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
),
nn
.
Conv1d
(
inp_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
),
Rearrange
(
"batch channels horizon -> batch channels 1 horizon"
),
RearrangeDim
(),
# Rearrange("batch channels horizon -> batch channels 1 horizon"),
nn
.
GroupNorm
(
n_groups
,
out_channels
),
nn
.
GroupNorm
(
n_groups
,
out_channels
),
Rearrange
(
"batch channels 1 horizon -> batch channels horizon"
),
RearrangeDim
(),
# Rearrange("batch channels 1 horizon -> batch channels horizon"),
nn
.
Mish
(),
nn
.
Mish
(),
)
)
...
@@ -84,7 +102,8 @@ class ResidualTemporalBlock(nn.Module):
...
@@ -84,7 +102,8 @@ class ResidualTemporalBlock(nn.Module):
self
.
time_mlp
=
nn
.
Sequential
(
self
.
time_mlp
=
nn
.
Sequential
(
nn
.
Mish
(),
nn
.
Mish
(),
nn
.
Linear
(
embed_dim
,
out_channels
),
nn
.
Linear
(
embed_dim
,
out_channels
),
Rearrange
(
"batch t -> batch t 1"
),
RearrangeDim
(),
# Rearrange("batch t -> batch t 1"),
)
)
self
.
residual_conv
=
(
self
.
residual_conv
=
(
...
@@ -184,7 +203,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -184,7 +203,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
x : [ batch x horizon x transition ]
x : [ batch x horizon x transition ]
"""
"""
x
=
einops
.
rearrange
(
x
,
"b h t -> b t h"
)
# x = einops.rearrange(x, "b h t -> b t h")
x
=
x
.
permute
(
0
,
2
,
1
)
t
=
self
.
time_mlp
(
time
)
t
=
self
.
time_mlp
(
time
)
h
=
[]
h
=
[]
...
@@ -206,7 +226,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -206,7 +226,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
x
=
self
.
final_conv
(
x
)
x
=
self
.
final_conv
(
x
)
x
=
einops
.
rearrange
(
x
,
"b t h -> b h t"
)
# x = einops.rearrange(x, "b t h -> b h t")
x
=
x
.
permute
(
0
,
2
,
1
)
return
x
return
x
...
@@ -263,7 +284,8 @@ class TemporalValue(nn.Module):
...
@@ -263,7 +284,8 @@ class TemporalValue(nn.Module):
x : [ batch x horizon x transition ]
x : [ batch x horizon x transition ]
"""
"""
x
=
einops
.
rearrange
(
x
,
"b h t -> b t h"
)
# x = einops.rearrange(x, "b h t -> b t h")
x
=
x
.
permute
(
0
,
2
,
1
)
t
=
self
.
time_mlp
(
time
)
t
=
self
.
time_mlp
(
time
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment