Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c691bb2f
Unverified
Commit
c691bb2f
authored
Jul 01, 2022
by
Suraj Patil
Committed by
GitHub
Jul 01, 2022
Browse files
Merge pull request #60 from huggingface/add-fir-back
fix unde sde for vp model.
parents
abedfb08
4c293e0e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
75 additions
and
33 deletions
+75
-33
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+36
-8
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+39
-25
No files found.
src/diffusers/models/resnet.py
View file @
c691bb2f
from
abc
import
abstractmethod
from
functools
import
partial
import
numpy
as
np
import
torch
...
...
@@ -78,18 +79,25 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
,
name
=
"conv"
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
self
.
use_conv_transpose
=
use_conv_transpose
self
.
name
=
name
conv
=
None
if
use_conv_transpose
:
self
.
conv
=
conv_transpose_nd
(
dims
,
channels
,
self
.
out_channels
,
4
,
2
,
1
)
conv
=
conv_transpose_nd
(
dims
,
channels
,
self
.
out_channels
,
4
,
2
,
1
)
elif
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
if
name
==
"conv"
:
self
.
conv
=
conv
else
:
self
.
Conv2d_0
=
conv
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
...
...
@@ -102,7 +110,10 @@ class Upsample(nn.Module):
x
=
F
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
if
self
.
name
==
"conv"
:
x
=
self
.
conv
(
x
)
else
:
x
=
self
.
Conv2d_0
(
x
)
return
x
...
...
@@ -134,6 +145,8 @@ class Downsample(nn.Module):
if
name
==
"conv"
:
self
.
conv
=
conv
elif
name
==
"Conv2d_0"
:
self
.
Conv2d_0
=
conv
else
:
self
.
op
=
conv
...
...
@@ -145,6 +158,8 @@ class Downsample(nn.Module):
if
self
.
name
==
"conv"
:
return
self
.
conv
(
x
)
elif
self
.
name
==
"Conv2d_0"
:
return
self
.
Conv2d_0
(
x
)
else
:
return
self
.
op
(
x
)
...
...
@@ -390,6 +405,7 @@ class ResnetBlockBigGANpp(nn.Module):
up
=
False
,
down
=
False
,
dropout
=
0.1
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
skip_rescale
=
True
,
init_scale
=
0.0
,
...
...
@@ -400,8 +416,20 @@ class ResnetBlockBigGANpp(nn.Module):
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
up
=
up
self
.
down
=
down
self
.
fir
=
fir
self
.
fir_kernel
=
fir_kernel
if
self
.
up
:
if
self
.
fir
:
self
.
upsample
=
partial
(
upsample_2d
,
k
=
self
.
fir_kernel
,
factor
=
2
)
else
:
self
.
upsample
=
partial
(
F
.
interpolate
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
elif
self
.
down
:
if
self
.
fir
:
self
.
downsample
=
partial
(
downsample_2d
,
k
=
self
.
fir_kernel
,
factor
=
2
)
else
:
self
.
downsample
=
partial
(
F
.
avg_pool2d
,
kernel_size
=
2
,
stride
=
2
)
self
.
Conv_0
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
padding
=
1
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
...
...
@@ -424,11 +452,11 @@ class ResnetBlockBigGANpp(nn.Module):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
if
self
.
up
:
h
=
upsample
_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
upsample
_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
h
=
self
.
upsample
(
h
)
x
=
self
.
upsample
(
x
)
elif
self
.
down
:
h
=
downsample
_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample
_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
h
=
self
.
downsample
(
h
)
x
=
self
.
downsample
(
x
)
h
=
self
.
Conv_0
(
h
)
# Add bias to each feature map conditioned on the time embedding
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
c691bb2f
...
...
@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
ResnetBlockBigGANpp
,
downsample_2d
,
upfirdn2d
,
upsample_2d
from
.resnet
import
Downsample
,
ResnetBlockBigGANpp
,
Upsample
,
downsample_2d
,
upfirdn2d
,
upsample_2d
def
_setup_kernel
(
k
):
...
...
@@ -184,18 +184,19 @@ class Combine(nn.Module):
class
FirUpsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
def
__init__
(
self
,
channels
=
None
,
out_ch
annels
=
None
,
use
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
if
with
_conv
:
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
with
_conv
=
with
_conv
out_ch
annels
=
out_ch
annels
if
out_ch
annels
else
channels
if
use
_conv
:
self
.
Conv2d_0
=
Conv2d
(
channels
,
out_ch
annels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
use
_conv
=
use
_conv
self
.
fir_kernel
=
fir_kernel
self
.
out_ch
=
out_ch
self
.
out_ch
annels
=
out_ch
annels
def
forward
(
self
,
x
):
if
self
.
with
_conv
:
if
self
.
use
_conv
:
h
=
_upsample_conv_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
h
=
h
+
self
.
Conv2d_0
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
else
:
h
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
...
...
@@ -203,18 +204,19 @@ class FirUpsample(nn.Module):
class
FirDownsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
def
__init__
(
self
,
channels
=
None
,
out_ch
annels
=
None
,
use
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
if
with
_conv
:
self
.
Conv2d_0
=
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
out_ch
annels
=
out_ch
annels
if
out_ch
annels
else
channels
if
use
_conv
:
self
.
Conv2d_0
=
self
.
Conv2d_0
=
Conv2d
(
channels
,
out_ch
annels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
fir_kernel
=
fir_kernel
self
.
with
_conv
=
with
_conv
self
.
out_ch
=
out_ch
self
.
use
_conv
=
use
_conv
self
.
out_ch
annels
=
out_ch
annels
def
forward
(
self
,
x
):
if
self
.
with
_conv
:
if
self
.
use
_conv
:
x
=
_conv_downsample_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
x
=
x
+
self
.
Conv2d_0
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
else
:
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
...
...
@@ -228,13 +230,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
self
,
image_size
=
1024
,
num_channels
=
3
,
centered
=
False
,
attn_resolutions
=
(
16
,),
ch_mult
=
(
1
,
2
,
4
,
8
,
16
,
32
,
32
,
32
),
conditional
=
True
,
conv_size
=
3
,
dropout
=
0.0
,
embedding_type
=
"fourier"
,
fir
=
True
,
# TODO (patil-suraj) remove this option from here and pre-trained model configs
fir
=
True
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
fourier_scale
=
16
,
init_scale
=
0.0
,
...
...
@@ -252,12 +255,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
self
.
register_to_config
(
image_size
=
image_size
,
num_channels
=
num_channels
,
centered
=
centered
,
attn_resolutions
=
attn_resolutions
,
ch_mult
=
ch_mult
,
conditional
=
conditional
,
conv_size
=
conv_size
,
dropout
=
dropout
,
embedding_type
=
embedding_type
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
fourier_scale
=
fourier_scale
,
init_scale
=
init_scale
,
...
...
@@ -307,24 +312,32 @@ class NCSNpp(ModelMixin, ConfigMixin):
modules
.
append
(
Linear
(
nf
*
4
,
nf
*
4
))
AttnBlock
=
functools
.
partial
(
AttentionBlock
,
overwrite_linear
=
True
,
rescale_output_factor
=
math
.
sqrt
(
2.0
))
Up_sample
=
functools
.
partial
(
FirUpsample
,
with_conv
=
resamp_with_conv
,
fir_kernel
=
fir_kernel
)
if
self
.
fir
:
Up_sample
=
functools
.
partial
(
FirUpsample
,
fir_kernel
=
fir_kernel
,
use_conv
=
resamp_with_conv
)
else
:
Up_sample
=
functools
.
partial
(
Upsample
,
name
=
"Conv2d_0"
)
if
progressive
==
"output_skip"
:
self
.
pyramid_upsample
=
Up_sample
(
fir_kernel
=
fir_ker
ne
l
,
with
_conv
=
False
)
self
.
pyramid_upsample
=
Up_sample
(
channels
=
No
ne
,
use
_conv
=
False
)
elif
progressive
==
"residual"
:
pyramid_upsample
=
functools
.
partial
(
Up_sample
,
fir_kernel
=
fir_kernel
,
with
_conv
=
True
)
pyramid_upsample
=
functools
.
partial
(
Up_sample
,
use
_conv
=
True
)
Down_sample
=
functools
.
partial
(
FirDownsample
,
with_conv
=
resamp_with_conv
,
fir_kernel
=
fir_kernel
)
if
self
.
fir
:
Down_sample
=
functools
.
partial
(
FirDownsample
,
fir_kernel
=
fir_kernel
,
use_conv
=
resamp_with_conv
)
else
:
Down_sample
=
functools
.
partial
(
Downsample
,
padding
=
0
,
name
=
"Conv2d_0"
)
if
progressive_input
==
"input_skip"
:
self
.
pyramid_downsample
=
Down_sample
(
fir_kernel
=
fir_ker
ne
l
,
with
_conv
=
False
)
self
.
pyramid_downsample
=
Down_sample
(
channels
=
No
ne
,
use
_conv
=
False
)
elif
progressive_input
==
"residual"
:
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
fir_kernel
=
fir_kernel
,
with
_conv
=
True
)
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
use
_conv
=
True
)
ResnetBlock
=
functools
.
partial
(
ResnetBlockBigGANpp
,
act
=
act
,
dropout
=
dropout
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
init_scale
=
init_scale
,
skip_rescale
=
skip_rescale
,
...
...
@@ -361,7 +374,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
in_ch
*=
2
elif
progressive_input
==
"residual"
:
modules
.
append
(
pyramid_downsample
(
in_ch
=
input_pyramid_ch
,
out_ch
=
in_ch
))
modules
.
append
(
pyramid_downsample
(
channels
=
input_pyramid_ch
,
out_ch
annels
=
in_ch
))
input_pyramid_ch
=
in_ch
hs_c
.
append
(
in_ch
)
...
...
@@ -402,7 +415,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
)
pyramid_ch
=
channels
elif
progressive
==
"residual"
:
modules
.
append
(
pyramid_upsample
(
in_ch
=
pyramid_ch
,
out_ch
=
in_ch
))
modules
.
append
(
pyramid_upsample
(
channels
=
pyramid_ch
,
out_ch
annels
=
in_ch
))
pyramid_ch
=
in_ch
else
:
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
...
...
@@ -446,7 +459,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
temb
=
None
# If input data is in [0, 1]
x
=
2
*
x
-
1.0
if
not
self
.
config
.
centered
:
x
=
2
*
x
-
1.0
# Downsampling block
input_pyramid
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment