Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
d726857f
Commit
d726857f
authored
Jun 27, 2022
by
Patrick von Platen
Browse files
remove einops from unet_ldm
parent
c991ffd4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
14 deletions
+37
-14
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+37
-14
No files found.
src/diffusers/models/unet_ldm.py
View file @
d726857f
...
...
@@ -6,19 +6,18 @@ import numpy as np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
try
:
from
einops
import
rearrange
,
repeat
except
:
print
(
"Einops is not installed"
)
pass
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
#try:
# from einops import rearrange, repeat
#except:
# print("Einops is not installed")
# pass
def
exists
(
val
):
return
val
is
not
None
...
...
@@ -153,7 +152,23 @@ class CrossAttention(nn.Module):
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
def
reshape_heads_to_batch_dim
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
heads
tensor
=
tensor
.
reshape
(
batch_size
,
seq_len
,
head_size
,
dim
//
head_size
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
*
head_size
,
seq_len
,
dim
//
head_size
)
return
tensor
def
reshape_batch_dim_to_heads
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
heads
tensor
=
tensor
.
reshape
(
batch_size
//
head_size
,
head_size
,
seq_len
,
dim
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
//
head_size
,
seq_len
,
dim
*
head_size
)
return
tensor
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
batch_size
,
sequence_length
,
dim
=
x
.
shape
h
=
self
.
heads
q
=
self
.
to_q
(
x
)
...
...
@@ -161,21 +176,29 @@ class CrossAttention(nn.Module):
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
"b n (h d) -> (b h) n d"
,
h
=
h
),
(
q
,
k
,
v
))
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
q
=
self
.
reshape_heads_to_batch_dim
(
q
)
k
=
self
.
reshape_heads_to_batch_dim
(
k
)
v
=
self
.
reshape_heads_to_batch_dim
(
v
)
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
if
exists
(
mask
):
mask
=
rearrange
(
mask
,
"b ... -> b (...)"
)
# mask = rearrange(mask, "b ... -> b (...)")
maks
=
mask
.
reshape
(
batch_size
,
-
1
)
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
mask
=
repeat
(
mask
,
"b j -> (b h) () j"
,
h
=
h
)
# mask = repeat(mask, "b j -> (b h) () j", h=h)
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
# attention, what we cannot get enough of
attn
=
sim
.
softmax
(
dim
=-
1
)
out
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
out
=
rearrange
(
out
,
"(b h) n d -> b n (h d)"
,
h
=
h
)
out
=
self
.
reshape_batch_dim_to_heads
(
out
)
# out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return
self
.
to_out
(
out
)
...
...
@@ -233,10 +256,10 @@ class SpatialTransformer(nn.Module):
x_in
=
x
x
=
self
.
norm
(
x
)
x
=
self
.
proj_in
(
x
)
x
=
rearrang
e
(
x
,
"b c h w -> b (h w)
c
"
)
x
=
x
.
permut
e
(
0
,
2
,
3
,
1
).
reshape
(
b
,
h
*
w
,
c
)
for
block
in
self
.
transformer_blocks
:
x
=
block
(
x
,
context
=
context
)
x
=
rearrange
(
x
,
"b (h
w
)
c
-> b c h w"
,
h
=
h
,
w
=
w
)
x
=
x
.
reshape
(
b
,
h
,
w
,
c
).
permute
(
0
,
3
,
1
,
2
)
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment