Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
fa7443c8
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b934215d4c376ea2e08e28103443686b95ea772c"
Commit
fa7443c8
authored
Jul 01, 2022
by
Patrick von Platen
Browse files
finish resnet
parent
8d7771d8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
92 additions
and
25 deletions
+92
-25
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+13
-10
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+79
-15
No files found.
src/diffusers/models/resnet.py
View file @
fa7443c8
...
@@ -380,7 +380,7 @@ class ResnetBlock(nn.Module):
...
@@ -380,7 +380,7 @@ class ResnetBlock(nn.Module):
eps
=
1e-6
,
eps
=
1e-6
,
non_linearity
=
"swish"
,
non_linearity
=
"swish"
,
time_embedding_norm
=
"default"
,
time_embedding_norm
=
"default"
,
fir_
kernel
=
(
1
,
3
,
3
,
1
)
,
kernel
=
None
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
use_nin_shortcut
=
None
,
use_nin_shortcut
=
None
,
up
=
False
,
up
=
False
,
...
@@ -433,8 +433,18 @@ class ResnetBlock(nn.Module):
...
@@ -433,8 +433,18 @@ class ResnetBlock(nn.Module):
# elif down:
# elif down:
# self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
# self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
# self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
# self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self
.
upsample
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
if
self
.
up
else
None
self
.
downsample
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
if
self
.
down
else
None
self
.
upsample
=
self
.
downsample
=
None
if
self
.
up
and
kernel
==
"fir"
:
fir_kernel
=
(
1
,
3
,
3
,
1
)
self
.
upsample
=
lambda
x
:
upsample_2d
(
x
,
k
=
fir_kernel
)
elif
self
.
up
and
kernel
is
None
:
self
.
upsample
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
elif
self
.
down
and
kernel
==
"fir"
:
fir_kernel
=
(
1
,
3
,
3
,
1
)
self
.
downsample
=
lambda
x
:
downsample_2d
(
x
,
k
=
fir_kernel
)
elif
self
.
down
and
kernel
is
None
:
self
.
downsample
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
self
.
use_nin_shortcut
=
self
.
in_channels
!=
self
.
out_channels
if
use_nin_shortcut
is
None
else
use_nin_shortcut
self
.
use_nin_shortcut
=
self
.
in_channels
!=
self
.
out_channels
if
use_nin_shortcut
is
None
else
use_nin_shortcut
...
@@ -505,8 +515,6 @@ class ResnetBlock(nn.Module):
...
@@ -505,8 +515,6 @@ class ResnetBlock(nn.Module):
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
in_ch
,
eps
=
eps
)
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
in_ch
,
eps
=
eps
)
self
.
up
=
up
self
.
up
=
up
self
.
down
=
down
self
.
down
=
down
self
.
fir_kernel
=
fir_kernel
self
.
Conv_0
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
padding
=
1
)
self
.
Conv_0
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
padding
=
1
)
if
temb_dim
is
not
None
:
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
...
@@ -525,11 +533,6 @@ class ResnetBlock(nn.Module):
...
@@ -525,11 +533,6 @@ class ResnetBlock(nn.Module):
self
.
out_ch
=
out_ch
self
.
out_ch
=
out_ch
# TODO(Patrick) - move to main init
# TODO(Patrick) - move to main init
if
self
.
up
:
self
.
upsample
=
functools
.
partial
(
upsample_2d
,
k
=
self
.
fir_kernel
)
if
self
.
down
:
self
.
downsample
=
functools
.
partial
(
downsample_2d
,
k
=
self
.
fir_kernel
)
self
.
is_overwritten
=
False
self
.
is_overwritten
=
False
def
set_weights_grad_tts
(
self
):
def
set_weights_grad_tts
(
self
):
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
fa7443c8
...
@@ -348,16 +348,18 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -348,16 +348,18 @@ class NCSNpp(ModelMixin, ConfigMixin):
for
i_block
in
range
(
num_res_blocks
):
for
i_block
in
range
(
num_res_blocks
):
out_ch
=
nf
*
ch_mult
[
i_level
]
out_ch
=
nf
*
ch_mult
[
i_level
]
# modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
# modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
modules
.
append
(
ResnetNew
(
modules
.
append
(
in_channels
=
in_ch
,
ResnetNew
(
out_channels
=
out_ch
,
in_channels
=
in_ch
,
temb_channels
=
4
*
nf
,
out_channels
=
out_ch
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
temb_channels
=
4
*
nf
,
non_linearity
=
"silu"
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
groups
=
min
(
in_ch
//
4
,
32
),
non_linearity
=
"silu"
,
groups_out
=
min
(
out_ch
//
4
,
32
),
groups
=
min
(
in_ch
//
4
,
32
),
overwrite_for_score_vde
=
True
,
groups_out
=
min
(
out_ch
//
4
,
32
),
))
overwrite_for_score_vde
=
True
,
)
)
in_ch
=
out_ch
in_ch
=
out_ch
if
all_resolutions
[
i_level
]
in
attn_resolutions
:
if
all_resolutions
[
i_level
]
in
attn_resolutions
:
...
@@ -365,7 +367,21 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -365,7 +367,21 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c
.
append
(
in_ch
)
hs_c
.
append
(
in_ch
)
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
modules
.
append
(
ResnetBlock
(
down
=
True
,
in_ch
=
in_ch
))
# modules.append(ResnetBlock(down=True, in_ch=in_ch))
modules
.
append
(
ResnetNew
(
in_channels
=
in_ch
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
non_linearity
=
"silu"
,
groups
=
min
(
in_ch
//
4
,
32
),
groups_out
=
min
(
out_ch
//
4
,
32
),
overwrite_for_score_vde
=
True
,
down
=
True
,
kernel
=
"fir"
,
# TODO(Patrick) - it seems like both fir and non-fir kernels are fine
use_nin_shortcut
=
True
,
)
)
if
progressive_input
==
"input_skip"
:
if
progressive_input
==
"input_skip"
:
modules
.
append
(
combiner
(
dim1
=
input_pyramid_ch
,
dim2
=
in_ch
))
modules
.
append
(
combiner
(
dim1
=
input_pyramid_ch
,
dim2
=
in_ch
))
...
@@ -379,16 +395,50 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -379,16 +395,50 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c
.
append
(
in_ch
)
hs_c
.
append
(
in_ch
)
in_ch
=
hs_c
[
-
1
]
in_ch
=
hs_c
[
-
1
]
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
))
# modules.append(ResnetBlock(in_ch=in_ch))
modules
.
append
(
ResnetNew
(
in_channels
=
in_ch
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
non_linearity
=
"silu"
,
groups
=
min
(
in_ch
//
4
,
32
),
groups_out
=
min
(
out_ch
//
4
,
32
),
overwrite_for_score_vde
=
True
,
)
)
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
))
# modules.append(ResnetBlock(in_ch=in_ch))
modules
.
append
(
ResnetNew
(
in_channels
=
in_ch
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
non_linearity
=
"silu"
,
groups
=
min
(
in_ch
//
4
,
32
),
groups_out
=
min
(
out_ch
//
4
,
32
),
overwrite_for_score_vde
=
True
,
)
)
pyramid_ch
=
0
pyramid_ch
=
0
# Upsampling block
# Upsampling block
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
num_res_blocks
+
1
):
for
i_block
in
range
(
num_res_blocks
+
1
):
out_ch
=
nf
*
ch_mult
[
i_level
]
out_ch
=
nf
*
ch_mult
[
i_level
]
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
+
hs_c
.
pop
(),
out_ch
=
out_ch
))
# modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
modules
.
append
(
ResnetNew
(
in_channels
=
in_ch
+
hs_c
.
pop
(),
out_channels
=
out_ch
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
non_linearity
=
"silu"
,
groups
=
min
(
in_ch
//
4
,
32
),
groups_out
=
min
(
out_ch
//
4
,
32
),
overwrite_for_score_vde
=
True
,
)
)
in_ch
=
out_ch
in_ch
=
out_ch
if
all_resolutions
[
i_level
]
in
attn_resolutions
:
if
all_resolutions
[
i_level
]
in
attn_resolutions
:
...
@@ -420,7 +470,21 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -420,7 +470,21 @@ class NCSNpp(ModelMixin, ConfigMixin):
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
if
i_level
!=
0
:
if
i_level
!=
0
:
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
,
up
=
True
))
# modules.append(ResnetBlock(in_ch=in_ch, up=True))
modules
.
append
(
ResnetNew
(
in_channels
=
in_ch
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
non_linearity
=
"silu"
,
groups
=
min
(
in_ch
//
4
,
32
),
groups_out
=
min
(
out_ch
//
4
,
32
),
overwrite_for_score_vde
=
True
,
up
=
True
,
kernel
=
"fir"
,
use_nin_shortcut
=
True
,
)
)
assert
not
hs_c
assert
not
hs_c
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment