Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
1468f754
Commit
1468f754
authored
Jul 01, 2022
by
Patrick von Platen
Browse files
finish resnet
parent
fa7443c8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
24 deletions
+7
-24
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+7
-24
No files found.
src/diffusers/models/unet_sde_score_estimation.py
View file @
1468f754
...
@@ -28,8 +28,7 @@ from ..modeling_utils import ModelMixin
...
@@ -28,8 +28,7 @@ from ..modeling_utils import ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
downsample_2d
,
upfirdn2d
,
upsample_2d
from
.resnet
import
downsample_2d
,
upfirdn2d
,
upsample_2d
from
.resnet
import
ResnetBlockBigGANppNew
as
ResnetBlockBigGANpp
from
.resnet
import
ResnetBlock
from
.resnet
import
ResnetBlock
as
ResnetNew
def
_setup_kernel
(
k
):
def
_setup_kernel
(
k
):
...
@@ -323,16 +322,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -323,16 +322,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
elif
progressive_input
==
"residual"
:
elif
progressive_input
==
"residual"
:
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
fir_kernel
=
fir_kernel
,
with_conv
=
True
)
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
fir_kernel
=
fir_kernel
,
with_conv
=
True
)
ResnetBlock
=
functools
.
partial
(
ResnetBlockBigGANpp
,
act
=
act
,
dropout
=
dropout
,
fir_kernel
=
fir_kernel
,
init_scale
=
init_scale
,
skip_rescale
=
skip_rescale
,
temb_dim
=
nf
*
4
,
)
# Downsampling block
# Downsampling block
channels
=
num_channels
channels
=
num_channels
...
@@ -347,9 +336,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -347,9 +336,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
# Residual blocks for this resolution
# Residual blocks for this resolution
for
i_block
in
range
(
num_res_blocks
):
for
i_block
in
range
(
num_res_blocks
):
out_ch
=
nf
*
ch_mult
[
i_level
]
out_ch
=
nf
*
ch_mult
[
i_level
]
# modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
modules
.
append
(
modules
.
append
(
Resnet
New
(
Resnet
Block
(
in_channels
=
in_ch
,
in_channels
=
in_ch
,
out_channels
=
out_ch
,
out_channels
=
out_ch
,
temb_channels
=
4
*
nf
,
temb_channels
=
4
*
nf
,
...
@@ -367,9 +355,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -367,9 +355,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c
.
append
(
in_ch
)
hs_c
.
append
(
in_ch
)
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
# modules.append(ResnetBlock(down=True, in_ch=in_ch))
modules
.
append
(
modules
.
append
(
Resnet
New
(
Resnet
Block
(
in_channels
=
in_ch
,
in_channels
=
in_ch
,
temb_channels
=
4
*
nf
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
output_scale_factor
=
np
.
sqrt
(
2.0
),
...
@@ -395,9 +382,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -395,9 +382,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c
.
append
(
in_ch
)
hs_c
.
append
(
in_ch
)
in_ch
=
hs_c
[
-
1
]
in_ch
=
hs_c
[
-
1
]
# modules.append(ResnetBlock(in_ch=in_ch))
modules
.
append
(
modules
.
append
(
Resnet
New
(
Resnet
Block
(
in_channels
=
in_ch
,
in_channels
=
in_ch
,
temb_channels
=
4
*
nf
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
output_scale_factor
=
np
.
sqrt
(
2.0
),
...
@@ -408,9 +394,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -408,9 +394,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
)
)
)
)
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
# modules.append(ResnetBlock(in_ch=in_ch))
modules
.
append
(
modules
.
append
(
Resnet
New
(
Resnet
Block
(
in_channels
=
in_ch
,
in_channels
=
in_ch
,
temb_channels
=
4
*
nf
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
output_scale_factor
=
np
.
sqrt
(
2.0
),
...
@@ -426,9 +411,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -426,9 +411,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
num_res_blocks
+
1
):
for
i_block
in
range
(
num_res_blocks
+
1
):
out_ch
=
nf
*
ch_mult
[
i_level
]
out_ch
=
nf
*
ch_mult
[
i_level
]
# modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
modules
.
append
(
modules
.
append
(
Resnet
New
(
Resnet
Block
(
in_channels
=
in_ch
+
hs_c
.
pop
(),
in_channels
=
in_ch
+
hs_c
.
pop
(),
out_channels
=
out_ch
,
out_channels
=
out_ch
,
temb_channels
=
4
*
nf
,
temb_channels
=
4
*
nf
,
...
@@ -470,9 +454,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -470,9 +454,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
if
i_level
!=
0
:
if
i_level
!=
0
:
# modules.append(ResnetBlock(in_ch=in_ch, up=True))
modules
.
append
(
modules
.
append
(
Resnet
New
(
Resnet
Block
(
in_channels
=
in_ch
,
in_channels
=
in_ch
,
temb_channels
=
4
*
nf
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
output_scale_factor
=
np
.
sqrt
(
2.0
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment