Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
11667d08
Unverified
Commit
11667d08
authored
Jul 01, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 01, 2022
Browse files
Merge pull request #59 from huggingface/fuse_final_resnets
[Resnet] Merge final 2D resnet
parents
c2bc59d2
221de0ed
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
176 additions
and
42 deletions
+176
-42
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+99
-20
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+77
-22
No files found.
src/diffusers/models/resnet.py
View file @
11667d08
...
@@ -174,9 +174,7 @@ class Downsample(nn.Module):
...
@@ -174,9 +174,7 @@ class Downsample(nn.Module):
# return self.conv(x)
# return self.conv(x)
# RESNETS
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
class
ResnetBlock
(
nn
.
Module
):
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -187,15 +185,20 @@ class ResnetBlock(nn.Module):
...
@@ -187,15 +185,20 @@ class ResnetBlock(nn.Module):
dropout
=
0.0
,
dropout
=
0.0
,
temb_channels
=
512
,
temb_channels
=
512
,
groups
=
32
,
groups
=
32
,
groups_out
=
None
,
pre_norm
=
True
,
pre_norm
=
True
,
eps
=
1e-6
,
eps
=
1e-6
,
non_linearity
=
"swish"
,
non_linearity
=
"swish"
,
time_embedding_norm
=
"default"
,
time_embedding_norm
=
"default"
,
kernel
=
None
,
output_scale_factor
=
1.0
,
use_nin_shortcut
=
None
,
up
=
False
,
up
=
False
,
down
=
False
,
down
=
False
,
overwrite_for_grad_tts
=
False
,
overwrite_for_grad_tts
=
False
,
overwrite_for_ldm
=
False
,
overwrite_for_ldm
=
False
,
overwrite_for_glide
=
False
,
overwrite_for_glide
=
False
,
overwrite_for_score_vde
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
pre_norm
=
pre_norm
self
.
pre_norm
=
pre_norm
...
@@ -206,6 +209,10 @@ class ResnetBlock(nn.Module):
...
@@ -206,6 +209,10 @@ class ResnetBlock(nn.Module):
self
.
time_embedding_norm
=
time_embedding_norm
self
.
time_embedding_norm
=
time_embedding_norm
self
.
up
=
up
self
.
up
=
up
self
.
down
=
down
self
.
down
=
down
self
.
output_scale_factor
=
output_scale_factor
if
groups_out
is
None
:
groups_out
=
groups
if
self
.
pre_norm
:
if
self
.
pre_norm
:
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
...
@@ -219,7 +226,7 @@ class ResnetBlock(nn.Module):
...
@@ -219,7 +226,7 @@ class ResnetBlock(nn.Module):
elif
time_embedding_norm
==
"scale_shift"
:
elif
time_embedding_norm
==
"scale_shift"
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
_out
,
eps
=
eps
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
...
@@ -230,14 +237,29 @@ class ResnetBlock(nn.Module):
...
@@ -230,14 +237,29 @@ class ResnetBlock(nn.Module):
elif
non_linearity
==
"silu"
:
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
self
.
nonlinearity
=
nn
.
SiLU
()
if
up
:
# if up:
self
.
h_upd
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
# self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
self
.
x_upd
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
# self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
elif
down
:
# elif down:
self
.
h_upd
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
# self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self
.
x_upd
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
# self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
upsample
=
self
.
downsample
=
None
if
self
.
up
and
kernel
==
"fir"
:
fir_kernel
=
(
1
,
3
,
3
,
1
)
self
.
upsample
=
lambda
x
:
upsample_2d
(
x
,
k
=
fir_kernel
)
elif
self
.
up
and
kernel
is
None
:
self
.
upsample
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
elif
self
.
down
and
kernel
==
"fir"
:
fir_kernel
=
(
1
,
3
,
3
,
1
)
self
.
downsample
=
lambda
x
:
downsample_2d
(
x
,
k
=
fir_kernel
)
elif
self
.
down
and
kernel
is
None
:
self
.
downsample
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
self
.
use_nin_shortcut
=
self
.
in_channels
!=
self
.
out_channels
if
use_nin_shortcut
is
None
else
use_nin_shortcut
self
.
nin_shortcut
=
None
if
self
.
use_nin_shortcut
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
...
@@ -245,6 +267,7 @@ class ResnetBlock(nn.Module):
...
@@ -245,6 +267,7 @@ class ResnetBlock(nn.Module):
self
.
overwrite_for_glide
=
overwrite_for_glide
self
.
overwrite_for_glide
=
overwrite_for_glide
self
.
overwrite_for_grad_tts
=
overwrite_for_grad_tts
self
.
overwrite_for_grad_tts
=
overwrite_for_grad_tts
self
.
overwrite_for_ldm
=
overwrite_for_ldm
or
overwrite_for_glide
self
.
overwrite_for_ldm
=
overwrite_for_ldm
or
overwrite_for_glide
self
.
overwrite_for_score_vde
=
overwrite_for_score_vde
if
self
.
overwrite_for_grad_tts
:
if
self
.
overwrite_for_grad_tts
:
dim
=
in_channels
dim
=
in_channels
dim_out
=
out_channels
dim_out
=
out_channels
...
@@ -260,12 +283,10 @@ class ResnetBlock(nn.Module):
...
@@ -260,12 +283,10 @@ class ResnetBlock(nn.Module):
self
.
res_conv
=
torch
.
nn
.
Identity
()
self
.
res_conv
=
torch
.
nn
.
Identity
()
elif
self
.
overwrite_for_ldm
:
elif
self
.
overwrite_for_ldm
:
dims
=
2
dims
=
2
# eps = 1e-5
# non_linearity = "silu"
# overwrite_for_ldm
channels
=
in_channels
channels
=
in_channels
emb_channels
=
temb_channels
emb_channels
=
temb_channels
use_scale_shift_norm
=
False
use_scale_shift_norm
=
False
non_linearity
=
"silu"
self
.
in_layers
=
nn
.
Sequential
(
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
,
swish
=
1.0
),
normalization
(
channels
,
swish
=
1.0
),
...
@@ -289,6 +310,40 @@ class ResnetBlock(nn.Module):
...
@@ -289,6 +310,40 @@ class ResnetBlock(nn.Module):
self
.
skip_connection
=
nn
.
Identity
()
self
.
skip_connection
=
nn
.
Identity
()
else
:
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
elif
self
.
overwrite_for_score_vde
:
in_ch
=
in_channels
out_ch
=
out_channels
eps
=
1e-6
num_groups
=
min
(
in_ch
//
4
,
32
)
num_groups_out
=
min
(
out_ch
//
4
,
32
)
temb_dim
=
temb_channels
# output_scale_factor = np.sqrt(2.0)
# non_linearity = "silu"
# use_nin_shortcut = in_channels != out_channels or use_nin_shortcut = True
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
in_ch
,
eps
=
eps
)
self
.
up
=
up
self
.
down
=
down
self
.
Conv_0
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
padding
=
1
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
variance_scaling
()(
self
.
Dense_0
.
weight
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
num_groups_out
,
num_channels
=
out_ch
,
eps
=
eps
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv2d
(
out_ch
,
out_ch
,
init_scale
=
0.0
,
kernel_size
=
3
,
padding
=
1
)
if
in_ch
!=
out_ch
or
up
or
down
:
# 1x1 convolution with DDPM initialization.
self
.
Conv_2
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
1
,
padding
=
0
)
# self.skip_rescale = skip_rescale
self
.
in_ch
=
in_ch
self
.
out_ch
=
out_ch
# TODO(Patrick) - move to main init
self
.
is_overwritten
=
False
def
set_weights_grad_tts
(
self
):
def
set_weights_grad_tts
(
self
):
self
.
conv1
.
weight
.
data
=
self
.
block1
.
block
[
0
].
weight
.
data
self
.
conv1
.
weight
.
data
=
self
.
block1
.
block
[
0
].
weight
.
data
...
@@ -328,6 +383,24 @@ class ResnetBlock(nn.Module):
...
@@ -328,6 +383,24 @@ class ResnetBlock(nn.Module):
self
.
nin_shortcut
.
weight
.
data
=
self
.
skip_connection
.
weight
.
data
self
.
nin_shortcut
.
weight
.
data
=
self
.
skip_connection
.
weight
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
skip_connection
.
bias
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
skip_connection
.
bias
.
data
def
set_weights_score_vde
(
self
):
self
.
conv1
.
weight
.
data
=
self
.
Conv_0
.
weight
.
data
self
.
conv1
.
bias
.
data
=
self
.
Conv_0
.
bias
.
data
self
.
norm1
.
weight
.
data
=
self
.
GroupNorm_0
.
weight
.
data
self
.
norm1
.
bias
.
data
=
self
.
GroupNorm_0
.
bias
.
data
self
.
conv2
.
weight
.
data
=
self
.
Conv_1
.
weight
.
data
self
.
conv2
.
bias
.
data
=
self
.
Conv_1
.
bias
.
data
self
.
norm2
.
weight
.
data
=
self
.
GroupNorm_1
.
weight
.
data
self
.
norm2
.
bias
.
data
=
self
.
GroupNorm_1
.
bias
.
data
self
.
temb_proj
.
weight
.
data
=
self
.
Dense_0
.
weight
.
data
self
.
temb_proj
.
bias
.
data
=
self
.
Dense_0
.
bias
.
data
if
self
.
in_channels
!=
self
.
out_channels
or
self
.
up
or
self
.
down
:
self
.
nin_shortcut
.
weight
.
data
=
self
.
Conv_2
.
weight
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
Conv_2
.
bias
.
data
def
forward
(
self
,
x
,
temb
,
mask
=
1.0
):
def
forward
(
self
,
x
,
temb
,
mask
=
1.0
):
# TODO(Patrick) eventually this class should be split into multiple classes
# TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements
# too many if else statements
...
@@ -337,6 +410,9 @@ class ResnetBlock(nn.Module):
...
@@ -337,6 +410,9 @@ class ResnetBlock(nn.Module):
elif
self
.
overwrite_for_ldm
and
not
self
.
is_overwritten
:
elif
self
.
overwrite_for_ldm
and
not
self
.
is_overwritten
:
self
.
set_weights_ldm
()
self
.
set_weights_ldm
()
self
.
is_overwritten
=
True
self
.
is_overwritten
=
True
elif
self
.
overwrite_for_score_vde
and
not
self
.
is_overwritten
:
self
.
set_weights_score_vde
()
self
.
is_overwritten
=
True
h
=
x
h
=
x
h
=
h
*
mask
h
=
h
*
mask
...
@@ -344,9 +420,12 @@ class ResnetBlock(nn.Module):
...
@@ -344,9 +420,12 @@ class ResnetBlock(nn.Module):
h
=
self
.
norm1
(
h
)
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
self
.
up
or
self
.
down
:
if
self
.
upsample
is
not
None
:
x
=
self
.
x_upd
(
x
)
x
=
self
.
upsample
(
x
)
h
=
self
.
h_upd
(
h
)
h
=
self
.
upsample
(
h
)
elif
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
h
=
self
.
downsample
(
h
)
h
=
self
.
conv1
(
h
)
h
=
self
.
conv1
(
h
)
...
@@ -379,10 +458,10 @@ class ResnetBlock(nn.Module):
...
@@ -379,10 +458,10 @@ class ResnetBlock(nn.Module):
h
=
h
*
mask
h
=
h
*
mask
x
=
x
*
mask
x
=
x
*
mask
if
self
.
in_
channels
!=
self
.
out_channels
:
if
self
.
n
in_
shortcut
is
not
None
:
x
=
self
.
nin_shortcut
(
x
)
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
return
(
x
+
h
)
/
self
.
output_scale_factor
# TODO(Patrick) - just there to convert the weights; can delete afterward
# TODO(Patrick) - just there to convert the weights; can delete afterward
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
11667d08
...
@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
Downsample
,
ResnetBlock
BigGANpp
,
Upsample
,
downsample_2d
,
upfirdn2d
,
upsample_2d
from
.resnet
import
Downsample
,
ResnetBlock
,
Upsample
,
downsample_2d
,
upfirdn2d
,
upsample_2d
def
_setup_kernel
(
k
):
def
_setup_kernel
(
k
):
...
@@ -276,8 +276,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -276,8 +276,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
skip_rescale
=
skip_rescale
,
skip_rescale
=
skip_rescale
,
continuous
=
continuous
,
continuous
=
continuous
,
)
)
self
.
act
=
act
=
nn
.
SiLU
()
self
.
act
=
nn
.
SiLU
()
self
.
nf
=
nf
self
.
nf
=
nf
self
.
num_res_blocks
=
num_res_blocks
self
.
num_res_blocks
=
num_res_blocks
self
.
attn_resolutions
=
attn_resolutions
self
.
attn_resolutions
=
attn_resolutions
...
@@ -333,19 +332,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -333,19 +332,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
elif
progressive_input
==
"residual"
:
elif
progressive_input
==
"residual"
:
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
use_conv
=
True
)
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
use_conv
=
True
)
ResnetBlock
=
functools
.
partial
(
ResnetBlockBigGANpp
,
act
=
act
,
dropout
=
dropout
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
init_scale
=
init_scale
,
skip_rescale
=
skip_rescale
,
temb_dim
=
nf
*
4
,
)
# Downsampling block
channels
=
num_channels
channels
=
num_channels
if
progressive_input
!=
"none"
:
if
progressive_input
!=
"none"
:
input_pyramid_ch
=
channels
input_pyramid_ch
=
channels
...
@@ -358,7 +344,18 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -358,7 +344,18 @@ class NCSNpp(ModelMixin, ConfigMixin):
# Residual blocks for this resolution
# Residual blocks for this resolution
for
i_block
in
range
(
num_res_blocks
):
for
i_block
in
range
(
num_res_blocks
):
out_ch
=
nf
*
ch_mult
[
i_level
]
out_ch
=
nf
*
ch_mult
[
i_level
]
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
,
out_ch
=
out_ch
))
modules
.
append
(
ResnetBlock
(
in_channels
=
in_ch
,
out_channels
=
out_ch
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
non_linearity
=
"silu"
,
groups
=
min
(
in_ch
//
4
,
32
),
groups_out
=
min
(
out_ch
//
4
,
32
),
overwrite_for_score_vde
=
True
,
)
)
in_ch
=
out_ch
in_ch
=
out_ch
if
all_resolutions
[
i_level
]
in
attn_resolutions
:
if
all_resolutions
[
i_level
]
in
attn_resolutions
:
...
@@ -366,7 +363,20 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -366,7 +363,20 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c
.
append
(
in_ch
)
hs_c
.
append
(
in_ch
)
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
modules
.
append
(
ResnetBlock
(
down
=
True
,
in_ch
=
in_ch
))
modules
.
append
(
ResnetBlock
(
in_channels
=
in_ch
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
non_linearity
=
"silu"
,
groups
=
min
(
in_ch
//
4
,
32
),
groups_out
=
min
(
out_ch
//
4
,
32
),
overwrite_for_score_vde
=
True
,
down
=
True
,
kernel
=
"fir"
,
# TODO(Patrick) - it seems like both fir and non-fir kernels are fine
use_nin_shortcut
=
True
,
)
)
if
progressive_input
==
"input_skip"
:
if
progressive_input
==
"input_skip"
:
modules
.
append
(
combiner
(
dim1
=
input_pyramid_ch
,
dim2
=
in_ch
))
modules
.
append
(
combiner
(
dim1
=
input_pyramid_ch
,
dim2
=
in_ch
))
...
@@ -380,16 +390,48 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -380,16 +390,48 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c
.
append
(
in_ch
)
hs_c
.
append
(
in_ch
)
in_ch
=
hs_c
[
-
1
]
in_ch
=
hs_c
[
-
1
]
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
))
modules
.
append
(
ResnetBlock
(
in_channels
=
in_ch
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
non_linearity
=
"silu"
,
groups
=
min
(
in_ch
//
4
,
32
),
groups_out
=
min
(
out_ch
//
4
,
32
),
overwrite_for_score_vde
=
True
,
)
)
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
))
modules
.
append
(
ResnetBlock
(
in_channels
=
in_ch
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
non_linearity
=
"silu"
,
groups
=
min
(
in_ch
//
4
,
32
),
groups_out
=
min
(
out_ch
//
4
,
32
),
overwrite_for_score_vde
=
True
,
)
)
pyramid_ch
=
0
pyramid_ch
=
0
# Upsampling block
# Upsampling block
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
num_res_blocks
+
1
):
for
i_block
in
range
(
num_res_blocks
+
1
):
out_ch
=
nf
*
ch_mult
[
i_level
]
out_ch
=
nf
*
ch_mult
[
i_level
]
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
+
hs_c
.
pop
(),
out_ch
=
out_ch
))
in_ch
=
in_ch
+
hs_c
.
pop
()
modules
.
append
(
ResnetBlock
(
in_channels
=
in_ch
,
out_channels
=
out_ch
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
non_linearity
=
"silu"
,
groups
=
min
(
in_ch
//
4
,
32
),
groups_out
=
min
(
out_ch
//
4
,
32
),
overwrite_for_score_vde
=
True
,
)
)
in_ch
=
out_ch
in_ch
=
out_ch
if
all_resolutions
[
i_level
]
in
attn_resolutions
:
if
all_resolutions
[
i_level
]
in
attn_resolutions
:
...
@@ -421,7 +463,20 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -421,7 +463,20 @@ class NCSNpp(ModelMixin, ConfigMixin):
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
if
i_level
!=
0
:
if
i_level
!=
0
:
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
,
up
=
True
))
modules
.
append
(
ResnetBlock
(
in_channels
=
in_ch
,
temb_channels
=
4
*
nf
,
output_scale_factor
=
np
.
sqrt
(
2.0
),
non_linearity
=
"silu"
,
groups
=
min
(
in_ch
//
4
,
32
),
groups_out
=
min
(
out_ch
//
4
,
32
),
overwrite_for_score_vde
=
True
,
up
=
True
,
kernel
=
"fir"
,
# TODO(Patrick) - it seems like both fir and non-fir kernels are fine
use_nin_shortcut
=
True
,
)
)
assert
not
hs_c
assert
not
hs_c
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment