Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
8830af11
Commit
8830af11
authored
Jun 30, 2022
by
patil-suraj
Browse files
get rid ResnetBlockDDPMpp and related functions
parent
81e71447
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
139 deletions
+17
-139
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+0
-68
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+17
-71
No files found.
src/diffusers/models/resnet.py
View file @
8830af11
...
@@ -637,62 +637,6 @@ class ResnetBlockBigGANpp(nn.Module):
...
@@ -637,62 +637,6 @@ class ResnetBlockBigGANpp(nn.Module):
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
# unet_score_estimation.py
class
ResnetBlockDDPMpp
(
nn
.
Module
):
"""ResBlock adapted from DDPM."""
def
__init__
(
self
,
act
,
in_ch
,
out_ch
=
None
,
temb_dim
=
None
,
conv_shortcut
=
False
,
dropout
=
0.1
,
skip_rescale
=
False
,
init_scale
=
0.0
,
):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
if
in_ch
!=
out_ch
:
if
conv_shortcut
:
self
.
Conv_2
=
conv3x3
(
in_ch
,
out_ch
)
else
:
self
.
NIN_0
=
NIN
(
in_ch
,
out_ch
)
self
.
skip_rescale
=
skip_rescale
self
.
act
=
act
self
.
out_ch
=
out_ch
self
.
conv_shortcut
=
conv_shortcut
def
forward
(
self
,
x
,
temb
=
None
):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
h
=
self
.
Conv_0
(
h
)
if
temb
is
not
None
:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
h
=
self
.
Dropout_0
(
h
)
h
=
self
.
Conv_1
(
h
)
if
x
.
shape
[
1
]
!=
self
.
out_ch
:
if
self
.
conv_shortcut
:
x
=
self
.
Conv_2
(
x
)
else
:
x
=
self
.
NIN_0
(
x
)
if
not
self
.
skip_rescale
:
return
x
+
h
else
:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
# unet_rl.py
# unet_rl.py
class
ResidualTemporalBlock
(
nn
.
Module
):
class
ResidualTemporalBlock
(
nn
.
Module
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
...
@@ -957,18 +901,6 @@ def downsample_2d(x, k=None, factor=2, gain=1):
...
@@ -957,18 +901,6 @@ def downsample_2d(x, k=None, factor=2, gain=1):
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
class
NIN
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
num_units
,
init_scale
=
0.1
):
super
().
__init__
()
self
.
W
=
nn
.
Parameter
(
default_init
(
scale
=
init_scale
)((
in_dim
,
num_units
)),
requires_grad
=
True
)
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
num_units
),
requires_grad
=
True
)
def
forward
(
self
,
x
):
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
y
=
contract_inner
(
x
,
self
.
W
)
+
self
.
b
return
y
.
permute
(
0
,
3
,
1
,
2
)
def
_setup_kernel
(
k
):
def
_setup_kernel
(
k
):
k
=
np
.
asarray
(
k
,
dtype
=
np
.
float32
)
k
=
np
.
asarray
(
k
,
dtype
=
np
.
float32
)
if
k
.
ndim
==
1
:
if
k
.
ndim
==
1
:
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
8830af11
...
@@ -28,7 +28,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -28,7 +28,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
ResnetBlockBigGANpp
,
ResnetBlockDDPMpp
from
.resnet
import
ResnetBlockBigGANpp
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
...
@@ -305,32 +305,6 @@ def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1
...
@@ -305,32 +305,6 @@ def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1
return
conv
return
conv
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
einsum_str
=
"{},{}->{}"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
def
contract_inner
(
x
,
y
):
"""tensordot(x, y, 1)."""
x_chars
=
list
(
string
.
ascii_lowercase
[:
len
(
x
.
shape
)])
y_chars
=
list
(
string
.
ascii_lowercase
[
len
(
x
.
shape
)
:
len
(
y
.
shape
)
+
len
(
x
.
shape
)])
y_chars
[
0
]
=
x_chars
[
-
1
]
# first axis of y and last of x get summed
out_chars
=
x_chars
[:
-
1
]
+
y_chars
[
1
:]
return
_einsum
(
x_chars
,
y_chars
,
out_chars
,
x
,
y
)
class
NIN
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
num_units
,
init_scale
=
0.1
):
super
().
__init__
()
self
.
W
=
nn
.
Parameter
(
default_init
(
scale
=
init_scale
)((
in_dim
,
num_units
)),
requires_grad
=
True
)
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
num_units
),
requires_grad
=
True
)
def
forward
(
self
,
x
):
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
y
=
contract_inner
(
x
,
self
.
W
)
+
self
.
b
return
y
.
permute
(
0
,
3
,
1
,
2
)
def
get_act
(
nonlinearity
):
def
get_act
(
nonlinearity
):
"""Get activation functions from the config file."""
"""Get activation functions from the config file."""
...
@@ -575,30 +549,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -575,30 +549,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
elif
progressive_input
==
"residual"
:
elif
progressive_input
==
"residual"
:
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
with_conv
=
True
)
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
with_conv
=
True
)
if
resblock_type
==
"ddpm"
:
ResnetBlock
=
functools
.
partial
(
ResnetBlock
=
functools
.
partial
(
ResnetBlockBigGANpp
,
ResnetBlockDDPMpp
,
act
=
act
,
act
=
act
,
dropout
=
dropout
,
dropout
=
dropout
,
fir
=
fir
,
init_scale
=
init_scale
,
fir_kernel
=
fir_kernel
,
skip_rescale
=
skip_rescale
,
init_scale
=
init_scale
,
temb_dim
=
nf
*
4
,
skip_rescale
=
skip_rescale
,
)
temb_dim
=
nf
*
4
,
)
elif
resblock_type
==
"biggan"
:
ResnetBlock
=
functools
.
partial
(
ResnetBlockBigGANpp
,
act
=
act
,
dropout
=
dropout
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
init_scale
=
init_scale
,
skip_rescale
=
skip_rescale
,
temb_dim
=
nf
*
4
,
)
else
:
raise
ValueError
(
f
"resblock type
{
resblock_type
}
unrecognized."
)
# Downsampling block
# Downsampling block
...
@@ -622,10 +582,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -622,10 +582,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c
.
append
(
in_ch
)
hs_c
.
append
(
in_ch
)
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
if
resblock_type
==
"ddpm"
:
modules
.
append
(
ResnetBlock
(
down
=
True
,
in_ch
=
in_ch
))
modules
.
append
(
Downsample
(
in_ch
=
in_ch
))
else
:
modules
.
append
(
ResnetBlock
(
down
=
True
,
in_ch
=
in_ch
))
if
progressive_input
==
"input_skip"
:
if
progressive_input
==
"input_skip"
:
modules
.
append
(
combiner
(
dim1
=
input_pyramid_ch
,
dim2
=
in_ch
))
modules
.
append
(
combiner
(
dim1
=
input_pyramid_ch
,
dim2
=
in_ch
))
...
@@ -678,10 +635,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -678,10 +635,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
if
i_level
!=
0
:
if
i_level
!=
0
:
if
resblock_type
==
"ddpm"
:
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
,
up
=
True
))
modules
.
append
(
Upsample
(
in_ch
=
in_ch
))
else
:
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
,
up
=
True
))
assert
not
hs_c
assert
not
hs_c
...
@@ -741,12 +695,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -741,12 +695,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs
.
append
(
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
if
self
.
resblock_type
==
"ddpm"
:
h
=
modules
[
m_idx
](
hs
[
-
1
],
temb
)
h
=
modules
[
m_idx
](
hs
[
-
1
])
m_idx
+=
1
m_idx
+=
1
else
:
h
=
modules
[
m_idx
](
hs
[
-
1
],
temb
)
m_idx
+=
1
if
self
.
progressive_input
==
"input_skip"
:
if
self
.
progressive_input
==
"input_skip"
:
input_pyramid
=
self
.
pyramid_downsample
(
input_pyramid
)
input_pyramid
=
self
.
pyramid_downsample
(
input_pyramid
)
...
@@ -818,12 +768,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -818,12 +768,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
raise
ValueError
(
f
"
{
self
.
progressive
}
is not a valid name"
)
raise
ValueError
(
f
"
{
self
.
progressive
}
is not a valid name"
)
if
i_level
!=
0
:
if
i_level
!=
0
:
if
self
.
resblock_type
==
"ddpm"
:
h
=
modules
[
m_idx
](
h
,
temb
)
h
=
modules
[
m_idx
](
h
)
m_idx
+=
1
m_idx
+=
1
else
:
h
=
modules
[
m_idx
](
h
,
temb
)
m_idx
+=
1
assert
not
hs
assert
not
hs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment