Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
9da575d6
Commit
9da575d6
authored
Jul 01, 2022
by
Patrick von Platen
Browse files
correct more
parent
61dc6574
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
10 deletions
+19
-10
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+5
-9
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+14
-1
No files found.
src/diffusers/models/resnet.py
View file @
9da575d6
...
@@ -404,9 +404,6 @@ class ResnetBlock(nn.Module):
...
@@ -404,9 +404,6 @@ class ResnetBlock(nn.Module):
if
groups_out
is
None
:
if
groups_out
is
None
:
groups_out
=
groups
groups_out
=
groups
if
use_nin_shortcut
is
None
:
use_nin_shortcut
=
self
.
in_channels
!=
self
.
out_channels
if
self
.
pre_norm
:
if
self
.
pre_norm
:
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
else
:
else
:
...
@@ -439,8 +436,11 @@ class ResnetBlock(nn.Module):
...
@@ -439,8 +436,11 @@ class ResnetBlock(nn.Module):
self
.
upsample
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
if
self
.
up
else
None
self
.
upsample
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
if
self
.
up
else
None
self
.
downsample
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
if
self
.
down
else
None
self
.
downsample
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
if
self
.
down
else
None
self
.
nin_shortcut
=
None
self
.
nin_shortcut
=
use_nin_shortcut
if
use_nin_shortcut
:
if
self
.
use_nin_shortcut
is
None
:
self
.
use_nin_shortcut
=
self
.
in_channels
!=
self
.
out_channels
if
self
.
use_nin_shortcut
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
...
@@ -613,10 +613,6 @@ class ResnetBlock(nn.Module):
...
@@ -613,10 +613,6 @@ class ResnetBlock(nn.Module):
x
=
self
.
downsample
(
x
)
x
=
self
.
downsample
(
x
)
h
=
self
.
downsample
(
h
)
h
=
self
.
downsample
(
h
)
# if self.up: or self.down:
# x = self.x_upd(x)
# h = self.h_upd(h)
#
h
=
self
.
conv1
(
h
)
h
=
self
.
conv1
(
h
)
if
not
self
.
pre_norm
:
if
not
self
.
pre_norm
:
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
9da575d6
...
@@ -29,6 +29,7 @@ from .attention import AttentionBlock
...
@@ -29,6 +29,7 @@ from .attention import AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
downsample_2d
,
upfirdn2d
,
upsample_2d
from
.resnet
import
downsample_2d
,
upfirdn2d
,
upsample_2d
from
.resnet
import
ResnetBlockBigGANppNew
as
ResnetBlockBigGANpp
from
.resnet
import
ResnetBlockBigGANppNew
as
ResnetBlockBigGANpp
from
.resnet
import
ResnetBlock
as
ResnetNew
def
_setup_kernel
(
k
):
def
_setup_kernel
(
k
):
...
@@ -346,7 +347,19 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -346,7 +347,19 @@ class NCSNpp(ModelMixin, ConfigMixin):
# Residual blocks for this resolution
# Residual blocks for this resolution
for
i_block
in
range
(
num_res_blocks
):
for
i_block
in
range
(
num_res_blocks
):
out_ch
=
nf
*
ch_mult
[
i_level
]
out_ch
=
nf
*
ch_mult
[
i_level
]
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
,
out_ch
=
out_ch
))
# modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
modules
.
append
(
ResnetNew
(
in_channels
=
in_ch
,
out_channels
=
out_ch
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
non_linearity
=
"silu"
,
groups
=
min
(
in_ch
//
4
,
32
),
groups_out
=
min
(
out_ch
//
4
,
32
),
overwrite_for_score_vde
=
True
,
)
)
in_ch
=
out_ch
in_ch
=
out_ch
if
all_resolutions
[
i_level
]
in
attn_resolutions
:
if
all_resolutions
[
i_level
]
in
attn_resolutions
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment