Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
dcb9070b
Commit
dcb9070b
authored
Jul 01, 2022
by
Patrick von Platen
Browse files
quick fix to include non-fir kernels for sde-vp
parent
11667d08
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
100 deletions
+18
-100
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+16
-98
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+2
-2
No files found.
src/diffusers/models/resnet.py
View file @
dcb9070b
...
@@ -237,24 +237,23 @@ class ResnetBlock(nn.Module):
...
@@ -237,24 +237,23 @@ class ResnetBlock(nn.Module):
elif
non_linearity
==
"silu"
:
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
self
.
nonlinearity
=
nn
.
SiLU
()
# if up:
# self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
# self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
# elif down:
# self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
# self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self
.
upsample
=
self
.
downsample
=
None
self
.
upsample
=
self
.
downsample
=
None
if
self
.
up
and
kernel
==
"fir"
:
if
self
.
up
:
fir_kernel
=
(
1
,
3
,
3
,
1
)
if
kernel
==
"fir"
:
self
.
upsample
=
lambda
x
:
upsample_2d
(
x
,
k
=
fir_kernel
)
fir_kernel
=
(
1
,
3
,
3
,
1
)
elif
self
.
up
and
kernel
is
None
:
self
.
upsample
=
lambda
x
:
upsample_2d
(
x
,
k
=
fir_kernel
)
self
.
upsample
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
elif
kernel
==
"sde_vp"
:
elif
self
.
down
and
kernel
==
"fir"
:
self
.
upsample
=
partial
(
F
.
interpolate
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
fir_kernel
=
(
1
,
3
,
3
,
1
)
else
:
self
.
downsample
=
lambda
x
:
downsample_2d
(
x
,
k
=
fir_kernel
)
self
.
upsample
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
elif
self
.
down
and
kernel
is
None
:
elif
self
.
down
:
self
.
downsample
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
if
kernel
==
"fir"
:
fir_kernel
=
(
1
,
3
,
3
,
1
)
self
.
downsample
=
lambda
x
:
downsample_2d
(
x
,
k
=
fir_kernel
)
elif
kernel
==
"sde_vp"
:
self
.
downsample
=
partial
(
F
.
avg_pool2d
,
kernel_size
=
2
,
stride
=
2
)
else
:
self
.
downsample
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
self
.
use_nin_shortcut
=
self
.
in_channels
!=
self
.
out_channels
if
use_nin_shortcut
is
None
else
use_nin_shortcut
self
.
use_nin_shortcut
=
self
.
in_channels
!=
self
.
out_channels
if
use_nin_shortcut
is
None
else
use_nin_shortcut
...
@@ -473,87 +472,6 @@ class Block(torch.nn.Module):
...
@@ -473,87 +472,6 @@ class Block(torch.nn.Module):
)
)
# unet_score_estimation.py
class
ResnetBlockBigGANpp
(
nn
.
Module
):
def
__init__
(
self
,
act
,
in_ch
,
out_ch
=
None
,
temb_dim
=
None
,
up
=
False
,
down
=
False
,
dropout
=
0.1
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
skip_rescale
=
True
,
init_scale
=
0.0
,
):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
up
=
up
self
.
down
=
down
self
.
fir
=
fir
self
.
fir_kernel
=
fir_kernel
if
self
.
up
:
if
self
.
fir
:
self
.
upsample
=
partial
(
upsample_2d
,
k
=
self
.
fir_kernel
,
factor
=
2
)
else
:
self
.
upsample
=
partial
(
F
.
interpolate
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
elif
self
.
down
:
if
self
.
fir
:
self
.
downsample
=
partial
(
downsample_2d
,
k
=
self
.
fir_kernel
,
factor
=
2
)
else
:
self
.
downsample
=
partial
(
F
.
avg_pool2d
,
kernel_size
=
2
,
stride
=
2
)
self
.
Conv_0
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
padding
=
1
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
variance_scaling
()(
self
.
Dense_0
.
weight
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv2d
(
out_ch
,
out_ch
,
init_scale
=
init_scale
,
kernel_size
=
3
,
padding
=
1
)
if
in_ch
!=
out_ch
or
up
or
down
:
# 1x1 convolution with DDPM initialization.
self
.
Conv_2
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
1
,
padding
=
0
)
self
.
skip_rescale
=
skip_rescale
self
.
act
=
act
self
.
in_ch
=
in_ch
self
.
out_ch
=
out_ch
def
forward
(
self
,
x
,
temb
=
None
):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
if
self
.
up
:
h
=
self
.
upsample
(
h
)
x
=
self
.
upsample
(
x
)
elif
self
.
down
:
h
=
self
.
downsample
(
h
)
x
=
self
.
downsample
(
x
)
h
=
self
.
Conv_0
(
h
)
# Add bias to each feature map conditioned on the time embedding
if
temb
is
not
None
:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
h
=
self
.
Dropout_0
(
h
)
h
=
self
.
Conv_1
(
h
)
if
self
.
in_ch
!=
self
.
out_ch
or
self
.
up
or
self
.
down
:
x
=
self
.
Conv_2
(
x
)
if
not
self
.
skip_rescale
:
return
x
+
h
else
:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
# unet_rl.py
# unet_rl.py
class
ResidualTemporalBlock
(
nn
.
Module
):
class
ResidualTemporalBlock
(
nn
.
Module
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
dcb9070b
...
@@ -373,7 +373,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -373,7 +373,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
groups_out
=
min
(
out_ch
//
4
,
32
),
groups_out
=
min
(
out_ch
//
4
,
32
),
overwrite_for_score_vde
=
True
,
overwrite_for_score_vde
=
True
,
down
=
True
,
down
=
True
,
kernel
=
"fir"
,
# TODO(Patrick) - it seems like both fir and non-fir kernels are fine
kernel
=
"fir"
if
self
.
fir
else
"sde_vp"
,
use_nin_shortcut
=
True
,
use_nin_shortcut
=
True
,
)
)
)
)
...
@@ -473,7 +473,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -473,7 +473,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
groups_out
=
min
(
out_ch
//
4
,
32
),
groups_out
=
min
(
out_ch
//
4
,
32
),
overwrite_for_score_vde
=
True
,
overwrite_for_score_vde
=
True
,
up
=
True
,
up
=
True
,
kernel
=
"fir"
,
# TODO(Patrick) - it seems like both fir and non-fir kernels are fine
kernel
=
"fir"
if
self
.
fir
else
"sde_vp"
,
use_nin_shortcut
=
True
,
use_nin_shortcut
=
True
,
)
)
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment