Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
9dccc7dc
Commit
9dccc7dc
authored
Jun 28, 2022
by
Patrick von Platen
Browse files
refactor unet's attention
parent
52b3ff5e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
112 additions
and
73 deletions
+112
-73
src/diffusers/models/attention2d.py
src/diffusers/models/attention2d.py
+65
-36
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+47
-37
No files found.
src/diffusers/models/attention2d.py
View file @
9dccc7dc
...
@@ -5,10 +5,6 @@ import torch.nn.functional as F
...
@@ -5,10 +5,6 @@ import torch.nn.functional as F
from
torch
import
nn
from
torch
import
nn
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
# unet_grad_tts.py
# unet_grad_tts.py
class
LinearAttention
(
torch
.
nn
.
Module
):
class
LinearAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
...
@@ -42,31 +38,48 @@ class AttnBlock(nn.Module):
...
@@ -42,31 +38,48 @@ class AttnBlock(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
norm
=
N
ormaliz
e
(
in_channels
)
self
.
norm
=
n
ormaliz
ation
(
in_channels
,
swish
=
None
,
eps
=
1e-6
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
print
(
"x"
,
x
.
abs
().
sum
())
h_
=
x
h_
=
x
h_
=
self
.
norm
(
h_
)
h_
=
self
.
norm
(
h_
)
print
(
"hid_states shape"
,
h_
.
shape
)
print
(
"hid_states"
,
h_
.
abs
().
sum
())
print
(
"hid_states - 3 - 3"
,
h_
.
view
(
h_
.
shape
[
0
],
h_
.
shape
[
1
],
-
1
)[:,
:
3
,
-
3
:])
q
=
self
.
q
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
v
=
self
.
v
(
h_
)
print
(
self
.
q
)
print
(
"q_shape"
,
q
.
shape
)
print
(
"q"
,
q
.
abs
().
sum
())
# print("k_shape", k.shape)
# print("k", k.abs().sum())
# print("v_shape", v.shape)
# print("v", v.abs().sum())
# compute attention
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
b
,
c
,
h
,
w
=
q
.
shape
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
print
(
"weight"
,
w_
.
abs
().
sum
())
# attend to values
# attend to values
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
h_
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
...
@@ -92,6 +105,7 @@ class AttentionBlock(nn.Module):
...
@@ -92,6 +105,7 @@ class AttentionBlock(nn.Module):
use_checkpoint
=
False
,
use_checkpoint
=
False
,
encoder_channels
=
None
,
encoder_channels
=
None
,
use_new_attention_order
=
False
,
# TODO(Patrick) -> is never used, maybe delete?
use_new_attention_order
=
False
,
# TODO(Patrick) -> is never used, maybe delete?
overwrite_qkv
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
...
@@ -102,57 +116,72 @@ class AttentionBlock(nn.Module):
...
@@ -102,57 +116,72 @@ class AttentionBlock(nn.Module):
channels
%
num_head_channels
==
0
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
normalization
(
channels
,
swish
=
0.0
)
self
.
norm
=
normalization
(
channels
,
swish
=
0.0
)
self
.
qkv
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
self
.
qkv
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
self
.
attention
=
QKVAttention
(
self
.
num_heads
)
self
.
n_heads
=
self
.
num_heads
if
encoder_channels
is
not
None
:
if
encoder_channels
is
not
None
:
self
.
encoder_kv
=
conv_nd
(
1
,
encoder_channels
,
channels
*
2
,
1
)
self
.
encoder_kv
=
conv_nd
(
1
,
encoder_channels
,
channels
*
2
,
1
)
self
.
proj_out
=
zero_module
(
conv_nd
(
1
,
channels
,
channels
,
1
))
self
.
proj_out
=
zero_module
(
conv_nd
(
1
,
channels
,
channels
,
1
))
def
forward
(
self
,
x
,
encoder_out
=
None
):
self
.
overwrite_qkv
=
overwrite_qkv
b
,
c
,
*
spatial
=
x
.
shape
if
overwrite_qkv
:
qkv
=
self
.
qkv
(
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
))
in_channels
=
channels
if
encoder_out
is
not
None
:
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
encoder_out
=
self
.
encoder_kv
(
encoder_out
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
h
=
self
.
attention
(
qkv
,
encoder_out
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
else
:
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
h
=
self
.
attention
(
qkv
)
h
=
self
.
proj_out
(
h
)
return
x
+
h
.
reshape
(
b
,
c
,
*
spatial
)
self
.
is_overwritten
=
False
class
QKVAttention
(
nn
.
M
odule
):
def
set_weights
(
self
,
m
odule
):
"""
if
self
.
overwrite_qkv
:
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
qkv_weight
=
torch
.
cat
([
module
.
q
.
weight
.
data
,
module
.
k
.
weight
.
data
,
module
.
v
.
weight
.
data
],
dim
=
0
)[:,
:,
:,
0
]
"""
qkv_bias
=
torch
.
cat
([
module
.
q
.
bias
.
data
,
module
.
k
.
bias
.
data
,
module
.
v
.
bias
.
data
],
dim
=
0
)
def
__init__
(
self
,
n_heads
):
self
.
qkv
.
weight
.
data
=
qkv_weight
super
().
__init__
()
self
.
qkv
.
bias
.
data
=
qkv_bias
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
,
encoder_kv
=
None
):
proj_out
=
zero_module
(
conv_nd
(
1
,
self
.
channels
,
self
.
channels
,
1
))
"""
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
Apply QKV attention.
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after
self
.
proj_out
=
proj_out
attention.
"""
def
forward
(
self
,
x
,
encoder_out
=
None
):
if
self
.
overwrite_qkv
and
not
self
.
is_overwritten
:
self
.
set_weights
(
self
)
self
.
is_overwritten
=
True
b
,
c
,
*
spatial
=
x
.
shape
hid_states
=
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
)
qkv
=
self
.
qkv
(
hid_states
)
bs
,
width
,
length
=
qkv
.
shape
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
if
encoder_kv
is
not
None
:
if
encoder_out
is
not
None
:
encoder_kv
=
self
.
encoder_kv
(
encoder_out
)
assert
encoder_kv
.
shape
[
1
]
==
self
.
n_heads
*
ch
*
2
assert
encoder_kv
.
shape
[
1
]
==
self
.
n_heads
*
ch
*
2
ek
,
ev
=
encoder_kv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
2
,
-
1
).
split
(
ch
,
dim
=
1
)
ek
,
ev
=
encoder_kv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
2
,
-
1
).
split
(
ch
,
dim
=
1
)
k
=
torch
.
cat
([
ek
,
k
],
dim
=-
1
)
k
=
torch
.
cat
([
ek
,
k
],
dim
=-
1
)
v
=
torch
.
cat
([
ev
,
v
],
dim
=-
1
)
v
=
torch
.
cat
([
ev
,
v
],
dim
=-
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
h
=
a
.
reshape
(
bs
,
-
1
,
length
)
h
=
self
.
proj_out
(
h
)
return
x
+
h
.
reshape
(
b
,
c
,
*
spatial
)
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
...
@@ -169,8 +198,8 @@ def conv_nd(dims, *args, **kwargs):
...
@@ -169,8 +198,8 @@ def conv_nd(dims, *args, **kwargs):
class
GroupNorm32
(
nn
.
GroupNorm
):
class
GroupNorm32
(
nn
.
GroupNorm
):
def
__init__
(
self
,
num_groups
,
num_channels
,
swish
,
eps
=
1e-5
):
def
__init__
(
self
,
num_groups
,
num_channels
,
swish
,
eps
=
1e-5
,
affine
=
True
):
super
().
__init__
(
num_groups
=
num_groups
,
num_channels
=
num_channels
,
eps
=
eps
)
super
().
__init__
(
num_groups
=
num_groups
,
num_channels
=
num_channels
,
eps
=
eps
,
affine
=
affine
)
self
.
swish
=
swish
self
.
swish
=
swish
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -182,13 +211,13 @@ class GroupNorm32(nn.GroupNorm):
...
@@ -182,13 +211,13 @@ class GroupNorm32(nn.GroupNorm):
return
y
return
y
def
normalization
(
channels
,
swish
=
0.0
):
def
normalization
(
channels
,
swish
=
0.0
,
eps
=
1e-5
):
"""
"""
Make a standard normalization layer, with an optional swish activation.
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels. :return: an nn.Module for normalization.
:param channels: number of input channels. :return: an nn.Module for normalization.
"""
"""
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
,
eps
=
eps
,
affine
=
True
)
def
zero_module
(
module
):
def
zero_module
(
module
):
...
...
src/diffusers/models/unet.py
View file @
9dccc7dc
...
@@ -32,6 +32,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -32,6 +32,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
Upsample
from
.attention2d
import
AttnBlock
,
AttentionBlock
def
nonlinearity
(
x
):
def
nonlinearity
(
x
):
...
@@ -85,42 +86,42 @@ class ResnetBlock(nn.Module):
...
@@ -85,42 +86,42 @@ class ResnetBlock(nn.Module):
return
x
+
h
return
x
+
h
class
AttnBlock
(
nn
.
Module
):
#
class AttnBlock(nn.Module):
def
__init__
(
self
,
in_channels
):
#
def __init__(self, in_channels):
super
().
__init__
()
#
super().__init__()
self
.
in_channels
=
in_channels
#
self.in_channels = in_channels
#
self
.
norm
=
Normalize
(
in_channels
)
#
self.norm = Normalize(in_channels)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
#
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
#
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
#
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
#
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
#
def
forward
(
self
,
x
):
#
def forward(self, x):
h_
=
x
#
h_ = x
h_
=
self
.
norm
(
h_
)
#
h_ = self.norm(h_)
q
=
self
.
q
(
h_
)
#
q = self.q(h_)
k
=
self
.
k
(
h_
)
#
k = self.k(h_)
v
=
self
.
v
(
h_
)
#
v = self.v(h_)
#
# compute attention
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
#
b, c, h, w = q.shape
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
#
q = q.reshape(b, c, h * w)
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
#
q = q.permute(0, 2, 1) # b,hw,c
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
#
k = k.reshape(b, c, h * w) # b,c,hw
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
#
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
#
w_ = w_ * (int(c) ** (-0.5))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
#
w_ = torch.nn.functional.softmax(w_, dim=2)
#
# attend to values
# attend to values
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
#
v = v.reshape(b, c, h * w)
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
#
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
#
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
#
h_ = h_.reshape(b, c, h, w)
#
h_
=
self
.
proj_out
(
h_
)
#
h_ = self.proj_out(h_)
#
return
x
+
h_
#
return x + h_
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
...
@@ -174,6 +175,7 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -174,6 +175,7 @@ class UNetModel(ModelMixin, ConfigMixin):
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
attn_2
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
for
i_block
in
range
(
self
.
num_res_blocks
):
...
@@ -184,10 +186,12 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -184,10 +186,12 @@ class UNetModel(ModelMixin, ConfigMixin):
)
)
block_in
=
block_out
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
# attn.append(AttnBlock(block_in))
attn
.
append
(
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
))
down
=
nn
.
Module
()
down
=
nn
.
Module
()
down
.
block
=
block
down
.
block
=
block
down
.
attn
=
attn
down
.
attn
=
attn
down
.
attn_2
=
attn_2
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
down
.
downsample
=
Downsample
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
curr_res
=
curr_res
//
2
curr_res
=
curr_res
//
2
...
@@ -198,7 +202,8 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -198,7 +202,8 @@ class UNetModel(ModelMixin, ConfigMixin):
self
.
mid
.
block_1
=
ResnetBlock
(
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
# self.mid.attn_1 = AttnBlock(block_in)
self
.
mid
.
attn_1
=
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
)
self
.
mid
.
block_2
=
ResnetBlock
(
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
...
@@ -223,7 +228,8 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -223,7 +228,8 @@ class UNetModel(ModelMixin, ConfigMixin):
)
)
block_in
=
block_out
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
# attn.append(AttnBlock(block_in))
attn
.
append
(
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
))
up
=
nn
.
Module
()
up
=
nn
.
Module
()
up
.
block
=
block
up
.
block
=
block
up
.
attn
=
attn
up
.
attn
=
attn
...
@@ -254,7 +260,11 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -254,7 +260,11 @@ class UNetModel(ModelMixin, ConfigMixin):
for
i_block
in
range
(
self
.
num_res_blocks
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
# self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block])
# h = self.down[i_level].attn_2[i_block](h)
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
# print("Result", (h - h_2).abs().sum())
hs
.
append
(
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment