Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
c691bb2f
Unverified
Commit
c691bb2f
authored
Jul 01, 2022
by
Suraj Patil
Committed by
GitHub
Jul 01, 2022
Browse files
Merge pull request #60 from huggingface/add-fir-back
fix unde sde for vp model.
parents
abedfb08
4c293e0e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
75 additions
and
33 deletions
+75
-33
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+36
-8
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+39
-25
No files found.
src/diffusers/models/resnet.py
View file @
c691bb2f
from
abc
import
abstractmethod
from
functools
import
partial
import
numpy
as
np
import
torch
...
...
@@ -78,18 +79,25 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
,
name
=
"conv"
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
self
.
use_conv_transpose
=
use_conv_transpose
self
.
name
=
name
conv
=
None
if
use_conv_transpose
:
self
.
conv
=
conv_transpose_nd
(
dims
,
channels
,
self
.
out_channels
,
4
,
2
,
1
)
conv
=
conv_transpose_nd
(
dims
,
channels
,
self
.
out_channels
,
4
,
2
,
1
)
elif
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
if
name
==
"conv"
:
self
.
conv
=
conv
else
:
self
.
Conv2d_0
=
conv
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
...
...
@@ -102,7 +110,10 @@ class Upsample(nn.Module):
x
=
F
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
if
self
.
name
==
"conv"
:
x
=
self
.
conv
(
x
)
else
:
x
=
self
.
Conv2d_0
(
x
)
return
x
...
...
@@ -134,6 +145,8 @@ class Downsample(nn.Module):
if
name
==
"conv"
:
self
.
conv
=
conv
elif
name
==
"Conv2d_0"
:
self
.
Conv2d_0
=
conv
else
:
self
.
op
=
conv
...
...
@@ -145,6 +158,8 @@ class Downsample(nn.Module):
if
self
.
name
==
"conv"
:
return
self
.
conv
(
x
)
elif
self
.
name
==
"Conv2d_0"
:
return
self
.
Conv2d_0
(
x
)
else
:
return
self
.
op
(
x
)
...
...
@@ -390,6 +405,7 @@ class ResnetBlockBigGANpp(nn.Module):
up
=
False
,
down
=
False
,
dropout
=
0.1
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
skip_rescale
=
True
,
init_scale
=
0.0
,
...
...
@@ -400,8 +416,20 @@ class ResnetBlockBigGANpp(nn.Module):
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
up
=
up
self
.
down
=
down
self
.
fir
=
fir
self
.
fir_kernel
=
fir_kernel
if
self
.
up
:
if
self
.
fir
:
self
.
upsample
=
partial
(
upsample_2d
,
k
=
self
.
fir_kernel
,
factor
=
2
)
else
:
self
.
upsample
=
partial
(
F
.
interpolate
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
elif
self
.
down
:
if
self
.
fir
:
self
.
downsample
=
partial
(
downsample_2d
,
k
=
self
.
fir_kernel
,
factor
=
2
)
else
:
self
.
downsample
=
partial
(
F
.
avg_pool2d
,
kernel_size
=
2
,
stride
=
2
)
self
.
Conv_0
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
padding
=
1
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
...
...
@@ -424,11 +452,11 @@ class ResnetBlockBigGANpp(nn.Module):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
if
self
.
up
:
h
=
upsample
_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
upsample
_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
h
=
self
.
upsample
(
h
)
x
=
self
.
upsample
(
x
)
elif
self
.
down
:
h
=
downsample
_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample
_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
h
=
self
.
downsample
(
h
)
x
=
self
.
downsample
(
x
)
h
=
self
.
Conv_0
(
h
)
# Add bias to each feature map conditioned on the time embedding
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
c691bb2f
...
...
@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
ResnetBlockBigGANpp
,
downsample_2d
,
upfirdn2d
,
upsample_2d
from
.resnet
import
Downsample
,
ResnetBlockBigGANpp
,
Upsample
,
downsample_2d
,
upfirdn2d
,
upsample_2d
def
_setup_kernel
(
k
):
...
...
@@ -184,18 +184,19 @@ class Combine(nn.Module):
class
FirUpsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
def
__init__
(
self
,
channels
=
None
,
out_ch
annels
=
None
,
use
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
if
with
_conv
:
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
with
_conv
=
with
_conv
out_ch
annels
=
out_ch
annels
if
out_ch
annels
else
channels
if
use
_conv
:
self
.
Conv2d_0
=
Conv2d
(
channels
,
out_ch
annels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
use
_conv
=
use
_conv
self
.
fir_kernel
=
fir_kernel
self
.
out_ch
=
out_ch
self
.
out_ch
annels
=
out_ch
annels
def
forward
(
self
,
x
):
if
self
.
with
_conv
:
if
self
.
use
_conv
:
h
=
_upsample_conv_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
h
=
h
+
self
.
Conv2d_0
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
else
:
h
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
...
...
@@ -203,18 +204,19 @@ class FirUpsample(nn.Module):
class
FirDownsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
def
__init__
(
self
,
channels
=
None
,
out_ch
annels
=
None
,
use
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
if
with
_conv
:
self
.
Conv2d_0
=
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
out_ch
annels
=
out_ch
annels
if
out_ch
annels
else
channels
if
use
_conv
:
self
.
Conv2d_0
=
self
.
Conv2d_0
=
Conv2d
(
channels
,
out_ch
annels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
fir_kernel
=
fir_kernel
self
.
with
_conv
=
with
_conv
self
.
out_ch
=
out_ch
self
.
use
_conv
=
use
_conv
self
.
out_ch
annels
=
out_ch
annels
def
forward
(
self
,
x
):
if
self
.
with
_conv
:
if
self
.
use
_conv
:
x
=
_conv_downsample_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
x
=
x
+
self
.
Conv2d_0
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
else
:
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
...
...
@@ -228,13 +230,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
self
,
image_size
=
1024
,
num_channels
=
3
,
centered
=
False
,
attn_resolutions
=
(
16
,),
ch_mult
=
(
1
,
2
,
4
,
8
,
16
,
32
,
32
,
32
),
conditional
=
True
,
conv_size
=
3
,
dropout
=
0.0
,
embedding_type
=
"fourier"
,
fir
=
True
,
# TODO (patil-suraj) remove this option from here and pre-trained model configs
fir
=
True
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
fourier_scale
=
16
,
init_scale
=
0.0
,
...
...
@@ -252,12 +255,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
self
.
register_to_config
(
image_size
=
image_size
,
num_channels
=
num_channels
,
centered
=
centered
,
attn_resolutions
=
attn_resolutions
,
ch_mult
=
ch_mult
,
conditional
=
conditional
,
conv_size
=
conv_size
,
dropout
=
dropout
,
embedding_type
=
embedding_type
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
fourier_scale
=
fourier_scale
,
init_scale
=
init_scale
,
...
...
@@ -307,24 +312,32 @@ class NCSNpp(ModelMixin, ConfigMixin):
modules
.
append
(
Linear
(
nf
*
4
,
nf
*
4
))
AttnBlock
=
functools
.
partial
(
AttentionBlock
,
overwrite_linear
=
True
,
rescale_output_factor
=
math
.
sqrt
(
2.0
))
Up_sample
=
functools
.
partial
(
FirUpsample
,
with_conv
=
resamp_with_conv
,
fir_kernel
=
fir_kernel
)
if
self
.
fir
:
Up_sample
=
functools
.
partial
(
FirUpsample
,
fir_kernel
=
fir_kernel
,
use_conv
=
resamp_with_conv
)
else
:
Up_sample
=
functools
.
partial
(
Upsample
,
name
=
"Conv2d_0"
)
if
progressive
==
"output_skip"
:
self
.
pyramid_upsample
=
Up_sample
(
fir_kernel
=
fir_ker
ne
l
,
with
_conv
=
False
)
self
.
pyramid_upsample
=
Up_sample
(
channels
=
No
ne
,
use
_conv
=
False
)
elif
progressive
==
"residual"
:
pyramid_upsample
=
functools
.
partial
(
Up_sample
,
fir_kernel
=
fir_kernel
,
with
_conv
=
True
)
pyramid_upsample
=
functools
.
partial
(
Up_sample
,
use
_conv
=
True
)
Down_sample
=
functools
.
partial
(
FirDownsample
,
with_conv
=
resamp_with_conv
,
fir_kernel
=
fir_kernel
)
if
self
.
fir
:
Down_sample
=
functools
.
partial
(
FirDownsample
,
fir_kernel
=
fir_kernel
,
use_conv
=
resamp_with_conv
)
else
:
Down_sample
=
functools
.
partial
(
Downsample
,
padding
=
0
,
name
=
"Conv2d_0"
)
if
progressive_input
==
"input_skip"
:
self
.
pyramid_downsample
=
Down_sample
(
fir_kernel
=
fir_ker
ne
l
,
with
_conv
=
False
)
self
.
pyramid_downsample
=
Down_sample
(
channels
=
No
ne
,
use
_conv
=
False
)
elif
progressive_input
==
"residual"
:
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
fir_kernel
=
fir_kernel
,
with
_conv
=
True
)
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
use
_conv
=
True
)
ResnetBlock
=
functools
.
partial
(
ResnetBlockBigGANpp
,
act
=
act
,
dropout
=
dropout
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
init_scale
=
init_scale
,
skip_rescale
=
skip_rescale
,
...
...
@@ -361,7 +374,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
in_ch
*=
2
elif
progressive_input
==
"residual"
:
modules
.
append
(
pyramid_downsample
(
in_ch
=
input_pyramid_ch
,
out_ch
=
in_ch
))
modules
.
append
(
pyramid_downsample
(
channels
=
input_pyramid_ch
,
out_ch
annels
=
in_ch
))
input_pyramid_ch
=
in_ch
hs_c
.
append
(
in_ch
)
...
...
@@ -402,7 +415,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
)
pyramid_ch
=
channels
elif
progressive
==
"residual"
:
modules
.
append
(
pyramid_upsample
(
in_ch
=
pyramid_ch
,
out_ch
=
in_ch
))
modules
.
append
(
pyramid_upsample
(
channels
=
pyramid_ch
,
out_ch
annels
=
in_ch
))
pyramid_ch
=
in_ch
else
:
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
...
...
@@ -446,7 +459,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
temb
=
None
# If input data is in [0, 1]
x
=
2
*
x
-
1.0
if
not
self
.
config
.
centered
:
x
=
2
*
x
-
1.0
# Downsampling block
input_pyramid
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment