Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
8b0bc596
Unverified
Commit
8b0bc596
authored
Jun 30, 2022
by
Suraj Patil
Committed by
GitHub
Jun 30, 2022
Browse files
Merge pull request #52 from huggingface/clean-unet-sde
Clean UNetNCSNpp
parents
7e0fd19f
f35387b3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
103 additions
and
549 deletions
+103
-549
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+17
-148
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+86
-401
No files found.
src/diffusers/models/resnet.py
View file @
8b0bc596
...
@@ -579,7 +579,6 @@ class ResnetBlockBigGANpp(nn.Module):
...
@@ -579,7 +579,6 @@ class ResnetBlockBigGANpp(nn.Module):
up
=
False
,
up
=
False
,
down
=
False
,
down
=
False
,
dropout
=
0.1
,
dropout
=
0.1
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
fir_kernel
=
(
1
,
3
,
3
,
1
),
skip_rescale
=
True
,
skip_rescale
=
True
,
init_scale
=
0.0
,
init_scale
=
0.0
,
...
@@ -590,20 +589,20 @@ class ResnetBlockBigGANpp(nn.Module):
...
@@ -590,20 +589,20 @@ class ResnetBlockBigGANpp(nn.Module):
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
up
=
up
self
.
up
=
up
self
.
down
=
down
self
.
down
=
down
self
.
fir
=
fir
self
.
fir_kernel
=
fir_kernel
self
.
fir_kernel
=
fir_kernel
self
.
Conv_0
=
conv
3x3
(
in_ch
,
out_ch
)
self
.
Conv_0
=
conv
2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
padding
=
1
)
if
temb_dim
is
not
None
:
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
shape
)
self
.
Dense_0
.
weight
.
data
=
variance_scaling
()(
self
.
Dense_0
.
weight
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv
3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
self
.
Conv_1
=
conv
2d
(
out_ch
,
out_ch
,
init_scale
=
init_scale
,
kernel_size
=
3
,
padding
=
1
)
if
in_ch
!=
out_ch
or
up
or
down
:
if
in_ch
!=
out_ch
or
up
or
down
:
self
.
Conv_2
=
conv1x1
(
in_ch
,
out_ch
)
# 1x1 convolution with DDPM initialization.
self
.
Conv_2
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
1
,
padding
=
0
)
self
.
skip_rescale
=
skip_rescale
self
.
skip_rescale
=
skip_rescale
self
.
act
=
act
self
.
act
=
act
...
@@ -614,19 +613,11 @@ class ResnetBlockBigGANpp(nn.Module):
...
@@ -614,19 +613,11 @@ class ResnetBlockBigGANpp(nn.Module):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
if
self
.
up
:
if
self
.
up
:
if
self
.
fir
:
h
=
upsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
h
=
upsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
x
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
naive_upsample_2d
(
h
,
factor
=
2
)
x
=
naive_upsample_2d
(
x
,
factor
=
2
)
elif
self
.
down
:
elif
self
.
down
:
if
self
.
fir
:
h
=
downsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
h
=
downsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
naive_downsample_2d
(
h
,
factor
=
2
)
x
=
naive_downsample_2d
(
x
,
factor
=
2
)
h
=
self
.
Conv_0
(
h
)
h
=
self
.
Conv_0
(
h
)
# Add bias to each feature map conditioned on the time embedding
# Add bias to each feature map conditioned on the time embedding
...
@@ -645,62 +636,6 @@ class ResnetBlockBigGANpp(nn.Module):
...
@@ -645,62 +636,6 @@ class ResnetBlockBigGANpp(nn.Module):
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
# unet_score_estimation.py
class
ResnetBlockDDPMpp
(
nn
.
Module
):
"""ResBlock adapted from DDPM."""
def
__init__
(
self
,
act
,
in_ch
,
out_ch
=
None
,
temb_dim
=
None
,
conv_shortcut
=
False
,
dropout
=
0.1
,
skip_rescale
=
False
,
init_scale
=
0.0
,
):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
if
in_ch
!=
out_ch
:
if
conv_shortcut
:
self
.
Conv_2
=
conv3x3
(
in_ch
,
out_ch
)
else
:
self
.
NIN_0
=
NIN
(
in_ch
,
out_ch
)
self
.
skip_rescale
=
skip_rescale
self
.
act
=
act
self
.
out_ch
=
out_ch
self
.
conv_shortcut
=
conv_shortcut
def
forward
(
self
,
x
,
temb
=
None
):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
h
=
self
.
Conv_0
(
h
)
if
temb
is
not
None
:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
h
=
self
.
Dropout_0
(
h
)
h
=
self
.
Conv_1
(
h
)
if
x
.
shape
[
1
]
!=
self
.
out_ch
:
if
self
.
conv_shortcut
:
x
=
self
.
Conv_2
(
x
)
else
:
x
=
self
.
NIN_0
(
x
)
if
not
self
.
skip_rescale
:
return
x
+
h
else
:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
# unet_rl.py
# unet_rl.py
class
ResidualTemporalBlock
(
nn
.
Module
):
class
ResidualTemporalBlock
(
nn
.
Module
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
...
@@ -818,32 +753,17 @@ class RearrangeDim(nn.Module):
...
@@ -818,32 +753,17 @@ class RearrangeDim(nn.Module):
raise
ValueError
(
f
"`len(tensor)`:
{
len
(
tensor
)
}
has to be 2, 3 or 4."
)
raise
ValueError
(
f
"`len(tensor)`:
{
len
(
tensor
)
}
has to be 2, 3 or 4."
)
def
conv
1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
def
conv
2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
1
):
"""
1x1
convolution with DDPM initialization."""
"""
nXn
convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
conv
.
weight
.
data
=
variance_scaling
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
return
conv
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
def
variance_scaling
(
scale
=
1.0
,
in_axis
=
1
,
out_axis
=
0
,
dtype
=
torch
.
float32
,
device
=
"cpu"
):
"""3x3 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
def
default_init
(
scale
=
1.0
):
"""The same initialization used in DDPM."""
scale
=
1e-10
if
scale
==
0
else
scale
return
variance_scaling
(
scale
,
"fan_avg"
,
"uniform"
)
def
variance_scaling
(
scale
,
mode
,
distribution
,
in_axis
=
1
,
out_axis
=
0
,
dtype
=
torch
.
float32
,
device
=
"cpu"
):
"""Ported from JAX."""
"""Ported from JAX."""
scale
=
1e-10
if
scale
==
0
else
scale
def
_compute_fans
(
shape
,
in_axis
=
1
,
out_axis
=
0
):
def
_compute_fans
(
shape
,
in_axis
=
1
,
out_axis
=
0
):
receptive_field_size
=
np
.
prod
(
shape
)
/
shape
[
in_axis
]
/
shape
[
out_axis
]
receptive_field_size
=
np
.
prod
(
shape
)
/
shape
[
in_axis
]
/
shape
[
out_axis
]
...
@@ -853,21 +773,9 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
...
@@ -853,21 +773,9 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
def
init
(
shape
,
dtype
=
dtype
,
device
=
device
):
def
init
(
shape
,
dtype
=
dtype
,
device
=
device
):
fan_in
,
fan_out
=
_compute_fans
(
shape
,
in_axis
,
out_axis
)
fan_in
,
fan_out
=
_compute_fans
(
shape
,
in_axis
,
out_axis
)
if
mode
==
"fan_in"
:
denominator
=
(
fan_in
+
fan_out
)
/
2
denominator
=
fan_in
elif
mode
==
"fan_out"
:
denominator
=
fan_out
elif
mode
==
"fan_avg"
:
denominator
=
(
fan_in
+
fan_out
)
/
2
else
:
raise
ValueError
(
"invalid mode for variance scaling initializer: {}"
.
format
(
mode
))
variance
=
scale
/
denominator
variance
=
scale
/
denominator
if
distribution
==
"normal"
:
return
(
torch
.
rand
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
2.0
-
1.0
)
*
np
.
sqrt
(
3
*
variance
)
return
torch
.
randn
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
np
.
sqrt
(
variance
)
elif
distribution
==
"uniform"
:
return
(
torch
.
rand
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
2.0
-
1.0
)
*
np
.
sqrt
(
3
*
variance
)
else
:
raise
ValueError
(
"invalid distribution for variance scaling initializer"
)
return
init
return
init
...
@@ -965,31 +873,6 @@ def downsample_2d(x, k=None, factor=2, gain=1):
...
@@ -965,31 +873,6 @@ def downsample_2d(x, k=None, factor=2, gain=1):
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
def
naive_upsample_2d
(
x
,
factor
=
2
):
_N
,
C
,
H
,
W
=
x
.
shape
x
=
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
,
1
,
W
,
1
))
x
=
x
.
repeat
(
1
,
1
,
1
,
factor
,
1
,
factor
)
return
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
*
factor
,
W
*
factor
))
def
naive_downsample_2d
(
x
,
factor
=
2
):
_N
,
C
,
H
,
W
=
x
.
shape
x
=
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
//
factor
,
factor
,
W
//
factor
,
factor
))
return
torch
.
mean
(
x
,
dim
=
(
3
,
5
))
class
NIN
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
num_units
,
init_scale
=
0.1
):
super
().
__init__
()
self
.
W
=
nn
.
Parameter
(
default_init
(
scale
=
init_scale
)((
in_dim
,
num_units
)),
requires_grad
=
True
)
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
num_units
),
requires_grad
=
True
)
def
forward
(
self
,
x
):
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
y
=
contract_inner
(
x
,
self
.
W
)
+
self
.
b
return
y
.
permute
(
0
,
3
,
1
,
2
)
def
_setup_kernel
(
k
):
def
_setup_kernel
(
k
):
k
=
np
.
asarray
(
k
,
dtype
=
np
.
float32
)
k
=
np
.
asarray
(
k
,
dtype
=
np
.
float32
)
if
k
.
ndim
==
1
:
if
k
.
ndim
==
1
:
...
@@ -998,17 +881,3 @@ def _setup_kernel(k):
...
@@ -998,17 +881,3 @@ def _setup_kernel(k):
assert
k
.
ndim
==
2
assert
k
.
ndim
==
2
assert
k
.
shape
[
0
]
==
k
.
shape
[
1
]
assert
k
.
shape
[
0
]
==
k
.
shape
[
1
]
return
k
return
k
def
contract_inner
(
x
,
y
):
"""tensordot(x, y, 1)."""
x_chars
=
list
(
string
.
ascii_lowercase
[:
len
(
x
.
shape
)])
y_chars
=
list
(
string
.
ascii_lowercase
[
len
(
x
.
shape
)
:
len
(
y
.
shape
)
+
len
(
x
.
shape
)])
y_chars
[
0
]
=
x_chars
[
-
1
]
# first axis of y and last of x get summed
out_chars
=
x_chars
[:
-
1
]
+
y_chars
[
1
:]
return
_einsum
(
x_chars
,
y_chars
,
out_chars
,
x
,
y
)
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
einsum_str
=
"{},{}->{}"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
src/diffusers/models/unet_sde_score_estimation.py
View file @
8b0bc596
...
@@ -17,7 +17,6 @@
...
@@ -17,7 +17,6 @@
import
functools
import
functools
import
math
import
math
import
string
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -28,116 +27,21 @@ from ..configuration_utils import ConfigMixin
...
@@ -28,116 +27,21 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
ResnetBlockBigGANpp
,
ResnetBlockDDPMpp
from
.resnet
import
ResnetBlockBigGANpp
,
downsample_2d
,
upfirdn2d
,
upsample_2d
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
def
_setup_kernel
(
k
):
return
upfirdn2d_native
(
input
,
kernel
,
up
,
up
,
down
,
down
,
pad
[
0
],
pad
[
1
],
pad
[
0
],
pad
[
1
])
k
=
np
.
asarray
(
k
,
dtype
=
np
.
float32
)
if
k
.
ndim
==
1
:
k
=
np
.
outer
(
k
,
k
)
def
upfirdn2d_native
(
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
):
k
/=
np
.
sum
(
k
)
_
,
channel
,
in_h
,
in_w
=
input
.
shape
assert
k
.
ndim
==
2
input
=
input
.
reshape
(
-
1
,
in_h
,
in_w
,
1
)
assert
k
.
shape
[
0
]
==
k
.
shape
[
1
]
return
k
_
,
in_h
,
in_w
,
minor
=
input
.
shape
kernel_h
,
kernel_w
=
kernel
.
shape
out
=
input
.
view
(
-
1
,
in_h
,
1
,
in_w
,
1
,
minor
)
out
=
F
.
pad
(
out
,
[
0
,
0
,
0
,
up_x
-
1
,
0
,
0
,
0
,
up_y
-
1
])
out
=
out
.
view
(
-
1
,
in_h
*
up_y
,
in_w
*
up_x
,
minor
)
out
=
F
.
pad
(
out
,
[
0
,
0
,
max
(
pad_x0
,
0
),
max
(
pad_x1
,
0
),
max
(
pad_y0
,
0
),
max
(
pad_y1
,
0
)])
out
=
out
[
:,
max
(
-
pad_y0
,
0
)
:
out
.
shape
[
1
]
-
max
(
-
pad_y1
,
0
),
max
(
-
pad_x0
,
0
)
:
out
.
shape
[
2
]
-
max
(
-
pad_x1
,
0
),
:,
]
out
=
out
.
permute
(
0
,
3
,
1
,
2
)
out
=
out
.
reshape
([
-
1
,
1
,
in_h
*
up_y
+
pad_y0
+
pad_y1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
])
w
=
torch
.
flip
(
kernel
,
[
0
,
1
]).
view
(
1
,
1
,
kernel_h
,
kernel_w
)
out
=
F
.
conv2d
(
out
,
w
)
out
=
out
.
reshape
(
-
1
,
minor
,
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
+
1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
+
1
,
)
out
=
out
.
permute
(
0
,
2
,
3
,
1
)
out
=
out
[:,
::
down_y
,
::
down_x
,
:]
out_h
=
(
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
)
//
down_y
+
1
out_w
=
(
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
)
//
down_x
+
1
return
out
.
view
(
-
1
,
channel
,
out_h
,
out_w
)
# Function ported from StyleGAN2
def
get_weight
(
module
,
shape
,
weight_var
=
"weight"
,
kernel_init
=
None
):
"""Get/create weight tensor for a convolution or fully-connected layer."""
return
module
.
param
(
weight_var
,
kernel_init
,
shape
)
class
Conv2d
(
nn
.
Module
):
"""Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
def
__init__
(
self
,
in_ch
,
out_ch
,
kernel
,
up
=
False
,
down
=
False
,
resample_kernel
=
(
1
,
3
,
3
,
1
),
use_bias
=
True
,
kernel_init
=
None
,
):
super
().
__init__
()
assert
not
(
up
and
down
)
assert
kernel
>=
1
and
kernel
%
2
==
1
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
out_ch
,
in_ch
,
kernel
,
kernel
))
if
kernel_init
is
not
None
:
self
.
weight
.
data
=
kernel_init
(
self
.
weight
.
data
.
shape
)
if
use_bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_ch
))
self
.
up
=
up
self
.
down
=
down
self
.
resample_kernel
=
resample_kernel
self
.
kernel
=
kernel
self
.
use_bias
=
use_bias
def
forward
(
self
,
x
):
if
self
.
up
:
x
=
upsample_conv_2d
(
x
,
self
.
weight
,
k
=
self
.
resample_kernel
)
elif
self
.
down
:
x
=
conv_downsample_2d
(
x
,
self
.
weight
,
k
=
self
.
resample_kernel
)
else
:
x
=
F
.
conv2d
(
x
,
self
.
weight
,
stride
=
1
,
padding
=
self
.
kernel
//
2
)
if
self
.
use_bias
:
x
=
x
+
self
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
return
x
def
naive_upsample_2d
(
x
,
factor
=
2
):
_N
,
C
,
H
,
W
=
x
.
shape
x
=
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
,
1
,
W
,
1
))
x
=
x
.
repeat
(
1
,
1
,
1
,
factor
,
1
,
factor
)
return
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
*
factor
,
W
*
factor
))
def
naive_downsample_2d
(
x
,
factor
=
2
):
_N
,
C
,
H
,
W
=
x
.
shape
x
=
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
//
factor
,
factor
,
W
//
factor
,
factor
))
return
torch
.
mean
(
x
,
dim
=
(
3
,
5
))
def
upsample_conv_2d
(
x
,
w
,
k
=
None
,
factor
=
2
,
gain
=
1
):
def
_
upsample_conv_2d
(
x
,
w
,
k
=
None
,
factor
=
2
,
gain
=
1
):
"""Fused `upsample_2d()` followed by `
tf.nn.c
onv2d()`.
"""Fused `upsample_2d()` followed by `
C
onv2d()`.
Args:
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
...
@@ -176,13 +80,13 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
...
@@ -176,13 +80,13 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
# Determine data dimensions.
# Determine data dimensions.
stride
=
[
1
,
1
,
factor
,
factor
]
stride
=
[
1
,
1
,
factor
,
factor
]
output_shape
=
((
_
shape
(
x
,
2
)
-
1
)
*
factor
+
convH
,
(
_
shape
(
x
,
3
)
-
1
)
*
factor
+
convW
)
output_shape
=
((
x
.
shape
[
2
]
-
1
)
*
factor
+
convH
,
(
x
.
shape
[
3
]
-
1
)
*
factor
+
convW
)
output_padding
=
(
output_padding
=
(
output_shape
[
0
]
-
(
_
shape
(
x
,
2
)
-
1
)
*
stride
[
0
]
-
convH
,
output_shape
[
0
]
-
(
x
.
shape
[
2
]
-
1
)
*
stride
[
0
]
-
convH
,
output_shape
[
1
]
-
(
_
shape
(
x
,
3
)
-
1
)
*
stride
[
1
]
-
convW
,
output_shape
[
1
]
-
(
x
.
shape
[
3
]
-
1
)
*
stride
[
1
]
-
convW
,
)
)
assert
output_padding
[
0
]
>=
0
and
output_padding
[
1
]
>=
0
assert
output_padding
[
0
]
>=
0
and
output_padding
[
1
]
>=
0
num_groups
=
_
shape
(
x
,
1
)
//
inC
num_groups
=
x
.
shape
[
1
]
//
inC
# Transpose weights.
# Transpose weights.
w
=
torch
.
reshape
(
w
,
(
num_groups
,
-
1
,
inC
,
convH
,
convW
))
w
=
torch
.
reshape
(
w
,
(
num_groups
,
-
1
,
inC
,
convH
,
convW
))
...
@@ -190,21 +94,12 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
...
@@ -190,21 +94,12 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
w
=
torch
.
reshape
(
w
,
(
num_groups
*
inC
,
-
1
,
convH
,
convW
))
w
=
torch
.
reshape
(
w
,
(
num_groups
*
inC
,
-
1
,
convH
,
convW
))
x
=
F
.
conv_transpose2d
(
x
,
w
,
stride
=
stride
,
output_padding
=
output_padding
,
padding
=
0
)
x
=
F
.
conv_transpose2d
(
x
,
w
,
stride
=
stride
,
output_padding
=
output_padding
,
padding
=
0
)
# Original TF code.
# x = tf.nn.conv2d_transpose(
# x,
# w,
# output_shape=output_shape,
# strides=stride,
# padding='VALID',
# data_format=data_format)
# JAX equivalent
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
pad
=
((
p
+
1
)
//
2
+
factor
-
1
,
p
//
2
+
1
))
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
pad
=
((
p
+
1
)
//
2
+
factor
-
1
,
p
//
2
+
1
))
def
conv_downsample_2d
(
x
,
w
,
k
=
None
,
factor
=
2
,
gain
=
1
):
def
_
conv_downsample_2d
(
x
,
w
,
k
=
None
,
factor
=
2
,
gain
=
1
):
"""Fused `
tf.nn.c
onv2d()` followed by `downsample_2d()`.
"""Fused `
C
onv2d()` followed by `downsample_2d()`.
Args:
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
...
@@ -235,138 +130,9 @@ def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
...
@@ -235,138 +130,9 @@ def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
return
F
.
conv2d
(
x
,
w
,
stride
=
s
,
padding
=
0
)
return
F
.
conv2d
(
x
,
w
,
stride
=
s
,
padding
=
0
)
def
_setup_kernel
(
k
):
def
_variance_scaling
(
scale
=
1.0
,
in_axis
=
1
,
out_axis
=
0
,
dtype
=
torch
.
float32
,
device
=
"cpu"
):
k
=
np
.
asarray
(
k
,
dtype
=
np
.
float32
)
if
k
.
ndim
==
1
:
k
=
np
.
outer
(
k
,
k
)
k
/=
np
.
sum
(
k
)
assert
k
.
ndim
==
2
assert
k
.
shape
[
0
]
==
k
.
shape
[
1
]
return
k
def
_shape
(
x
,
dim
):
return
x
.
shape
[
dim
]
def
upsample_2d
(
x
,
k
=
None
,
factor
=
2
,
gain
=
1
):
r
"""Upsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
multiple of the upsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
if
k
is
None
:
k
=
[
1
]
*
factor
k
=
_setup_kernel
(
k
)
*
(
gain
*
(
factor
**
2
))
p
=
k
.
shape
[
0
]
-
factor
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
up
=
factor
,
pad
=
((
p
+
1
)
//
2
+
factor
-
1
,
p
//
2
))
def
downsample_2d
(
x
,
k
=
None
,
factor
=
2
,
gain
=
1
):
r
"""Downsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
if
k
is
None
:
k
=
[
1
]
*
factor
k
=
_setup_kernel
(
k
)
*
gain
p
=
k
.
shape
[
0
]
-
factor
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
"""1x1 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
"""3x3 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
einsum_str
=
"{},{}->{}"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
def
contract_inner
(
x
,
y
):
"""tensordot(x, y, 1)."""
x_chars
=
list
(
string
.
ascii_lowercase
[:
len
(
x
.
shape
)])
y_chars
=
list
(
string
.
ascii_lowercase
[
len
(
x
.
shape
)
:
len
(
y
.
shape
)
+
len
(
x
.
shape
)])
y_chars
[
0
]
=
x_chars
[
-
1
]
# first axis of y and last of x get summed
out_chars
=
x_chars
[:
-
1
]
+
y_chars
[
1
:]
return
_einsum
(
x_chars
,
y_chars
,
out_chars
,
x
,
y
)
class
NIN
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
num_units
,
init_scale
=
0.1
):
super
().
__init__
()
self
.
W
=
nn
.
Parameter
(
default_init
(
scale
=
init_scale
)((
in_dim
,
num_units
)),
requires_grad
=
True
)
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
num_units
),
requires_grad
=
True
)
def
forward
(
self
,
x
):
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
y
=
contract_inner
(
x
,
self
.
W
)
+
self
.
b
return
y
.
permute
(
0
,
3
,
1
,
2
)
def
get_act
(
nonlinearity
):
"""Get activation functions from the config file."""
if
nonlinearity
.
lower
()
==
"elu"
:
return
nn
.
ELU
()
elif
nonlinearity
.
lower
()
==
"relu"
:
return
nn
.
ReLU
()
elif
nonlinearity
.
lower
()
==
"lrelu"
:
return
nn
.
LeakyReLU
(
negative_slope
=
0.2
)
elif
nonlinearity
.
lower
()
==
"swish"
:
return
nn
.
SiLU
()
else
:
raise
NotImplementedError
(
"activation function does not exist!"
)
def
default_init
(
scale
=
1.0
):
"""The same initialization used in DDPM."""
scale
=
1e-10
if
scale
==
0
else
scale
return
variance_scaling
(
scale
,
"fan_avg"
,
"uniform"
)
def
variance_scaling
(
scale
,
mode
,
distribution
,
in_axis
=
1
,
out_axis
=
0
,
dtype
=
torch
.
float32
,
device
=
"cpu"
):
"""Ported from JAX."""
"""Ported from JAX."""
scale
=
1e-10
if
scale
==
0
else
scale
def
_compute_fans
(
shape
,
in_axis
=
1
,
out_axis
=
0
):
def
_compute_fans
(
shape
,
in_axis
=
1
,
out_axis
=
0
):
receptive_field_size
=
np
.
prod
(
shape
)
/
shape
[
in_axis
]
/
shape
[
out_axis
]
receptive_field_size
=
np
.
prod
(
shape
)
/
shape
[
in_axis
]
/
shape
[
out_axis
]
...
@@ -376,31 +142,35 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
...
@@ -376,31 +142,35 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
def
init
(
shape
,
dtype
=
dtype
,
device
=
device
):
def
init
(
shape
,
dtype
=
dtype
,
device
=
device
):
fan_in
,
fan_out
=
_compute_fans
(
shape
,
in_axis
,
out_axis
)
fan_in
,
fan_out
=
_compute_fans
(
shape
,
in_axis
,
out_axis
)
if
mode
==
"fan_in"
:
denominator
=
(
fan_in
+
fan_out
)
/
2
denominator
=
fan_in
elif
mode
==
"fan_out"
:
denominator
=
fan_out
elif
mode
==
"fan_avg"
:
denominator
=
(
fan_in
+
fan_out
)
/
2
else
:
raise
ValueError
(
"invalid mode for variance scaling initializer: {}"
.
format
(
mode
))
variance
=
scale
/
denominator
variance
=
scale
/
denominator
if
distribution
==
"normal"
:
return
(
torch
.
rand
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
2.0
-
1.0
)
*
np
.
sqrt
(
3
*
variance
)
return
torch
.
randn
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
np
.
sqrt
(
variance
)
elif
distribution
==
"uniform"
:
return
(
torch
.
rand
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
2.0
-
1.0
)
*
np
.
sqrt
(
3
*
variance
)
else
:
raise
ValueError
(
"invalid distribution for variance scaling initializer"
)
return
init
return
init
def
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
1
):
"""nXn convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
.
weight
.
data
=
_variance_scaling
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
def
Linear
(
dim_in
,
dim_out
):
linear
=
nn
.
Linear
(
dim_in
,
dim_out
)
linear
.
weight
.
data
=
_variance_scaling
()(
linear
.
weight
.
shape
)
nn
.
init
.
zeros_
(
linear
.
bias
)
return
linear
class
Combine
(
nn
.
Module
):
class
Combine
(
nn
.
Module
):
"""Combine information from skip connections."""
"""Combine information from skip connections."""
def
__init__
(
self
,
dim1
,
dim2
,
method
=
"cat"
):
def
__init__
(
self
,
dim1
,
dim2
,
method
=
"cat"
):
super
().
__init__
()
super
().
__init__
()
self
.
Conv_0
=
conv1x1
(
dim1
,
dim2
)
# 1x1 convolution with DDPM initialization.
self
.
Conv_0
=
Conv2d
(
dim1
,
dim2
,
kernel_size
=
1
,
padding
=
0
)
self
.
method
=
method
self
.
method
=
method
def
forward
(
self
,
x
,
y
):
def
forward
(
self
,
x
,
y
):
...
@@ -413,80 +183,40 @@ class Combine(nn.Module):
...
@@ -413,80 +183,40 @@ class Combine(nn.Module):
raise
ValueError
(
f
"Method
{
self
.
method
}
not recognized."
)
raise
ValueError
(
f
"Method
{
self
.
method
}
not recognized."
)
class
Upsample
(
nn
.
Module
):
class
Fir
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with_conv
=
False
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
out_ch
=
out_ch
if
out_ch
else
in_ch
if
not
fir
:
if
with_conv
:
if
with_conv
:
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
else
:
if
with_conv
:
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel
=
3
,
up
=
True
,
resample_kernel
=
fir_kernel
,
use_bias
=
True
,
kernel_init
=
default_init
(),
)
self
.
fir
=
fir
self
.
with_conv
=
with_conv
self
.
with_conv
=
with_conv
self
.
fir_kernel
=
fir_kernel
self
.
fir_kernel
=
fir_kernel
self
.
out_ch
=
out_ch
self
.
out_ch
=
out_ch
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
if
self
.
with_conv
:
if
not
self
.
fir
:
h
=
_upsample_conv_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
h
=
F
.
interpolate
(
x
,
(
H
*
2
,
W
*
2
),
"nearest"
)
if
self
.
with_conv
:
h
=
self
.
Conv_0
(
h
)
else
:
else
:
if
not
self
.
with_conv
:
h
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
h
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
self
.
Conv2d_0
(
x
)
return
h
return
h
class
Downsample
(
nn
.
Module
):
class
Fir
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with_conv
=
False
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
out_ch
=
out_ch
if
out_ch
else
in_ch
if
not
fir
:
if
with_conv
:
if
with_conv
:
self
.
Conv2d_0
=
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
,
stride
=
2
,
padding
=
0
)
else
:
if
with_conv
:
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel
=
3
,
down
=
True
,
resample_kernel
=
fir_kernel
,
use_bias
=
True
,
kernel_init
=
default_init
(),
)
self
.
fir
=
fir
self
.
fir_kernel
=
fir_kernel
self
.
fir_kernel
=
fir_kernel
self
.
with_conv
=
with_conv
self
.
with_conv
=
with_conv
self
.
out_ch
=
out_ch
self
.
out_ch
=
out_ch
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
if
self
.
with_conv
:
if
not
self
.
fir
:
x
=
_conv_downsample_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
if
self
.
with_conv
:
x
=
F
.
pad
(
x
,
(
0
,
1
,
0
,
1
))
x
=
self
.
Conv_0
(
x
)
else
:
x
=
F
.
avg_pool2d
(
x
,
2
,
stride
=
2
)
else
:
else
:
if
not
self
.
with_conv
:
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
x
=
self
.
Conv2d_0
(
x
)
return
x
return
x
...
@@ -496,63 +226,52 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -496,63 +226,52 @@ class NCSNpp(ModelMixin, ConfigMixin):
def
__init__
(
def
__init__
(
self
,
self
,
centered
=
False
,
image_size
=
1024
,
image_size
=
1024
,
num_channels
=
3
,
num_channels
=
3
,
attention_type
=
"ddpm"
,
attn_resolutions
=
(
16
,),
attn_resolutions
=
(
16
,),
ch_mult
=
(
1
,
2
,
4
,
8
,
16
,
32
,
32
,
32
),
ch_mult
=
(
1
,
2
,
4
,
8
,
16
,
32
,
32
,
32
),
conditional
=
True
,
conditional
=
True
,
conv_size
=
3
,
conv_size
=
3
,
dropout
=
0.0
,
dropout
=
0.0
,
embedding_type
=
"fourier"
,
embedding_type
=
"fourier"
,
fir
=
True
,
fir
=
True
,
# TODO (patil-suraj) remove this option from here and pre-trained model configs
fir_kernel
=
(
1
,
3
,
3
,
1
),
fir_kernel
=
(
1
,
3
,
3
,
1
),
fourier_scale
=
16
,
fourier_scale
=
16
,
init_scale
=
0.0
,
init_scale
=
0.0
,
nf
=
16
,
nf
=
16
,
nonlinearity
=
"swish"
,
normalization
=
"GroupNorm"
,
num_res_blocks
=
1
,
num_res_blocks
=
1
,
progressive
=
"output_skip"
,
progressive
=
"output_skip"
,
progressive_combine
=
"sum"
,
progressive_combine
=
"sum"
,
progressive_input
=
"input_skip"
,
progressive_input
=
"input_skip"
,
resamp_with_conv
=
True
,
resamp_with_conv
=
True
,
resblock_type
=
"biggan"
,
scale_by_sigma
=
True
,
scale_by_sigma
=
True
,
skip_rescale
=
True
,
skip_rescale
=
True
,
continuous
=
True
,
continuous
=
True
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register_to_config
(
self
.
register_to_config
(
centered
=
centered
,
image_size
=
image_size
,
image_size
=
image_size
,
num_channels
=
num_channels
,
num_channels
=
num_channels
,
attention_type
=
attention_type
,
attn_resolutions
=
attn_resolutions
,
attn_resolutions
=
attn_resolutions
,
ch_mult
=
ch_mult
,
ch_mult
=
ch_mult
,
conditional
=
conditional
,
conditional
=
conditional
,
conv_size
=
conv_size
,
conv_size
=
conv_size
,
dropout
=
dropout
,
dropout
=
dropout
,
embedding_type
=
embedding_type
,
embedding_type
=
embedding_type
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
fir_kernel
=
fir_kernel
,
fourier_scale
=
fourier_scale
,
fourier_scale
=
fourier_scale
,
init_scale
=
init_scale
,
init_scale
=
init_scale
,
nf
=
nf
,
nf
=
nf
,
nonlinearity
=
nonlinearity
,
normalization
=
normalization
,
num_res_blocks
=
num_res_blocks
,
num_res_blocks
=
num_res_blocks
,
progressive
=
progressive
,
progressive
=
progressive
,
progressive_combine
=
progressive_combine
,
progressive_combine
=
progressive_combine
,
progressive_input
=
progressive_input
,
progressive_input
=
progressive_input
,
resamp_with_conv
=
resamp_with_conv
,
resamp_with_conv
=
resamp_with_conv
,
resblock_type
=
resblock_type
,
scale_by_sigma
=
scale_by_sigma
,
scale_by_sigma
=
scale_by_sigma
,
skip_rescale
=
skip_rescale
,
skip_rescale
=
skip_rescale
,
continuous
=
continuous
,
continuous
=
continuous
,
)
)
self
.
act
=
act
=
get_act
(
nonlinearity
)
self
.
act
=
act
=
nn
.
SiLU
(
)
self
.
nf
=
nf
self
.
nf
=
nf
self
.
num_res_blocks
=
num_res_blocks
self
.
num_res_blocks
=
num_res_blocks
...
@@ -562,7 +281,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -562,7 +281,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
self
.
conditional
=
conditional
self
.
conditional
=
conditional
self
.
skip_rescale
=
skip_rescale
self
.
skip_rescale
=
skip_rescale
self
.
resblock_type
=
resblock_type
self
.
progressive
=
progressive
self
.
progressive
=
progressive
self
.
progressive_input
=
progressive_input
self
.
progressive_input
=
progressive_input
self
.
embedding_type
=
embedding_type
self
.
embedding_type
=
embedding_type
...
@@ -585,53 +303,33 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -585,53 +303,33 @@ class NCSNpp(ModelMixin, ConfigMixin):
else
:
else
:
raise
ValueError
(
f
"embedding type
{
embedding_type
}
unknown."
)
raise
ValueError
(
f
"embedding type
{
embedding_type
}
unknown."
)
if
conditional
:
modules
.
append
(
Linear
(
embed_dim
,
nf
*
4
))
modules
.
append
(
nn
.
Linear
(
embed_dim
,
nf
*
4
))
modules
.
append
(
Linear
(
nf
*
4
,
nf
*
4
))
modules
[
-
1
].
weight
.
data
=
default_init
()(
modules
[
-
1
].
weight
.
shape
)
nn
.
init
.
zeros_
(
modules
[
-
1
].
bias
)
modules
.
append
(
nn
.
Linear
(
nf
*
4
,
nf
*
4
))
modules
[
-
1
].
weight
.
data
=
default_init
()(
modules
[
-
1
].
weight
.
shape
)
nn
.
init
.
zeros_
(
modules
[
-
1
].
bias
)
AttnBlock
=
functools
.
partial
(
AttentionBlock
,
overwrite_linear
=
True
,
rescale_output_factor
=
math
.
sqrt
(
2.0
))
AttnBlock
=
functools
.
partial
(
AttentionBlock
,
overwrite_linear
=
True
,
rescale_output_factor
=
math
.
sqrt
(
2.0
))
Up_sample
=
functools
.
partial
(
Upsample
,
with_conv
=
resamp_with_conv
,
fir
=
fir
,
fir_kernel
=
fir_kernel
)
Up_sample
=
functools
.
partial
(
Fir
Upsample
,
with_conv
=
resamp_with_conv
,
fir_kernel
=
fir_kernel
)
if
progressive
==
"output_skip"
:
if
progressive
==
"output_skip"
:
self
.
pyramid_upsample
=
Up_sample
(
fir
=
fir
,
fir_kernel
=
fir_kernel
,
with_conv
=
False
)
self
.
pyramid_upsample
=
Up_sample
(
fir_kernel
=
fir_kernel
,
with_conv
=
False
)
elif
progressive
==
"residual"
:
elif
progressive
==
"residual"
:
pyramid_upsample
=
functools
.
partial
(
Up_sample
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
with_conv
=
True
)
pyramid_upsample
=
functools
.
partial
(
Up_sample
,
fir_kernel
=
fir_kernel
,
with_conv
=
True
)
Down_sample
=
functools
.
partial
(
Downsample
,
with_conv
=
resamp_with_conv
,
fir
=
fir
,
fir_kernel
=
fir_kernel
)
Down_sample
=
functools
.
partial
(
Fir
Downsample
,
with_conv
=
resamp_with_conv
,
fir_kernel
=
fir_kernel
)
if
progressive_input
==
"input_skip"
:
if
progressive_input
==
"input_skip"
:
self
.
pyramid_downsample
=
Down_sample
(
fir
=
fir
,
fir_kernel
=
fir_kernel
,
with_conv
=
False
)
self
.
pyramid_downsample
=
Down_sample
(
fir_kernel
=
fir_kernel
,
with_conv
=
False
)
elif
progressive_input
==
"residual"
:
elif
progressive_input
==
"residual"
:
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
with_conv
=
True
)
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
fir_kernel
=
fir_kernel
,
with_conv
=
True
)
if
resblock_type
==
"ddpm"
:
ResnetBlock
=
functools
.
partial
(
ResnetBlockDDPMpp
,
act
=
act
,
dropout
=
dropout
,
init_scale
=
init_scale
,
skip_rescale
=
skip_rescale
,
temb_dim
=
nf
*
4
,
)
elif
resblock_type
==
"biggan"
:
ResnetBlock
=
functools
.
partial
(
ResnetBlockBigGANpp
,
act
=
act
,
dropout
=
dropout
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
init_scale
=
init_scale
,
skip_rescale
=
skip_rescale
,
temb_dim
=
nf
*
4
,
)
else
:
ResnetBlock
=
functools
.
partial
(
raise
ValueError
(
f
"resblock type
{
resblock_type
}
unrecognized."
)
ResnetBlockBigGANpp
,
act
=
act
,
dropout
=
dropout
,
fir_kernel
=
fir_kernel
,
init_scale
=
init_scale
,
skip_rescale
=
skip_rescale
,
temb_dim
=
nf
*
4
,
)
# Downsampling block
# Downsampling block
...
@@ -639,7 +337,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -639,7 +337,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if
progressive_input
!=
"none"
:
if
progressive_input
!=
"none"
:
input_pyramid_ch
=
channels
input_pyramid_ch
=
channels
modules
.
append
(
c
onv
3x3
(
channels
,
nf
))
modules
.
append
(
C
onv
2d
(
channels
,
nf
,
kernel_size
=
3
,
padding
=
1
))
hs_c
=
[
nf
]
hs_c
=
[
nf
]
in_ch
=
nf
in_ch
=
nf
...
@@ -655,10 +353,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -655,10 +353,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c
.
append
(
in_ch
)
hs_c
.
append
(
in_ch
)
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
if
resblock_type
==
"ddpm"
:
modules
.
append
(
ResnetBlock
(
down
=
True
,
in_ch
=
in_ch
))
modules
.
append
(
Downsample
(
in_ch
=
in_ch
))
else
:
modules
.
append
(
ResnetBlock
(
down
=
True
,
in_ch
=
in_ch
))
if
progressive_input
==
"input_skip"
:
if
progressive_input
==
"input_skip"
:
modules
.
append
(
combiner
(
dim1
=
input_pyramid_ch
,
dim2
=
in_ch
))
modules
.
append
(
combiner
(
dim1
=
input_pyramid_ch
,
dim2
=
in_ch
))
...
@@ -691,18 +386,20 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -691,18 +386,20 @@ class NCSNpp(ModelMixin, ConfigMixin):
if
i_level
==
self
.
num_resolutions
-
1
:
if
i_level
==
self
.
num_resolutions
-
1
:
if
progressive
==
"output_skip"
:
if
progressive
==
"output_skip"
:
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
c
onv
3x3
(
in_ch
,
channels
,
init_scale
=
init_scale
))
modules
.
append
(
C
onv
2d
(
in_ch
,
channels
,
init_scale
=
init_scale
,
kernel_size
=
3
,
padding
=
1
))
pyramid_ch
=
channels
pyramid_ch
=
channels
elif
progressive
==
"residual"
:
elif
progressive
==
"residual"
:
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
c
onv
3x3
(
in_ch
,
in_ch
,
bias
=
True
))
modules
.
append
(
C
onv
2d
(
in_ch
,
in_ch
,
bias
=
True
,
kernel_size
=
3
,
padding
=
1
))
pyramid_ch
=
in_ch
pyramid_ch
=
in_ch
else
:
else
:
raise
ValueError
(
f
"
{
progressive
}
is not a valid name."
)
raise
ValueError
(
f
"
{
progressive
}
is not a valid name."
)
else
:
else
:
if
progressive
==
"output_skip"
:
if
progressive
==
"output_skip"
:
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
conv3x3
(
in_ch
,
channels
,
bias
=
True
,
init_scale
=
init_scale
))
modules
.
append
(
Conv2d
(
in_ch
,
channels
,
bias
=
True
,
init_scale
=
init_scale
,
kernel_size
=
3
,
padding
=
1
)
)
pyramid_ch
=
channels
pyramid_ch
=
channels
elif
progressive
==
"residual"
:
elif
progressive
==
"residual"
:
modules
.
append
(
pyramid_upsample
(
in_ch
=
pyramid_ch
,
out_ch
=
in_ch
))
modules
.
append
(
pyramid_upsample
(
in_ch
=
pyramid_ch
,
out_ch
=
in_ch
))
...
@@ -711,16 +408,13 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -711,16 +408,13 @@ class NCSNpp(ModelMixin, ConfigMixin):
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
if
i_level
!=
0
:
if
i_level
!=
0
:
if
resblock_type
==
"ddpm"
:
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
,
up
=
True
))
modules
.
append
(
Upsample
(
in_ch
=
in_ch
))
else
:
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
,
up
=
True
))
assert
not
hs_c
assert
not
hs_c
if
progressive
!=
"output_skip"
:
if
progressive
!=
"output_skip"
:
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
c
onv
3x3
(
in_ch
,
channels
,
init_scale
=
init_scale
))
modules
.
append
(
C
onv
2d
(
in_ch
,
channels
,
init_scale
=
init_scale
))
self
.
all_modules
=
nn
.
ModuleList
(
modules
)
self
.
all_modules
=
nn
.
ModuleList
(
modules
)
...
@@ -751,9 +445,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -751,9 +445,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
else
:
else
:
temb
=
None
temb
=
None
if
not
self
.
config
.
centered
:
# If input data is in [0, 1]
# If input data is in [0, 1]
x
=
2
*
x
-
1.0
x
=
2
*
x
-
1.0
# Downsampling block
# Downsampling block
input_pyramid
=
None
input_pyramid
=
None
...
@@ -774,12 +467,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -774,12 +467,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs
.
append
(
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
if
self
.
resblock_type
==
"ddpm"
:
h
=
modules
[
m_idx
](
hs
[
-
1
],
temb
)
h
=
modules
[
m_idx
](
hs
[
-
1
])
m_idx
+=
1
m_idx
+=
1
else
:
h
=
modules
[
m_idx
](
hs
[
-
1
],
temb
)
m_idx
+=
1
if
self
.
progressive_input
==
"input_skip"
:
if
self
.
progressive_input
==
"input_skip"
:
input_pyramid
=
self
.
pyramid_downsample
(
input_pyramid
)
input_pyramid
=
self
.
pyramid_downsample
(
input_pyramid
)
...
@@ -851,12 +540,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -851,12 +540,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
raise
ValueError
(
f
"
{
self
.
progressive
}
is not a valid name"
)
raise
ValueError
(
f
"
{
self
.
progressive
}
is not a valid name"
)
if
i_level
!=
0
:
if
i_level
!=
0
:
if
self
.
resblock_type
==
"ddpm"
:
h
=
modules
[
m_idx
](
h
,
temb
)
h
=
modules
[
m_idx
](
h
)
m_idx
+=
1
m_idx
+=
1
else
:
h
=
modules
[
m_idx
](
h
,
temb
)
m_idx
+=
1
assert
not
hs
assert
not
hs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment