Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
31d1f3c8
Commit
31d1f3c8
authored
Jun 28, 2022
by
Patrick von Platen
Browse files
final fix
parent
635da723
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
30 additions
and
19 deletions
+30
-19
src/diffusers/models/attention2d.py
src/diffusers/models/attention2d.py
+20
-10
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+4
-4
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+1
-1
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+2
-3
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+3
-1
No files found.
src/diffusers/models/attention2d.py
View file @
31d1f3c8
...
...
@@ -91,11 +91,15 @@ class AttentionBlock(nn.Module):
self
.
NIN_2
=
NIN
(
channels
,
channels
)
self
.
NIN_3
=
NIN
(
channels
,
channels
)
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
channels
,
eps
=
1e-6
)
self
.
is_overwritten
=
False
def
set_weights
(
self
,
module
):
if
self
.
overwrite_qkv
:
qkv_weight
=
torch
.
cat
([
module
.
q
.
weight
.
data
,
module
.
k
.
weight
.
data
,
module
.
v
.
weight
.
data
],
dim
=
0
)[:,
:,
:,
0
]
qkv_weight
=
torch
.
cat
([
module
.
q
.
weight
.
data
,
module
.
k
.
weight
.
data
,
module
.
v
.
weight
.
data
],
dim
=
0
)[
:,
:,
:,
0
]
qkv_bias
=
torch
.
cat
([
module
.
q
.
bias
.
data
,
module
.
k
.
bias
.
data
,
module
.
v
.
bias
.
data
],
dim
=
0
)
self
.
qkv
.
weight
.
data
=
qkv_weight
...
...
@@ -107,14 +111,19 @@ class AttentionBlock(nn.Module):
self
.
proj_out
=
proj_out
elif
self
.
overwrite_linear
:
self
.
qkv
.
weight
.
data
=
torch
.
concat
([
self
.
NIN_0
.
W
.
data
.
T
,
self
.
NIN_1
.
W
.
data
.
T
,
self
.
NIN_2
.
W
.
data
.
T
],
dim
=
0
)[:,
:,
None
]
self
.
qkv
.
weight
.
data
=
torch
.
concat
(
[
self
.
NIN_0
.
W
.
data
.
T
,
self
.
NIN_1
.
W
.
data
.
T
,
self
.
NIN_2
.
W
.
data
.
T
],
dim
=
0
)[:,
:,
None
]
self
.
qkv
.
bias
.
data
=
torch
.
concat
([
self
.
NIN_0
.
b
.
data
,
self
.
NIN_1
.
b
.
data
,
self
.
NIN_2
.
b
.
data
],
dim
=
0
)
self
.
proj_out
.
weight
.
data
=
self
.
NIN_3
.
W
.
data
.
T
[:,
:,
None
]
self
.
proj_out
.
bias
.
data
=
self
.
NIN_3
.
b
.
data
self
.
norm
.
weight
.
data
=
self
.
GroupNorm_0
.
weight
.
data
self
.
norm
.
bias
.
data
=
self
.
GroupNorm_0
.
bias
.
data
def
forward
(
self
,
x
,
encoder_out
=
None
):
if
self
.
overwrite_qkv
and
not
self
.
is_overwritten
:
if
(
self
.
overwrite_qkv
or
self
.
overwrite_linear
)
and
not
self
.
is_overwritten
:
self
.
set_weights
(
self
)
self
.
is_overwritten
=
True
...
...
@@ -152,7 +161,7 @@ class AttentionBlock(nn.Module):
# unet_score_estimation.py
#class AttnBlockpp(nn.Module):
#
class AttnBlockpp(nn.Module):
# """Channel-wise self-attention block. Modified from DDPM."""
#
# def __init__(
...
...
@@ -187,14 +196,11 @@ class AttentionBlock(nn.Module):
# self.num_heads = channels // num_head_channels
#
# self.use_checkpoint = use_checkpoint
# self.norm = n
ormalization(
channels, num_groups=num_groups, eps=1e-6
, swish=None
)
# self.qkv =
conv_nd(1,
channels, channels * 3, 1)
# self.norm = n
n.GroupNorm(num_channels=
channels, num_groups=num_groups, eps=1e-6)
# self.qkv =
nn.Conv1d(
channels, channels * 3, 1)
# self.n_heads = self.num_heads
#
# if encoder_channels is not None:
# self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
#
# self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
# self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
#
# self.is_weight_set = False
#
...
...
@@ -205,6 +211,9 @@ class AttentionBlock(nn.Module):
# self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
# self.proj_out.bias.data = self.NIN_3.b.data
#
# self.norm.weight.data = self.GroupNorm_0.weight.data
# self.norm.bias.data = self.GroupNorm_0.bias.data
#
# def forward(self, x):
# if not self.is_weight_set:
# self.set_weights()
...
...
@@ -261,6 +270,7 @@ class AttentionBlock(nn.Module):
#
# return (x + h) / np.sqrt(2.0)
# TODO(Patrick) - this can and should be removed
def
zero_module
(
module
):
"""
...
...
src/diffusers/models/unet.py
View file @
31d1f3c8
...
...
@@ -30,9 +30,9 @@ from tqdm import tqdm
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention2d
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.attention2d
import
AttentionBlock
def
nonlinearity
(
x
):
...
...
@@ -219,11 +219,11 @@ class UNetModel(ModelMixin, ConfigMixin):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
# self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block])
# h = self.down[i_level].attn_2[i_block](h)
# self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block])
# h = self.down[i_level].attn_2[i_block](h)
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
# print("Result", (h - h_2).abs().sum())
# print("Result", (h - h_2).abs().sum())
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
...
...
src/diffusers/models/unet_grad_tts.py
View file @
31d1f3c8
...
...
@@ -3,9 +3,9 @@ from numpy import pad
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention2d
import
LinearAttention
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.attention2d
import
LinearAttention
class
Mish
(
torch
.
nn
.
Module
):
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
31d1f3c8
...
...
@@ -16,18 +16,18 @@
# helpers functions
import
functools
import
math
import
string
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
math
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.attention2d
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
...
...
@@ -728,7 +728,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
nn
.
init
.
zeros_
(
modules
[
-
1
].
bias
)
AttnBlock
=
functools
.
partial
(
AttentionBlock
,
overwrite_linear
=
True
,
rescale_output_factor
=
math
.
sqrt
(
2.0
))
Up_sample
=
functools
.
partial
(
Upsample
,
with_conv
=
resamp_with_conv
,
fir
=
fir
,
fir_kernel
=
fir_kernel
)
if
progressive
==
"output_skip"
:
...
...
tests/test_modeling_utils.py
View file @
31d1f3c8
...
...
@@ -859,7 +859,9 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
expected_slice
=
torch
.
tensor
([
-
0.5712
,
-
0.6215
,
-
0.5953
,
-
0.5438
,
-
0.4775
,
-
0.4539
,
-
0.5172
,
-
0.4872
,
-
0.5105
])
expected_slice
=
torch
.
tensor
(
[
-
0.5712
,
-
0.6215
,
-
0.5953
,
-
0.5438
,
-
0.4775
,
-
0.4539
,
-
0.5172
,
-
0.4872
,
-
0.5105
]
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment