Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
663393e2
Commit
663393e2
authored
Jun 30, 2022
by
patil-suraj
Browse files
remove fir option
parent
c50d9975
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
17 deletions
+9
-17
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+0
-2
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+9
-15
No files found.
src/diffusers/models/resnet.py
View file @
663393e2
...
...
@@ -579,7 +579,6 @@ class ResnetBlockBigGANpp(nn.Module):
up
=
False
,
down
=
False
,
dropout
=
0.1
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
skip_rescale
=
True
,
init_scale
=
0.0
,
...
...
@@ -590,7 +589,6 @@ class ResnetBlockBigGANpp(nn.Module):
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
up
=
up
self
.
down
=
down
self
.
fir
=
fir
self
.
fir_kernel
=
fir_kernel
self
.
Conv_0
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
padding
=
1
)
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
663393e2
...
...
@@ -334,7 +334,7 @@ class Combine(nn.Module):
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with_conv
=
False
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
if
with_conv
:
...
...
@@ -347,13 +347,11 @@ class Upsample(nn.Module):
use_bias
=
True
,
kernel_init
=
variance_scaling
(),
)
self
.
fir
=
fir
self
.
with_conv
=
with_conv
self
.
fir_kernel
=
fir_kernel
self
.
out_ch
=
out_ch
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
if
not
self
.
with_conv
:
h
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
...
...
@@ -363,7 +361,7 @@ class Upsample(nn.Module):
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with_conv
=
False
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
if
with_conv
:
...
...
@@ -376,13 +374,11 @@ class Downsample(nn.Module):
use_bias
=
True
,
kernel_init
=
variance_scaling
(),
)
self
.
fir
=
fir
self
.
fir_kernel
=
fir_kernel
self
.
with_conv
=
with_conv
self
.
out_ch
=
out_ch
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
if
not
self
.
with_conv
:
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
...
...
@@ -404,7 +400,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
conv_size
=
3
,
dropout
=
0.0
,
embedding_type
=
"fourier"
,
fir
=
True
,
fir
=
True
,
# TODO (patil-suraj) remove this option from here and pre-trained model configs
fir_kernel
=
(
1
,
3
,
3
,
1
),
fourier_scale
=
16
,
init_scale
=
0.0
,
...
...
@@ -428,7 +424,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
conv_size
=
conv_size
,
dropout
=
dropout
,
embedding_type
=
embedding_type
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
fourier_scale
=
fourier_scale
,
init_scale
=
init_scale
,
...
...
@@ -483,25 +478,24 @@ class NCSNpp(ModelMixin, ConfigMixin):
nn
.
init
.
zeros_
(
modules
[
-
1
].
bias
)
AttnBlock
=
functools
.
partial
(
AttentionBlock
,
overwrite_linear
=
True
,
rescale_output_factor
=
math
.
sqrt
(
2.0
))
Up_sample
=
functools
.
partial
(
Upsample
,
with_conv
=
resamp_with_conv
,
fir
=
fir
,
fir_kernel
=
fir_kernel
)
Up_sample
=
functools
.
partial
(
Upsample
,
with_conv
=
resamp_with_conv
,
fir_kernel
=
fir_kernel
)
if
progressive
==
"output_skip"
:
self
.
pyramid_upsample
=
Up_sample
(
fir
=
fir
,
fir_kernel
=
fir_kernel
,
with_conv
=
False
)
self
.
pyramid_upsample
=
Up_sample
(
fir_kernel
=
fir_kernel
,
with_conv
=
False
)
elif
progressive
==
"residual"
:
pyramid_upsample
=
functools
.
partial
(
Up_sample
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
with_conv
=
True
)
pyramid_upsample
=
functools
.
partial
(
Up_sample
,
fir_kernel
=
fir_kernel
,
with_conv
=
True
)
Down_sample
=
functools
.
partial
(
Downsample
,
with_conv
=
resamp_with_conv
,
fir
=
fir
,
fir_kernel
=
fir_kernel
)
Down_sample
=
functools
.
partial
(
Downsample
,
with_conv
=
resamp_with_conv
,
fir_kernel
=
fir_kernel
)
if
progressive_input
==
"input_skip"
:
self
.
pyramid_downsample
=
Down_sample
(
fir
=
fir
,
fir_kernel
=
fir_kernel
,
with_conv
=
False
)
self
.
pyramid_downsample
=
Down_sample
(
fir_kernel
=
fir_kernel
,
with_conv
=
False
)
elif
progressive_input
==
"residual"
:
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
with_conv
=
True
)
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
fir_kernel
=
fir_kernel
,
with_conv
=
True
)
ResnetBlock
=
functools
.
partial
(
ResnetBlockBigGANpp
,
act
=
act
,
dropout
=
dropout
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
init_scale
=
init_scale
,
skip_rescale
=
skip_rescale
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment