Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
9dccc7dc
Commit
9dccc7dc
authored
Jun 28, 2022
by
Patrick von Platen
Browse files
refactor unet's attention
parent
52b3ff5e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
112 additions
and
73 deletions
+112
-73
src/diffusers/models/attention2d.py
src/diffusers/models/attention2d.py
+65
-36
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+47
-37
No files found.
src/diffusers/models/attention2d.py
View file @
9dccc7dc
...
@@ -5,10 +5,6 @@ import torch.nn.functional as F
...
@@ -5,10 +5,6 @@ import torch.nn.functional as F
from
torch
import
nn
from
torch
import
nn
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
# unet_grad_tts.py
# unet_grad_tts.py
class
LinearAttention
(
torch
.
nn
.
Module
):
class
LinearAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
...
@@ -42,31 +38,48 @@ class AttnBlock(nn.Module):
...
@@ -42,31 +38,48 @@ class AttnBlock(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
norm
=
N
ormaliz
e
(
in_channels
)
self
.
norm
=
n
ormaliz
ation
(
in_channels
,
swish
=
None
,
eps
=
1e-6
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
print
(
"x"
,
x
.
abs
().
sum
())
h_
=
x
h_
=
x
h_
=
self
.
norm
(
h_
)
h_
=
self
.
norm
(
h_
)
print
(
"hid_states shape"
,
h_
.
shape
)
print
(
"hid_states"
,
h_
.
abs
().
sum
())
print
(
"hid_states - 3 - 3"
,
h_
.
view
(
h_
.
shape
[
0
],
h_
.
shape
[
1
],
-
1
)[:,
:
3
,
-
3
:])
q
=
self
.
q
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
v
=
self
.
v
(
h_
)
print
(
self
.
q
)
print
(
"q_shape"
,
q
.
shape
)
print
(
"q"
,
q
.
abs
().
sum
())
# print("k_shape", k.shape)
# print("k", k.abs().sum())
# print("v_shape", v.shape)
# print("v", v.abs().sum())
# compute attention
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
b
,
c
,
h
,
w
=
q
.
shape
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
print
(
"weight"
,
w_
.
abs
().
sum
())
# attend to values
# attend to values
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
h_
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
...
@@ -92,6 +105,7 @@ class AttentionBlock(nn.Module):
...
@@ -92,6 +105,7 @@ class AttentionBlock(nn.Module):
use_checkpoint
=
False
,
use_checkpoint
=
False
,
encoder_channels
=
None
,
encoder_channels
=
None
,
use_new_attention_order
=
False
,
# TODO(Patrick) -> is never used, maybe delete?
use_new_attention_order
=
False
,
# TODO(Patrick) -> is never used, maybe delete?
overwrite_qkv
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
...
@@ -102,57 +116,72 @@ class AttentionBlock(nn.Module):
...
@@ -102,57 +116,72 @@ class AttentionBlock(nn.Module):
channels
%
num_head_channels
==
0
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
normalization
(
channels
,
swish
=
0.0
)
self
.
norm
=
normalization
(
channels
,
swish
=
0.0
)
self
.
qkv
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
self
.
qkv
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
self
.
attention
=
QKVAttention
(
self
.
num_heads
)
self
.
n_heads
=
self
.
num_heads
if
encoder_channels
is
not
None
:
if
encoder_channels
is
not
None
:
self
.
encoder_kv
=
conv_nd
(
1
,
encoder_channels
,
channels
*
2
,
1
)
self
.
encoder_kv
=
conv_nd
(
1
,
encoder_channels
,
channels
*
2
,
1
)
self
.
proj_out
=
zero_module
(
conv_nd
(
1
,
channels
,
channels
,
1
))
self
.
proj_out
=
zero_module
(
conv_nd
(
1
,
channels
,
channels
,
1
))
def
forward
(
self
,
x
,
encoder_out
=
None
):
self
.
overwrite_qkv
=
overwrite_qkv
b
,
c
,
*
spatial
=
x
.
shape
if
overwrite_qkv
:
qkv
=
self
.
qkv
(
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
))
in_channels
=
channels
if
encoder_out
is
not
None
:
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
encoder_out
=
self
.
encoder_kv
(
encoder_out
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
h
=
self
.
attention
(
qkv
,
encoder_out
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
else
:
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
h
=
self
.
attention
(
qkv
)
h
=
self
.
proj_out
(
h
)
return
x
+
h
.
reshape
(
b
,
c
,
*
spatial
)
self
.
is_overwritten
=
False
class
QKVAttention
(
nn
.
M
odule
):
def
set_weights
(
self
,
m
odule
):
"""
if
self
.
overwrite_qkv
:
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
qkv_weight
=
torch
.
cat
([
module
.
q
.
weight
.
data
,
module
.
k
.
weight
.
data
,
module
.
v
.
weight
.
data
],
dim
=
0
)[:,
:,
:,
0
]
"""
qkv_bias
=
torch
.
cat
([
module
.
q
.
bias
.
data
,
module
.
k
.
bias
.
data
,
module
.
v
.
bias
.
data
],
dim
=
0
)
def
__init__
(
self
,
n_heads
):
self
.
qkv
.
weight
.
data
=
qkv_weight
super
().
__init__
()
self
.
qkv
.
bias
.
data
=
qkv_bias
self
.
n_heads
=
n_heads
proj_out
=
zero_module
(
conv_nd
(
1
,
self
.
channels
,
self
.
channels
,
1
))
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
def
forward
(
self
,
qkv
,
encoder_kv
=
None
):
self
.
proj_out
=
proj_out
"""
Apply QKV attention.
def
forward
(
self
,
x
,
encoder_out
=
None
):
if
self
.
overwrite_qkv
and
not
self
.
is_overwritten
:
self
.
set_weights
(
self
)
self
.
is_overwritten
=
True
b
,
c
,
*
spatial
=
x
.
shape
hid_states
=
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
)
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after
qkv
=
self
.
qkv
(
hid_states
)
attention.
"""
bs
,
width
,
length
=
qkv
.
shape
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
if
encoder_kv
is
not
None
:
if
encoder_out
is
not
None
:
encoder_kv
=
self
.
encoder_kv
(
encoder_out
)
assert
encoder_kv
.
shape
[
1
]
==
self
.
n_heads
*
ch
*
2
assert
encoder_kv
.
shape
[
1
]
==
self
.
n_heads
*
ch
*
2
ek
,
ev
=
encoder_kv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
2
,
-
1
).
split
(
ch
,
dim
=
1
)
ek
,
ev
=
encoder_kv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
2
,
-
1
).
split
(
ch
,
dim
=
1
)
k
=
torch
.
cat
([
ek
,
k
],
dim
=-
1
)
k
=
torch
.
cat
([
ek
,
k
],
dim
=-
1
)
v
=
torch
.
cat
([
ev
,
v
],
dim
=-
1
)
v
=
torch
.
cat
([
ev
,
v
],
dim
=-
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
h
=
a
.
reshape
(
bs
,
-
1
,
length
)
h
=
self
.
proj_out
(
h
)
return
x
+
h
.
reshape
(
b
,
c
,
*
spatial
)
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
...
@@ -169,8 +198,8 @@ def conv_nd(dims, *args, **kwargs):
...
@@ -169,8 +198,8 @@ def conv_nd(dims, *args, **kwargs):
class
GroupNorm32
(
nn
.
GroupNorm
):
class
GroupNorm32
(
nn
.
GroupNorm
):
def
__init__
(
self
,
num_groups
,
num_channels
,
swish
,
eps
=
1e-5
):
def
__init__
(
self
,
num_groups
,
num_channels
,
swish
,
eps
=
1e-5
,
affine
=
True
):
super
().
__init__
(
num_groups
=
num_groups
,
num_channels
=
num_channels
,
eps
=
eps
)
super
().
__init__
(
num_groups
=
num_groups
,
num_channels
=
num_channels
,
eps
=
eps
,
affine
=
affine
)
self
.
swish
=
swish
self
.
swish
=
swish
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -182,13 +211,13 @@ class GroupNorm32(nn.GroupNorm):
...
@@ -182,13 +211,13 @@ class GroupNorm32(nn.GroupNorm):
return
y
return
y
def
normalization
(
channels
,
swish
=
0.0
):
def
normalization
(
channels
,
swish
=
0.0
,
eps
=
1e-5
):
"""
"""
Make a standard normalization layer, with an optional swish activation.
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels. :return: an nn.Module for normalization.
:param channels: number of input channels. :return: an nn.Module for normalization.
"""
"""
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
,
eps
=
eps
,
affine
=
True
)
def
zero_module
(
module
):
def
zero_module
(
module
):
...
...
src/diffusers/models/unet.py
View file @
9dccc7dc
...
@@ -32,6 +32,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -32,6 +32,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
Upsample
from
.attention2d
import
AttnBlock
,
AttentionBlock
def
nonlinearity
(
x
):
def
nonlinearity
(
x
):
...
@@ -85,42 +86,42 @@ class ResnetBlock(nn.Module):
...
@@ -85,42 +86,42 @@ class ResnetBlock(nn.Module):
return
x
+
h
return
x
+
h
class
AttnBlock
(
nn
.
Module
):
#
class AttnBlock(nn.Module):
def
__init__
(
self
,
in_channels
):
#
def __init__(self, in_channels):
super
().
__init__
()
#
super().__init__()
self
.
in_channels
=
in_channels
#
self.in_channels = in_channels
#
self
.
norm
=
Normalize
(
in_channels
)
#
self.norm = Normalize(in_channels)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
#
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
#
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
#
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
#
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
#
def
forward
(
self
,
x
):
#
def forward(self, x):
h_
=
x
#
h_ = x
h_
=
self
.
norm
(
h_
)
#
h_ = self.norm(h_)
q
=
self
.
q
(
h_
)
#
q = self.q(h_)
k
=
self
.
k
(
h_
)
#
k = self.k(h_)
v
=
self
.
v
(
h_
)
#
v = self.v(h_)
#
# compute attention
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
#
b, c, h, w = q.shape
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
#
q = q.reshape(b, c, h * w)
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
#
q = q.permute(0, 2, 1) # b,hw,c
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
#
k = k.reshape(b, c, h * w) # b,c,hw
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
#
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
#
w_ = w_ * (int(c) ** (-0.5))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
#
w_ = torch.nn.functional.softmax(w_, dim=2)
#
# attend to values
# attend to values
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
#
v = v.reshape(b, c, h * w)
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
#
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
#
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
#
h_ = h_.reshape(b, c, h, w)
#
h_
=
self
.
proj_out
(
h_
)
#
h_ = self.proj_out(h_)
#
return
x
+
h_
#
return x + h_
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
...
@@ -174,6 +175,7 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -174,6 +175,7 @@ class UNetModel(ModelMixin, ConfigMixin):
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
attn_2
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
for
i_block
in
range
(
self
.
num_res_blocks
):
...
@@ -184,10 +186,12 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -184,10 +186,12 @@ class UNetModel(ModelMixin, ConfigMixin):
)
)
block_in
=
block_out
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
# attn.append(AttnBlock(block_in))
attn
.
append
(
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
))
down
=
nn
.
Module
()
down
=
nn
.
Module
()
down
.
block
=
block
down
.
block
=
block
down
.
attn
=
attn
down
.
attn
=
attn
down
.
attn_2
=
attn_2
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
down
.
downsample
=
Downsample
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
curr_res
=
curr_res
//
2
curr_res
=
curr_res
//
2
...
@@ -198,7 +202,8 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -198,7 +202,8 @@ class UNetModel(ModelMixin, ConfigMixin):
self
.
mid
.
block_1
=
ResnetBlock
(
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
# self.mid.attn_1 = AttnBlock(block_in)
self
.
mid
.
attn_1
=
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
)
self
.
mid
.
block_2
=
ResnetBlock
(
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
...
@@ -223,7 +228,8 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -223,7 +228,8 @@ class UNetModel(ModelMixin, ConfigMixin):
)
)
block_in
=
block_out
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
# attn.append(AttnBlock(block_in))
attn
.
append
(
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
))
up
=
nn
.
Module
()
up
=
nn
.
Module
()
up
.
block
=
block
up
.
block
=
block
up
.
attn
=
attn
up
.
attn
=
attn
...
@@ -254,7 +260,11 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -254,7 +260,11 @@ class UNetModel(ModelMixin, ConfigMixin):
for
i_block
in
range
(
self
.
num_res_blocks
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
# self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block])
# h = self.down[i_level].attn_2[i_block](h)
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
# print("Result", (h - h_2).abs().sum())
hs
.
append
(
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment