Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
c691bb2f
Unverified
Commit
c691bb2f
authored
Jul 01, 2022
by
Suraj Patil
Committed by
GitHub
Jul 01, 2022
Browse files
Merge pull request #60 from huggingface/add-fir-back
fix unde sde for vp model.
parents
abedfb08
4c293e0e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
75 additions
and
33 deletions
+75
-33
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+36
-8
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+39
-25
No files found.
src/diffusers/models/resnet.py
View file @
c691bb2f
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
functools
import
partial
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -78,18 +79,25 @@ class Upsample(nn.Module):
...
@@ -78,18 +79,25 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
upsampling occurs in the inner-two dimensions.
"""
"""
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
,
name
=
"conv"
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_conv
=
use_conv
self
.
dims
=
dims
self
.
dims
=
dims
self
.
use_conv_transpose
=
use_conv_transpose
self
.
use_conv_transpose
=
use_conv_transpose
self
.
name
=
name
conv
=
None
if
use_conv_transpose
:
if
use_conv_transpose
:
self
.
conv
=
conv_transpose_nd
(
dims
,
channels
,
self
.
out_channels
,
4
,
2
,
1
)
conv
=
conv_transpose_nd
(
dims
,
channels
,
self
.
out_channels
,
4
,
2
,
1
)
elif
use_conv
:
elif
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
if
name
==
"conv"
:
self
.
conv
=
conv
else
:
self
.
Conv2d_0
=
conv
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
assert
x
.
shape
[
1
]
==
self
.
channels
...
@@ -102,7 +110,10 @@ class Upsample(nn.Module):
...
@@ -102,7 +110,10 @@ class Upsample(nn.Module):
x
=
F
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
x
=
F
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
use_conv
:
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
if
self
.
name
==
"conv"
:
x
=
self
.
conv
(
x
)
else
:
x
=
self
.
Conv2d_0
(
x
)
return
x
return
x
...
@@ -134,6 +145,8 @@ class Downsample(nn.Module):
...
@@ -134,6 +145,8 @@ class Downsample(nn.Module):
if
name
==
"conv"
:
if
name
==
"conv"
:
self
.
conv
=
conv
self
.
conv
=
conv
elif
name
==
"Conv2d_0"
:
self
.
Conv2d_0
=
conv
else
:
else
:
self
.
op
=
conv
self
.
op
=
conv
...
@@ -145,6 +158,8 @@ class Downsample(nn.Module):
...
@@ -145,6 +158,8 @@ class Downsample(nn.Module):
if
self
.
name
==
"conv"
:
if
self
.
name
==
"conv"
:
return
self
.
conv
(
x
)
return
self
.
conv
(
x
)
elif
self
.
name
==
"Conv2d_0"
:
return
self
.
Conv2d_0
(
x
)
else
:
else
:
return
self
.
op
(
x
)
return
self
.
op
(
x
)
...
@@ -390,6 +405,7 @@ class ResnetBlockBigGANpp(nn.Module):
...
@@ -390,6 +405,7 @@ class ResnetBlockBigGANpp(nn.Module):
up
=
False
,
up
=
False
,
down
=
False
,
down
=
False
,
dropout
=
0.1
,
dropout
=
0.1
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
fir_kernel
=
(
1
,
3
,
3
,
1
),
skip_rescale
=
True
,
skip_rescale
=
True
,
init_scale
=
0.0
,
init_scale
=
0.0
,
...
@@ -400,8 +416,20 @@ class ResnetBlockBigGANpp(nn.Module):
...
@@ -400,8 +416,20 @@ class ResnetBlockBigGANpp(nn.Module):
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
up
=
up
self
.
up
=
up
self
.
down
=
down
self
.
down
=
down
self
.
fir
=
fir
self
.
fir_kernel
=
fir_kernel
self
.
fir_kernel
=
fir_kernel
if
self
.
up
:
if
self
.
fir
:
self
.
upsample
=
partial
(
upsample_2d
,
k
=
self
.
fir_kernel
,
factor
=
2
)
else
:
self
.
upsample
=
partial
(
F
.
interpolate
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
elif
self
.
down
:
if
self
.
fir
:
self
.
downsample
=
partial
(
downsample_2d
,
k
=
self
.
fir_kernel
,
factor
=
2
)
else
:
self
.
downsample
=
partial
(
F
.
avg_pool2d
,
kernel_size
=
2
,
stride
=
2
)
self
.
Conv_0
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
padding
=
1
)
self
.
Conv_0
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
padding
=
1
)
if
temb_dim
is
not
None
:
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
...
@@ -424,11 +452,11 @@ class ResnetBlockBigGANpp(nn.Module):
...
@@ -424,11 +452,11 @@ class ResnetBlockBigGANpp(nn.Module):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
if
self
.
up
:
if
self
.
up
:
h
=
upsample
_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
h
=
self
.
upsample
(
h
)
x
=
upsample
_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
x
=
self
.
upsample
(
x
)
elif
self
.
down
:
elif
self
.
down
:
h
=
downsample
_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
h
=
self
.
downsample
(
h
)
x
=
downsample
_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
x
=
self
.
downsample
(
x
)
h
=
self
.
Conv_0
(
h
)
h
=
self
.
Conv_0
(
h
)
# Add bias to each feature map conditioned on the time embedding
# Add bias to each feature map conditioned on the time embedding
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
c691bb2f
...
@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
ResnetBlockBigGANpp
,
downsample_2d
,
upfirdn2d
,
upsample_2d
from
.resnet
import
Downsample
,
ResnetBlockBigGANpp
,
Upsample
,
downsample_2d
,
upfirdn2d
,
upsample_2d
def
_setup_kernel
(
k
):
def
_setup_kernel
(
k
):
...
@@ -184,18 +184,19 @@ class Combine(nn.Module):
...
@@ -184,18 +184,19 @@ class Combine(nn.Module):
class
FirUpsample
(
nn
.
Module
):
class
FirUpsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
def
__init__
(
self
,
channels
=
None
,
out_ch
annels
=
None
,
use
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
out_ch
annels
=
out_ch
annels
if
out_ch
annels
else
channels
if
with
_conv
:
if
use
_conv
:
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
Conv2d_0
=
Conv2d
(
channels
,
out_ch
annels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
with
_conv
=
with
_conv
self
.
use
_conv
=
use
_conv
self
.
fir_kernel
=
fir_kernel
self
.
fir_kernel
=
fir_kernel
self
.
out_ch
=
out_ch
self
.
out_ch
annels
=
out_ch
annels
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
if
self
.
with
_conv
:
if
self
.
use
_conv
:
h
=
_upsample_conv_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
h
=
_upsample_conv_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
h
=
h
+
self
.
Conv2d_0
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
else
:
else
:
h
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
h
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
...
@@ -203,18 +204,19 @@ class FirUpsample(nn.Module):
...
@@ -203,18 +204,19 @@ class FirUpsample(nn.Module):
class
FirDownsample
(
nn
.
Module
):
class
FirDownsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
def
__init__
(
self
,
channels
=
None
,
out_ch
annels
=
None
,
use
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
out_ch
annels
=
out_ch
annels
if
out_ch
annels
else
channels
if
with
_conv
:
if
use
_conv
:
self
.
Conv2d_0
=
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
Conv2d_0
=
self
.
Conv2d_0
=
Conv2d
(
channels
,
out_ch
annels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
fir_kernel
=
fir_kernel
self
.
fir_kernel
=
fir_kernel
self
.
with
_conv
=
with
_conv
self
.
use
_conv
=
use
_conv
self
.
out_ch
=
out_ch
self
.
out_ch
annels
=
out_ch
annels
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
if
self
.
with
_conv
:
if
self
.
use
_conv
:
x
=
_conv_downsample_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
x
=
_conv_downsample_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
x
=
x
+
self
.
Conv2d_0
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
else
:
else
:
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
...
@@ -228,13 +230,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -228,13 +230,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
self
,
self
,
image_size
=
1024
,
image_size
=
1024
,
num_channels
=
3
,
num_channels
=
3
,
centered
=
False
,
attn_resolutions
=
(
16
,),
attn_resolutions
=
(
16
,),
ch_mult
=
(
1
,
2
,
4
,
8
,
16
,
32
,
32
,
32
),
ch_mult
=
(
1
,
2
,
4
,
8
,
16
,
32
,
32
,
32
),
conditional
=
True
,
conditional
=
True
,
conv_size
=
3
,
conv_size
=
3
,
dropout
=
0.0
,
dropout
=
0.0
,
embedding_type
=
"fourier"
,
embedding_type
=
"fourier"
,
fir
=
True
,
# TODO (patil-suraj) remove this option from here and pre-trained model configs
fir
=
True
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
fir_kernel
=
(
1
,
3
,
3
,
1
),
fourier_scale
=
16
,
fourier_scale
=
16
,
init_scale
=
0.0
,
init_scale
=
0.0
,
...
@@ -252,12 +255,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -252,12 +255,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
self
.
register_to_config
(
self
.
register_to_config
(
image_size
=
image_size
,
image_size
=
image_size
,
num_channels
=
num_channels
,
num_channels
=
num_channels
,
centered
=
centered
,
attn_resolutions
=
attn_resolutions
,
attn_resolutions
=
attn_resolutions
,
ch_mult
=
ch_mult
,
ch_mult
=
ch_mult
,
conditional
=
conditional
,
conditional
=
conditional
,
conv_size
=
conv_size
,
conv_size
=
conv_size
,
dropout
=
dropout
,
dropout
=
dropout
,
embedding_type
=
embedding_type
,
embedding_type
=
embedding_type
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
fir_kernel
=
fir_kernel
,
fourier_scale
=
fourier_scale
,
fourier_scale
=
fourier_scale
,
init_scale
=
init_scale
,
init_scale
=
init_scale
,
...
@@ -307,24 +312,32 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -307,24 +312,32 @@ class NCSNpp(ModelMixin, ConfigMixin):
modules
.
append
(
Linear
(
nf
*
4
,
nf
*
4
))
modules
.
append
(
Linear
(
nf
*
4
,
nf
*
4
))
AttnBlock
=
functools
.
partial
(
AttentionBlock
,
overwrite_linear
=
True
,
rescale_output_factor
=
math
.
sqrt
(
2.0
))
AttnBlock
=
functools
.
partial
(
AttentionBlock
,
overwrite_linear
=
True
,
rescale_output_factor
=
math
.
sqrt
(
2.0
))
Up_sample
=
functools
.
partial
(
FirUpsample
,
with_conv
=
resamp_with_conv
,
fir_kernel
=
fir_kernel
)
if
self
.
fir
:
Up_sample
=
functools
.
partial
(
FirUpsample
,
fir_kernel
=
fir_kernel
,
use_conv
=
resamp_with_conv
)
else
:
Up_sample
=
functools
.
partial
(
Upsample
,
name
=
"Conv2d_0"
)
if
progressive
==
"output_skip"
:
if
progressive
==
"output_skip"
:
self
.
pyramid_upsample
=
Up_sample
(
fir_kernel
=
fir_ker
ne
l
,
with
_conv
=
False
)
self
.
pyramid_upsample
=
Up_sample
(
channels
=
No
ne
,
use
_conv
=
False
)
elif
progressive
==
"residual"
:
elif
progressive
==
"residual"
:
pyramid_upsample
=
functools
.
partial
(
Up_sample
,
fir_kernel
=
fir_kernel
,
with
_conv
=
True
)
pyramid_upsample
=
functools
.
partial
(
Up_sample
,
use
_conv
=
True
)
Down_sample
=
functools
.
partial
(
FirDownsample
,
with_conv
=
resamp_with_conv
,
fir_kernel
=
fir_kernel
)
if
self
.
fir
:
Down_sample
=
functools
.
partial
(
FirDownsample
,
fir_kernel
=
fir_kernel
,
use_conv
=
resamp_with_conv
)
else
:
Down_sample
=
functools
.
partial
(
Downsample
,
padding
=
0
,
name
=
"Conv2d_0"
)
if
progressive_input
==
"input_skip"
:
if
progressive_input
==
"input_skip"
:
self
.
pyramid_downsample
=
Down_sample
(
fir_kernel
=
fir_ker
ne
l
,
with
_conv
=
False
)
self
.
pyramid_downsample
=
Down_sample
(
channels
=
No
ne
,
use
_conv
=
False
)
elif
progressive_input
==
"residual"
:
elif
progressive_input
==
"residual"
:
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
fir_kernel
=
fir_kernel
,
with
_conv
=
True
)
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
use
_conv
=
True
)
ResnetBlock
=
functools
.
partial
(
ResnetBlock
=
functools
.
partial
(
ResnetBlockBigGANpp
,
ResnetBlockBigGANpp
,
act
=
act
,
act
=
act
,
dropout
=
dropout
,
dropout
=
dropout
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
fir_kernel
=
fir_kernel
,
init_scale
=
init_scale
,
init_scale
=
init_scale
,
skip_rescale
=
skip_rescale
,
skip_rescale
=
skip_rescale
,
...
@@ -361,7 +374,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -361,7 +374,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
in_ch
*=
2
in_ch
*=
2
elif
progressive_input
==
"residual"
:
elif
progressive_input
==
"residual"
:
modules
.
append
(
pyramid_downsample
(
in_ch
=
input_pyramid_ch
,
out_ch
=
in_ch
))
modules
.
append
(
pyramid_downsample
(
channels
=
input_pyramid_ch
,
out_ch
annels
=
in_ch
))
input_pyramid_ch
=
in_ch
input_pyramid_ch
=
in_ch
hs_c
.
append
(
in_ch
)
hs_c
.
append
(
in_ch
)
...
@@ -402,7 +415,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -402,7 +415,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
)
)
pyramid_ch
=
channels
pyramid_ch
=
channels
elif
progressive
==
"residual"
:
elif
progressive
==
"residual"
:
modules
.
append
(
pyramid_upsample
(
in_ch
=
pyramid_ch
,
out_ch
=
in_ch
))
modules
.
append
(
pyramid_upsample
(
channels
=
pyramid_ch
,
out_ch
annels
=
in_ch
))
pyramid_ch
=
in_ch
pyramid_ch
=
in_ch
else
:
else
:
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
...
@@ -446,7 +459,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -446,7 +459,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
temb
=
None
temb
=
None
# If input data is in [0, 1]
# If input data is in [0, 1]
x
=
2
*
x
-
1.0
if
not
self
.
config
.
centered
:
x
=
2
*
x
-
1.0
# Downsampling block
# Downsampling block
input_pyramid
=
None
input_pyramid
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment