Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
e372767c
Unverified
Commit
e372767c
authored
Jun 28, 2022
by
Patrick von Platen
Committed by
GitHub
Jun 28, 2022
Browse files
Merge pull request #37 from huggingface/merg_unet_attn_into_glide
merge unet attention into glide attention
parents
9dccc7dc
c45fd749
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
107 deletions
+9
-107
src/diffusers/models/attention2d.py
src/diffusers/models/attention2d.py
+0
-56
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+1
-42
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+8
-9
No files found.
src/diffusers/models/attention2d.py
View file @
e372767c
...
...
@@ -32,62 +32,6 @@ class LinearAttention(torch.nn.Module):
return
self
.
to_out
(
out
)
# unet.py
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
normalization
(
in_channels
,
swish
=
None
,
eps
=
1e-6
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
print
(
"x"
,
x
.
abs
().
sum
())
h_
=
x
h_
=
self
.
norm
(
h_
)
print
(
"hid_states shape"
,
h_
.
shape
)
print
(
"hid_states"
,
h_
.
abs
().
sum
())
print
(
"hid_states - 3 - 3"
,
h_
.
view
(
h_
.
shape
[
0
],
h_
.
shape
[
1
],
-
1
)[:,
:
3
,
-
3
:])
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
print
(
self
.
q
)
print
(
"q_shape"
,
q
.
shape
)
print
(
"q"
,
q
.
abs
().
sum
())
# print("k_shape", k.shape)
# print("k", k.abs().sum())
# print("v_shape", v.shape)
# print("v", v.abs().sum())
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
print
(
"weight"
,
w_
.
abs
().
sum
())
# attend to values
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
# unet_glide.py & unet_ldm.py
class
AttentionBlock
(
nn
.
Module
):
"""
...
...
src/diffusers/models/unet.py
View file @
e372767c
...
...
@@ -32,7 +32,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.attention2d
import
AttnBlock
,
AttentionBlock
from
.attention2d
import
AttentionBlock
def
nonlinearity
(
x
):
...
...
@@ -86,44 +86,6 @@ class ResnetBlock(nn.Module):
return
x
+
h
#class AttnBlock(nn.Module):
# def __init__(self, in_channels):
# super().__init__()
# self.in_channels = in_channels
#
# self.norm = Normalize(in_channels)
# self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
#
# def forward(self, x):
# h_ = x
# h_ = self.norm(h_)
# q = self.q(h_)
# k = self.k(h_)
# v = self.v(h_)
#
# compute attention
# b, c, h, w = q.shape
# q = q.reshape(b, c, h * w)
# q = q.permute(0, 2, 1) # b,hw,c
# k = k.reshape(b, c, h * w) # b,c,hw
# w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
# w_ = w_ * (int(c) ** (-0.5))
# w_ = torch.nn.functional.softmax(w_, dim=2)
#
# attend to values
# v = v.reshape(b, c, h * w)
# w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
# h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
# h_ = h_.reshape(b, c, h, w)
#
# h_ = self.proj_out(h_)
#
# return x + h_
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
...
...
@@ -186,7 +148,6 @@ class UNetModel(ModelMixin, ConfigMixin):
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
# attn.append(AttnBlock(block_in))
attn
.
append
(
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
))
down
=
nn
.
Module
()
down
.
block
=
block
...
...
@@ -202,7 +163,6 @@ class UNetModel(ModelMixin, ConfigMixin):
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# self.mid.attn_1 = AttnBlock(block_in)
self
.
mid
.
attn_1
=
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
...
...
@@ -228,7 +188,6 @@ class UNetModel(ModelMixin, ConfigMixin):
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
# attn.append(AttnBlock(block_in))
attn
.
append
(
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
))
up
=
nn
.
Module
()
up
.
block
=
block
...
...
tests/test_modeling_utils.py
View file @
e372767c
...
...
@@ -858,25 +858,26 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
expected_slice
=
torch
.
tensor
([
0.22
50
,
0.3375
,
0.23
60
,
0.09
30
,
0.34
40
,
0.3156
,
0.1937
,
0.3585
,
0.1761
])
expected_slice
=
torch
.
tensor
([
0.22
49
,
0.3375
,
0.23
59
,
0.09
29
,
0.34
39
,
0.3156
,
0.1937
,
0.3585
,
0.1761
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_ddim_cifar10
(
self
):
generator
=
torch
.
manual_seed
(
0
)
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetModel
.
from_pretrained
(
model_id
)
noise_scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddim
(
generator
=
generator
,
eta
=
0.0
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
expected_slice
=
torch
.
tensor
(
[
-
0.
738
3
,
-
0.
738
5
,
-
0.
7298
,
-
0.
7364
,
-
0.7
414
,
-
0.
7239
,
-
0.6
737
,
-
0.
6813
,
-
0.70
68
]
[
-
0.
655
3
,
-
0.
676
5
,
-
0.
6799
,
-
0.
6749
,
-
0.7
006
,
-
0.
6974
,
-
0.6
991
,
-
0.
7116
,
-
0.70
94
]
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
...
...
@@ -895,7 +896,7 @@ class PipelineTesterMixin(unittest.TestCase):
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
expected_slice
=
torch
.
tensor
(
[
-
0.7
888
,
-
0.7
870
,
-
0.77
5
9
,
-
0.7
823
,
-
0.80
14
,
-
0.7
608
,
-
0.68
18
,
-
0.71
30
,
-
0.74
71
]
[
-
0.7
925
,
-
0.7
902
,
-
0.77
8
9
,
-
0.7
796
,
-
0.80
00
,
-
0.7
596
,
-
0.68
52
,
-
0.71
25
,
-
0.74
94
]
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
...
...
@@ -966,24 +967,22 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
def
test_score_sde_ve_pipeline
(
self
):
torch
.
manual_seed
(
0
)
model
=
NCSNpp
.
from_pretrained
(
"fusing/ffhq_ncsnpp"
)
scheduler
=
ScoreSdeVeScheduler
.
from_config
(
"fusing/ffhq_ncsnpp"
)
sde_ve
=
ScoreSdeVePipeline
(
model
=
model
,
scheduler
=
scheduler
)
torch
.
manual_seed
(
0
)
image
=
sde_ve
(
num_inference_steps
=
2
)
expected_image_sum
=
33828
10112
.0
expected_image_mean
=
1075.3
66455078125
expected_image_sum
=
33828
49024
.0
expected_image_mean
=
1075.3
788
assert
(
image
.
abs
().
sum
()
-
expected_image_sum
).
abs
().
cpu
().
item
()
<
1e-2
assert
(
image
.
abs
().
mean
()
-
expected_image_mean
).
abs
().
cpu
().
item
()
<
1e-4
@
slow
def
test_score_sde_vp_pipeline
(
self
):
model
=
NCSNpp
.
from_pretrained
(
"fusing/cifar10-ddpmpp-vp"
)
scheduler
=
ScoreSdeVpScheduler
.
from_config
(
"fusing/cifar10-ddpmpp-vp"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment