Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
ac796924
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "dcd77ce22273708294b7b9c2f7f0a4e45d7a9f33"
Commit
ac796924
authored
Jun 24, 2022
by
Patrick von Platen
Browse files
add score estimation model
parent
bd9c9fbf
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1070 additions
and
19 deletions
+1070
-19
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-3
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-0
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+17
-16
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+1051
-0
No files found.
src/diffusers/__init__.py
View file @
ac796924
...
@@ -7,9 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
...
@@ -7,9 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__
=
"0.0.4"
__version__
=
"0.0.4"
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
from
.models.unet
import
UNetModel
from
.models
import
NCSNpp
,
TemporalUNet
,
UNetLDMModel
,
UNetModel
from
.models.unet_ldm
import
UNetLDMModel
from
.models.unet_rl
import
TemporalUNet
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
BDDMPipeline
,
DDIMPipeline
,
DDPMPipeline
,
PNDMPipeline
from
.pipelines
import
BDDMPipeline
,
DDIMPipeline
,
DDPMPipeline
,
PNDMPipeline
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
GradTTSScheduler
,
PNDMScheduler
,
SchedulerMixin
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
GradTTSScheduler
,
PNDMScheduler
,
SchedulerMixin
...
...
src/diffusers/models/__init__.py
View file @
ac796924
...
@@ -21,3 +21,4 @@ from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, Glide
...
@@ -21,3 +21,4 @@ from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, Glide
from
.unet_grad_tts
import
UNetGradTTSModel
from
.unet_grad_tts
import
UNetGradTTSModel
from
.unet_ldm
import
UNetLDMModel
from
.unet_ldm
import
UNetLDMModel
from
.unet_rl
import
TemporalUNet
from
.unet_rl
import
TemporalUNet
from
.unet_sde_score_estimation
import
NCSNpp
src/diffusers/models/unet_rl.py
View file @
ac796924
...
@@ -5,6 +5,7 @@ import math
...
@@ -5,6 +5,7 @@ import math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
try
:
try
:
import
einops
import
einops
from
einops.layers.torch
import
Rearrange
from
einops.layers.torch
import
Rearrange
...
@@ -104,14 +105,14 @@ class ResidualTemporalBlock(nn.Module):
...
@@ -104,14 +105,14 @@ class ResidualTemporalBlock(nn.Module):
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
training_horizon
,
training_horizon
,
transition_dim
,
transition_dim
,
cond_dim
,
cond_dim
,
predict_epsilon
=
False
,
predict_epsilon
=
False
,
clip_denoised
=
True
,
clip_denoised
=
True
,
dim
=
32
,
dim
=
32
,
dim_mults
=
(
1
,
2
,
4
,
8
),
dim_mults
=
(
1
,
2
,
4
,
8
),
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -211,14 +212,14 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -211,14 +212,14 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
class
TemporalValue
(
nn
.
Module
):
class
TemporalValue
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
horizon
,
horizon
,
transition_dim
,
transition_dim
,
cond_dim
,
cond_dim
,
dim
=
32
,
dim
=
32
,
time_dim
=
None
,
time_dim
=
None
,
out_dim
=
1
,
out_dim
=
1
,
dim_mults
=
(
1
,
2
,
4
,
8
),
dim_mults
=
(
1
,
2
,
4
,
8
),
):
):
super
().
__init__
()
super
().
__init__
()
...
...
src/diffusers/models/unet_sde_score_estimation.py
0 → 100644
View file @
ac796924
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# helpers functions
import
functools
import
math
import
string
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
return
upfirdn2d_native
(
input
,
kernel
,
up
,
up
,
down
,
down
,
pad
[
0
],
pad
[
1
],
pad
[
0
],
pad
[
1
]
)
def
upfirdn2d_native
(
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
):
_
,
channel
,
in_h
,
in_w
=
input
.
shape
input
=
input
.
reshape
(
-
1
,
in_h
,
in_w
,
1
)
_
,
in_h
,
in_w
,
minor
=
input
.
shape
kernel_h
,
kernel_w
=
kernel
.
shape
out
=
input
.
view
(
-
1
,
in_h
,
1
,
in_w
,
1
,
minor
)
out
=
F
.
pad
(
out
,
[
0
,
0
,
0
,
up_x
-
1
,
0
,
0
,
0
,
up_y
-
1
])
out
=
out
.
view
(
-
1
,
in_h
*
up_y
,
in_w
*
up_x
,
minor
)
out
=
F
.
pad
(
out
,
[
0
,
0
,
max
(
pad_x0
,
0
),
max
(
pad_x1
,
0
),
max
(
pad_y0
,
0
),
max
(
pad_y1
,
0
)]
)
out
=
out
[
:,
max
(
-
pad_y0
,
0
)
:
out
.
shape
[
1
]
-
max
(
-
pad_y1
,
0
),
max
(
-
pad_x0
,
0
)
:
out
.
shape
[
2
]
-
max
(
-
pad_x1
,
0
),
:,
]
out
=
out
.
permute
(
0
,
3
,
1
,
2
)
out
=
out
.
reshape
(
[
-
1
,
1
,
in_h
*
up_y
+
pad_y0
+
pad_y1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
]
)
w
=
torch
.
flip
(
kernel
,
[
0
,
1
]).
view
(
1
,
1
,
kernel_h
,
kernel_w
)
out
=
F
.
conv2d
(
out
,
w
)
out
=
out
.
reshape
(
-
1
,
minor
,
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
+
1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
+
1
,
)
out
=
out
.
permute
(
0
,
2
,
3
,
1
)
out
=
out
[:,
::
down_y
,
::
down_x
,
:]
out_h
=
(
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
)
//
down_y
+
1
out_w
=
(
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
)
//
down_x
+
1
return
out
.
view
(
-
1
,
channel
,
out_h
,
out_w
)
# Function ported from StyleGAN2
def
get_weight
(
module
,
shape
,
weight_var
=
"weight"
,
kernel_init
=
None
):
"""Get/create weight tensor for a convolution or fully-connected layer."""
return
module
.
param
(
weight_var
,
kernel_init
,
shape
)
class
Conv2d
(
nn
.
Module
):
"""Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
def
__init__
(
self
,
in_ch
,
out_ch
,
kernel
,
up
=
False
,
down
=
False
,
resample_kernel
=
(
1
,
3
,
3
,
1
),
use_bias
=
True
,
kernel_init
=
None
,
):
super
().
__init__
()
assert
not
(
up
and
down
)
assert
kernel
>=
1
and
kernel
%
2
==
1
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
out_ch
,
in_ch
,
kernel
,
kernel
))
if
kernel_init
is
not
None
:
self
.
weight
.
data
=
kernel_init
(
self
.
weight
.
data
.
shape
)
if
use_bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_ch
))
self
.
up
=
up
self
.
down
=
down
self
.
resample_kernel
=
resample_kernel
self
.
kernel
=
kernel
self
.
use_bias
=
use_bias
def
forward
(
self
,
x
):
if
self
.
up
:
x
=
upsample_conv_2d
(
x
,
self
.
weight
,
k
=
self
.
resample_kernel
)
elif
self
.
down
:
x
=
conv_downsample_2d
(
x
,
self
.
weight
,
k
=
self
.
resample_kernel
)
else
:
x
=
F
.
conv2d
(
x
,
self
.
weight
,
stride
=
1
,
padding
=
self
.
kernel
//
2
)
if
self
.
use_bias
:
x
=
x
+
self
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
return
x
def
naive_upsample_2d
(
x
,
factor
=
2
):
_N
,
C
,
H
,
W
=
x
.
shape
x
=
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
,
1
,
W
,
1
))
x
=
x
.
repeat
(
1
,
1
,
1
,
factor
,
1
,
factor
)
return
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
*
factor
,
W
*
factor
))
def
naive_downsample_2d
(
x
,
factor
=
2
):
_N
,
C
,
H
,
W
=
x
.
shape
x
=
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
//
factor
,
factor
,
W
//
factor
,
factor
))
return
torch
.
mean
(
x
,
dim
=
(
3
,
5
))
def
upsample_conv_2d
(
x
,
w
,
k
=
None
,
factor
=
2
,
gain
=
1
):
"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
Padding is performed only once at the beginning, not between the
operations.
The fused op is considerably more efficient than performing the same
calculation
using standard TensorFlow ops. It supports gradients of arbitrary order.
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels =
x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to
nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]` or
`[N, H * factor, W * factor, C]`, and same datatype as `x`.
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
# Check weight shape.
assert
len
(
w
.
shape
)
==
4
convH
=
w
.
shape
[
2
]
convW
=
w
.
shape
[
3
]
inC
=
w
.
shape
[
1
]
assert
convW
==
convH
# Setup filter kernel.
if
k
is
None
:
k
=
[
1
]
*
factor
k
=
_setup_kernel
(
k
)
*
(
gain
*
(
factor
**
2
))
p
=
(
k
.
shape
[
0
]
-
factor
)
-
(
convW
-
1
)
stride
=
(
factor
,
factor
)
# Determine data dimensions.
stride
=
[
1
,
1
,
factor
,
factor
]
output_shape
=
((
_shape
(
x
,
2
)
-
1
)
*
factor
+
convH
,
(
_shape
(
x
,
3
)
-
1
)
*
factor
+
convW
)
output_padding
=
(
output_shape
[
0
]
-
(
_shape
(
x
,
2
)
-
1
)
*
stride
[
0
]
-
convH
,
output_shape
[
1
]
-
(
_shape
(
x
,
3
)
-
1
)
*
stride
[
1
]
-
convW
,
)
assert
output_padding
[
0
]
>=
0
and
output_padding
[
1
]
>=
0
num_groups
=
_shape
(
x
,
1
)
//
inC
# Transpose weights.
w
=
torch
.
reshape
(
w
,
(
num_groups
,
-
1
,
inC
,
convH
,
convW
))
w
=
w
[...,
::
-
1
,
::
-
1
].
permute
(
0
,
2
,
1
,
3
,
4
)
w
=
torch
.
reshape
(
w
,
(
num_groups
*
inC
,
-
1
,
convH
,
convW
))
x
=
F
.
conv_transpose2d
(
x
,
w
,
stride
=
stride
,
output_padding
=
output_padding
,
padding
=
0
)
# Original TF code.
# x = tf.nn.conv2d_transpose(
# x,
# w,
# output_shape=output_shape,
# strides=stride,
# padding='VALID',
# data_format=data_format)
# JAX equivalent
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
pad
=
((
p
+
1
)
//
2
+
factor
-
1
,
p
//
2
+
1
))
def
conv_downsample_2d
(
x
,
w
,
k
=
None
,
factor
=
2
,
gain
=
1
):
"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
Padding is performed only once at the beginning, not between the operations.
The fused op is considerably more efficient than performing the same
calculation
using standard TensorFlow ops. It supports gradients of arbitrary order.
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels =
x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to
average pooling.
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]` or
`[N, H // factor, W // factor, C]`, and same datatype as `x`.
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
_outC
,
_inC
,
convH
,
convW
=
w
.
shape
assert
convW
==
convH
if
k
is
None
:
k
=
[
1
]
*
factor
k
=
_setup_kernel
(
k
)
*
gain
p
=
(
k
.
shape
[
0
]
-
factor
)
+
(
convW
-
1
)
s
=
[
factor
,
factor
]
x
=
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
return
F
.
conv2d
(
x
,
w
,
stride
=
s
,
padding
=
0
)
def
_setup_kernel
(
k
):
k
=
np
.
asarray
(
k
,
dtype
=
np
.
float32
)
if
k
.
ndim
==
1
:
k
=
np
.
outer
(
k
,
k
)
k
/=
np
.
sum
(
k
)
assert
k
.
ndim
==
2
assert
k
.
shape
[
0
]
==
k
.
shape
[
1
]
return
k
def
_shape
(
x
,
dim
):
return
x
.
shape
[
dim
]
def
upsample_2d
(
x
,
k
=
None
,
factor
=
2
,
gain
=
1
):
r
"""Upsample a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
and upsamples each image with the given filter. The filter is normalized so
that
if the input pixels are constant, they will be scaled by the specified
`gain`.
Pixels outside the image are assumed to be zero, and the filter is padded
with
zeros so that its shape is a multiple of the upsampling factor.
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to
nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
if
k
is
None
:
k
=
[
1
]
*
factor
k
=
_setup_kernel
(
k
)
*
(
gain
*
(
factor
**
2
))
p
=
k
.
shape
[
0
]
-
factor
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
up
=
factor
,
pad
=
((
p
+
1
)
//
2
+
factor
-
1
,
p
//
2
))
def
downsample_2d
(
x
,
k
=
None
,
factor
=
2
,
gain
=
1
):
r
"""Downsample a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
and downsamples each image with the given filter. The filter is normalized
so that
if the input pixels are constant, they will be scaled by the specified
`gain`.
Pixels outside the image are assumed to be zero, and the filter is padded
with
zeros so that its shape is a multiple of the downsampling factor.
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to
average pooling.
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
if
k
is
None
:
k
=
[
1
]
*
factor
k
=
_setup_kernel
(
k
)
*
gain
p
=
k
.
shape
[
0
]
-
factor
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
def
ddpm_conv1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
"""1x1 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
def
ddpm_conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
"""3x3 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
conv1x1
=
ddpm_conv1x1
conv3x3
=
ddpm_conv3x3
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
einsum_str
=
'{},{}->{}'
.
format
(
''
.
join
(
a
),
''
.
join
(
b
),
''
.
join
(
c
))
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
def
contract_inner
(
x
,
y
):
"""tensordot(x, y, 1)."""
x_chars
=
list
(
string
.
ascii_lowercase
[:
len
(
x
.
shape
)])
y_chars
=
list
(
string
.
ascii_lowercase
[
len
(
x
.
shape
)
:
len
(
y
.
shape
)
+
len
(
x
.
shape
)])
y_chars
[
0
]
=
x_chars
[
-
1
]
# first axis of y and last of x get summed
out_chars
=
x_chars
[:
-
1
]
+
y_chars
[
1
:]
return
_einsum
(
x_chars
,
y_chars
,
out_chars
,
x
,
y
)
class
NIN
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
num_units
,
init_scale
=
0.1
):
super
().
__init__
()
self
.
W
=
nn
.
Parameter
(
default_init
(
scale
=
init_scale
)((
in_dim
,
num_units
)),
requires_grad
=
True
)
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
num_units
),
requires_grad
=
True
)
def
forward
(
self
,
x
):
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
y
=
contract_inner
(
x
,
self
.
W
)
+
self
.
b
return
y
.
permute
(
0
,
3
,
1
,
2
)
def
get_act
(
config
):
"""Get activation functions from the config file."""
if
config
.
model
.
nonlinearity
.
lower
()
==
"elu"
:
return
nn
.
ELU
()
elif
config
.
model
.
nonlinearity
.
lower
()
==
"relu"
:
return
nn
.
ReLU
()
elif
config
.
model
.
nonlinearity
.
lower
()
==
"lrelu"
:
return
nn
.
LeakyReLU
(
negative_slope
=
0.2
)
elif
config
.
model
.
nonlinearity
.
lower
()
==
"swish"
:
return
nn
.
SiLU
()
else
:
raise
NotImplementedError
(
"activation function does not exist!"
)
def
get_timestep_embedding
(
timesteps
,
embedding_dim
,
max_positions
=
10000
):
assert
len
(
timesteps
.
shape
)
==
1
# and timesteps.dtype == tf.int32
half_dim
=
embedding_dim
//
2
# magic number 10000 is from transformers
emb
=
math
.
log
(
max_positions
)
/
(
half_dim
-
1
)
# emb = math.log(2.) / (half_dim - 1)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
,
device
=
timesteps
.
device
)
*
-
emb
)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
F
.
pad
(
emb
,
(
0
,
1
),
mode
=
"constant"
)
assert
emb
.
shape
==
(
timesteps
.
shape
[
0
],
embedding_dim
)
return
emb
def
default_init
(
scale
=
1.0
):
"""The same initialization used in DDPM."""
scale
=
1e-10
if
scale
==
0
else
scale
return
variance_scaling
(
scale
,
"fan_avg"
,
"uniform"
)
def
variance_scaling
(
scale
,
mode
,
distribution
,
in_axis
=
1
,
out_axis
=
0
,
dtype
=
torch
.
float32
,
device
=
"cpu"
):
"""Ported from JAX."""
def
_compute_fans
(
shape
,
in_axis
=
1
,
out_axis
=
0
):
receptive_field_size
=
np
.
prod
(
shape
)
/
shape
[
in_axis
]
/
shape
[
out_axis
]
fan_in
=
shape
[
in_axis
]
*
receptive_field_size
fan_out
=
shape
[
out_axis
]
*
receptive_field_size
return
fan_in
,
fan_out
def
init
(
shape
,
dtype
=
dtype
,
device
=
device
):
fan_in
,
fan_out
=
_compute_fans
(
shape
,
in_axis
,
out_axis
)
if
mode
==
"fan_in"
:
denominator
=
fan_in
elif
mode
==
"fan_out"
:
denominator
=
fan_out
elif
mode
==
"fan_avg"
:
denominator
=
(
fan_in
+
fan_out
)
/
2
else
:
raise
ValueError
(
"invalid mode for variance scaling initializer: {}"
.
format
(
mode
))
variance
=
scale
/
denominator
if
distribution
==
"normal"
:
return
torch
.
randn
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
np
.
sqrt
(
variance
)
elif
distribution
==
"uniform"
:
return
(
torch
.
rand
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
2.0
-
1.0
)
*
np
.
sqrt
(
3
*
variance
)
else
:
raise
ValueError
(
"invalid distribution for variance scaling initializer"
)
return
init
class
GaussianFourierProjection
(
nn
.
Module
):
"""Gaussian Fourier embeddings for noise levels."""
def
__init__
(
self
,
embedding_size
=
256
,
scale
=
1.0
):
super
().
__init__
()
self
.
W
=
nn
.
Parameter
(
torch
.
randn
(
embedding_size
)
*
scale
,
requires_grad
=
False
)
def
forward
(
self
,
x
):
x_proj
=
x
[:,
None
]
*
self
.
W
[
None
,
:]
*
2
*
np
.
pi
return
torch
.
cat
([
torch
.
sin
(
x_proj
),
torch
.
cos
(
x_proj
)],
dim
=-
1
)
class
Combine
(
nn
.
Module
):
"""Combine information from skip connections."""
def
__init__
(
self
,
dim1
,
dim2
,
method
=
"cat"
):
super
().
__init__
()
self
.
Conv_0
=
conv1x1
(
dim1
,
dim2
)
self
.
method
=
method
def
forward
(
self
,
x
,
y
):
h
=
self
.
Conv_0
(
x
)
if
self
.
method
==
"cat"
:
return
torch
.
cat
([
h
,
y
],
dim
=
1
)
elif
self
.
method
==
"sum"
:
return
h
+
y
else
:
raise
ValueError
(
f
"Method
{
self
.
method
}
not recognized."
)
class
AttnBlockpp
(
nn
.
Module
):
"""Channel-wise self-attention block. Modified from DDPM."""
def
__init__
(
self
,
channels
,
skip_rescale
=
False
,
init_scale
=
0.0
):
super
().
__init__
()
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
channels
//
4
,
32
),
num_channels
=
channels
,
eps
=
1e-6
)
self
.
NIN_0
=
NIN
(
channels
,
channels
)
self
.
NIN_1
=
NIN
(
channels
,
channels
)
self
.
NIN_2
=
NIN
(
channels
,
channels
)
self
.
NIN_3
=
NIN
(
channels
,
channels
,
init_scale
=
init_scale
)
self
.
skip_rescale
=
skip_rescale
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
h
=
self
.
GroupNorm_0
(
x
)
q
=
self
.
NIN_0
(
h
)
k
=
self
.
NIN_1
(
h
)
v
=
self
.
NIN_2
(
h
)
w
=
torch
.
einsum
(
"bchw,bcij->bhwij"
,
q
,
k
)
*
(
int
(
C
)
**
(
-
0.5
))
w
=
torch
.
reshape
(
w
,
(
B
,
H
,
W
,
H
*
W
))
w
=
F
.
softmax
(
w
,
dim
=-
1
)
w
=
torch
.
reshape
(
w
,
(
B
,
H
,
W
,
H
,
W
))
h
=
torch
.
einsum
(
"bhwij,bcij->bchw"
,
w
,
v
)
h
=
self
.
NIN_3
(
h
)
if
not
self
.
skip_rescale
:
return
x
+
h
else
:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with_conv
=
False
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
if
not
fir
:
if
with_conv
:
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
else
:
if
with_conv
:
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel
=
3
,
up
=
True
,
resample_kernel
=
fir_kernel
,
use_bias
=
True
,
kernel_init
=
default_init
(),
)
self
.
fir
=
fir
self
.
with_conv
=
with_conv
self
.
fir_kernel
=
fir_kernel
self
.
out_ch
=
out_ch
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
if
not
self
.
fir
:
h
=
F
.
interpolate
(
x
,
(
H
*
2
,
W
*
2
),
"nearest"
)
if
self
.
with_conv
:
h
=
self
.
Conv_0
(
h
)
else
:
if
not
self
.
with_conv
:
h
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
self
.
Conv2d_0
(
x
)
return
h
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with_conv
=
False
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
if
not
fir
:
if
with_conv
:
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
,
stride
=
2
,
padding
=
0
)
else
:
if
with_conv
:
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel
=
3
,
down
=
True
,
resample_kernel
=
fir_kernel
,
use_bias
=
True
,
kernel_init
=
default_init
(),
)
self
.
fir
=
fir
self
.
fir_kernel
=
fir_kernel
self
.
with_conv
=
with_conv
self
.
out_ch
=
out_ch
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
if
not
self
.
fir
:
if
self
.
with_conv
:
x
=
F
.
pad
(
x
,
(
0
,
1
,
0
,
1
))
x
=
self
.
Conv_0
(
x
)
else
:
x
=
F
.
avg_pool2d
(
x
,
2
,
stride
=
2
)
else
:
if
not
self
.
with_conv
:
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
x
=
self
.
Conv2d_0
(
x
)
return
x
class
ResnetBlockDDPMpp
(
nn
.
Module
):
"""ResBlock adapted from DDPM."""
def
__init__
(
self
,
act
,
in_ch
,
out_ch
=
None
,
temb_dim
=
None
,
conv_shortcut
=
False
,
dropout
=
0.1
,
skip_rescale
=
False
,
init_scale
=
0.0
,
):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
if
in_ch
!=
out_ch
:
if
conv_shortcut
:
self
.
Conv_2
=
conv3x3
(
in_ch
,
out_ch
)
else
:
self
.
NIN_0
=
NIN
(
in_ch
,
out_ch
)
self
.
skip_rescale
=
skip_rescale
self
.
act
=
act
self
.
out_ch
=
out_ch
self
.
conv_shortcut
=
conv_shortcut
def
forward
(
self
,
x
,
temb
=
None
):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
h
=
self
.
Conv_0
(
h
)
if
temb
is
not
None
:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
h
=
self
.
Dropout_0
(
h
)
h
=
self
.
Conv_1
(
h
)
if
x
.
shape
[
1
]
!=
self
.
out_ch
:
if
self
.
conv_shortcut
:
x
=
self
.
Conv_2
(
x
)
else
:
x
=
self
.
NIN_0
(
x
)
if
not
self
.
skip_rescale
:
return
x
+
h
else
:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
class
ResnetBlockBigGANpp
(
nn
.
Module
):
def
__init__
(
self
,
act
,
in_ch
,
out_ch
=
None
,
temb_dim
=
None
,
up
=
False
,
down
=
False
,
dropout
=
0.1
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
skip_rescale
=
True
,
init_scale
=
0.0
,
):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
up
=
up
self
.
down
=
down
self
.
fir
=
fir
self
.
fir_kernel
=
fir_kernel
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
if
in_ch
!=
out_ch
or
up
or
down
:
self
.
Conv_2
=
conv1x1
(
in_ch
,
out_ch
)
self
.
skip_rescale
=
skip_rescale
self
.
act
=
act
self
.
in_ch
=
in_ch
self
.
out_ch
=
out_ch
def
forward
(
self
,
x
,
temb
=
None
):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
if
self
.
up
:
if
self
.
fir
:
h
=
upsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
naive_upsample_2d
(
h
,
factor
=
2
)
x
=
naive_upsample_2d
(
x
,
factor
=
2
)
elif
self
.
down
:
if
self
.
fir
:
h
=
downsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
naive_downsample_2d
(
h
,
factor
=
2
)
x
=
naive_downsample_2d
(
x
,
factor
=
2
)
h
=
self
.
Conv_0
(
h
)
# Add bias to each feature map conditioned on the time embedding
if
temb
is
not
None
:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
h
=
self
.
Dropout_0
(
h
)
h
=
self
.
Conv_1
(
h
)
if
self
.
in_ch
!=
self
.
out_ch
or
self
.
up
or
self
.
down
:
x
=
self
.
Conv_2
(
x
)
if
not
self
.
skip_rescale
:
return
x
+
h
else
:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
class
NCSNpp
(
nn
.
Module
):
"""NCSN++ model"""
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
act
=
act
=
get_act
(
config
)
# self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config)))
self
.
nf
=
nf
=
config
.
model
.
nf
ch_mult
=
config
.
model
.
ch_mult
self
.
num_res_blocks
=
num_res_blocks
=
config
.
model
.
num_res_blocks
self
.
attn_resolutions
=
attn_resolutions
=
config
.
model
.
attn_resolutions
dropout
=
config
.
model
.
dropout
resamp_with_conv
=
config
.
model
.
resamp_with_conv
self
.
num_resolutions
=
num_resolutions
=
len
(
ch_mult
)
self
.
all_resolutions
=
all_resolutions
=
[
config
.
data
.
image_size
//
(
2
**
i
)
for
i
in
range
(
num_resolutions
)]
self
.
conditional
=
conditional
=
config
.
model
.
conditional
# noise-conditional
fir
=
config
.
model
.
fir
fir_kernel
=
config
.
model
.
fir_kernel
self
.
skip_rescale
=
skip_rescale
=
config
.
model
.
skip_rescale
self
.
resblock_type
=
resblock_type
=
config
.
model
.
resblock_type
.
lower
()
self
.
progressive
=
progressive
=
config
.
model
.
progressive
.
lower
()
self
.
progressive_input
=
progressive_input
=
config
.
model
.
progressive_input
.
lower
()
self
.
embedding_type
=
embedding_type
=
config
.
model
.
embedding_type
.
lower
()
init_scale
=
config
.
model
.
init_scale
assert
progressive
in
[
"none"
,
"output_skip"
,
"residual"
]
assert
progressive_input
in
[
"none"
,
"input_skip"
,
"residual"
]
assert
embedding_type
in
[
"fourier"
,
"positional"
]
combine_method
=
config
.
model
.
progressive_combine
.
lower
()
combiner
=
functools
.
partial
(
Combine
,
method
=
combine_method
)
modules
=
[]
# timestep/noise_level embedding; only for continuous training
if
embedding_type
==
"fourier"
:
# Gaussian Fourier features embeddings.
assert
config
.
training
.
continuous
,
"Fourier features are only used for continuous training."
modules
.
append
(
GaussianFourierProjection
(
embedding_size
=
nf
,
scale
=
config
.
model
.
fourier_scale
))
embed_dim
=
2
*
nf
elif
embedding_type
==
"positional"
:
embed_dim
=
nf
else
:
raise
ValueError
(
f
"embedding type
{
embedding_type
}
unknown."
)
if
conditional
:
modules
.
append
(
nn
.
Linear
(
embed_dim
,
nf
*
4
))
modules
[
-
1
].
weight
.
data
=
default_init
()(
modules
[
-
1
].
weight
.
shape
)
nn
.
init
.
zeros_
(
modules
[
-
1
].
bias
)
modules
.
append
(
nn
.
Linear
(
nf
*
4
,
nf
*
4
))
modules
[
-
1
].
weight
.
data
=
default_init
()(
modules
[
-
1
].
weight
.
shape
)
nn
.
init
.
zeros_
(
modules
[
-
1
].
bias
)
AttnBlock
=
functools
.
partial
(
AttnBlockpp
,
init_scale
=
init_scale
,
skip_rescale
=
skip_rescale
)
Up_sample
=
functools
.
partial
(
Upsample
,
with_conv
=
resamp_with_conv
,
fir
=
fir
,
fir_kernel
=
fir_kernel
)
if
progressive
==
"output_skip"
:
self
.
pyramid_upsample
=
Up_sample
(
fir
=
fir
,
fir_kernel
=
fir_kernel
,
with_conv
=
False
)
elif
progressive
==
"residual"
:
pyramid_upsample
=
functools
.
partial
(
Up_sample
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
with_conv
=
True
)
Down_sample
=
functools
.
partial
(
Downsample
,
with_conv
=
resamp_with_conv
,
fir
=
fir
,
fir_kernel
=
fir_kernel
)
if
progressive_input
==
"input_skip"
:
self
.
pyramid_downsample
=
Down_sample
(
fir
=
fir
,
fir_kernel
=
fir_kernel
,
with_conv
=
False
)
elif
progressive_input
==
"residual"
:
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
with_conv
=
True
)
if
resblock_type
==
"ddpm"
:
ResnetBlock
=
functools
.
partial
(
ResnetBlockDDPMpp
,
act
=
act
,
dropout
=
dropout
,
init_scale
=
init_scale
,
skip_rescale
=
skip_rescale
,
temb_dim
=
nf
*
4
,
)
elif
resblock_type
==
"biggan"
:
ResnetBlock
=
functools
.
partial
(
ResnetBlockBigGANpp
,
act
=
act
,
dropout
=
dropout
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
init_scale
=
init_scale
,
skip_rescale
=
skip_rescale
,
temb_dim
=
nf
*
4
,
)
else
:
raise
ValueError
(
f
"resblock type
{
resblock_type
}
unrecognized."
)
# Downsampling block
channels
=
config
.
data
.
num_channels
if
progressive_input
!=
"none"
:
input_pyramid_ch
=
channels
modules
.
append
(
conv3x3
(
channels
,
nf
))
hs_c
=
[
nf
]
in_ch
=
nf
for
i_level
in
range
(
num_resolutions
):
# Residual blocks for this resolution
for
i_block
in
range
(
num_res_blocks
):
out_ch
=
nf
*
ch_mult
[
i_level
]
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
,
out_ch
=
out_ch
))
in_ch
=
out_ch
if
all_resolutions
[
i_level
]
in
attn_resolutions
:
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
hs_c
.
append
(
in_ch
)
if
i_level
!=
num_resolutions
-
1
:
if
resblock_type
==
"ddpm"
:
modules
.
append
(
Downsample
(
in_ch
=
in_ch
))
else
:
modules
.
append
(
ResnetBlock
(
down
=
True
,
in_ch
=
in_ch
))
if
progressive_input
==
"input_skip"
:
modules
.
append
(
combiner
(
dim1
=
input_pyramid_ch
,
dim2
=
in_ch
))
if
combine_method
==
"cat"
:
in_ch
*=
2
elif
progressive_input
==
"residual"
:
modules
.
append
(
pyramid_downsample
(
in_ch
=
input_pyramid_ch
,
out_ch
=
in_ch
))
input_pyramid_ch
=
in_ch
hs_c
.
append
(
in_ch
)
in_ch
=
hs_c
[
-
1
]
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
))
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
))
pyramid_ch
=
0
# Upsampling block
for
i_level
in
reversed
(
range
(
num_resolutions
)):
for
i_block
in
range
(
num_res_blocks
+
1
):
out_ch
=
nf
*
ch_mult
[
i_level
]
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
+
hs_c
.
pop
(),
out_ch
=
out_ch
))
in_ch
=
out_ch
if
all_resolutions
[
i_level
]
in
attn_resolutions
:
modules
.
append
(
AttnBlock
(
channels
=
in_ch
))
if
progressive
!=
"none"
:
if
i_level
==
num_resolutions
-
1
:
if
progressive
==
"output_skip"
:
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
conv3x3
(
in_ch
,
channels
,
init_scale
=
init_scale
))
pyramid_ch
=
channels
elif
progressive
==
"residual"
:
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
conv3x3
(
in_ch
,
in_ch
,
bias
=
True
))
pyramid_ch
=
in_ch
else
:
raise
ValueError
(
f
"
{
progressive
}
is not a valid name."
)
else
:
if
progressive
==
"output_skip"
:
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
conv3x3
(
in_ch
,
channels
,
bias
=
True
,
init_scale
=
init_scale
))
pyramid_ch
=
channels
elif
progressive
==
"residual"
:
modules
.
append
(
pyramid_upsample
(
in_ch
=
pyramid_ch
,
out_ch
=
in_ch
))
pyramid_ch
=
in_ch
else
:
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
if
i_level
!=
0
:
if
resblock_type
==
"ddpm"
:
modules
.
append
(
Upsample
(
in_ch
=
in_ch
))
else
:
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
,
up
=
True
))
assert
not
hs_c
if
progressive
!=
"output_skip"
:
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
conv3x3
(
in_ch
,
channels
,
init_scale
=
init_scale
))
self
.
all_modules
=
nn
.
ModuleList
(
modules
)
def
forward
(
self
,
x
,
time_cond
):
# import ipdb; ipdb.set_trace()
# timestep/noise_level embedding; only for continuous training
modules
=
self
.
all_modules
m_idx
=
0
if
self
.
embedding_type
==
"fourier"
:
# Gaussian Fourier features embeddings.
used_sigmas
=
time_cond
temb
=
modules
[
m_idx
](
torch
.
log
(
used_sigmas
))
m_idx
+=
1
elif
self
.
embedding_type
==
"positional"
:
# Sinusoidal positional embeddings.
timesteps
=
time_cond
used_sigmas
=
self
.
sigmas
[
time_cond
.
long
()]
temb
=
get_timestep_embedding
(
timesteps
,
self
.
nf
)
else
:
raise
ValueError
(
f
"embedding type
{
self
.
embedding_type
}
unknown."
)
if
self
.
conditional
:
temb
=
modules
[
m_idx
](
temb
)
m_idx
+=
1
temb
=
modules
[
m_idx
](
self
.
act
(
temb
))
m_idx
+=
1
else
:
temb
=
None
if
not
self
.
config
.
data
.
centered
:
# If input data is in [0, 1]
x
=
2
*
x
-
1.0
# Downsampling block
input_pyramid
=
None
if
self
.
progressive_input
!=
"none"
:
input_pyramid
=
x
hs
=
[
modules
[
m_idx
](
x
)]
m_idx
+=
1
for
i_level
in
range
(
self
.
num_resolutions
):
# Residual blocks for this resolution
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
modules
[
m_idx
](
hs
[
-
1
],
temb
)
m_idx
+=
1
if
h
.
shape
[
-
1
]
in
self
.
attn_resolutions
:
h
=
modules
[
m_idx
](
h
)
m_idx
+=
1
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
if
self
.
resblock_type
==
"ddpm"
:
h
=
modules
[
m_idx
](
hs
[
-
1
])
m_idx
+=
1
else
:
h
=
modules
[
m_idx
](
hs
[
-
1
],
temb
)
m_idx
+=
1
if
self
.
progressive_input
==
"input_skip"
:
input_pyramid
=
self
.
pyramid_downsample
(
input_pyramid
)
h
=
modules
[
m_idx
](
input_pyramid
,
h
)
m_idx
+=
1
elif
self
.
progressive_input
==
"residual"
:
input_pyramid
=
modules
[
m_idx
](
input_pyramid
)
m_idx
+=
1
if
self
.
skip_rescale
:
input_pyramid
=
(
input_pyramid
+
h
)
/
np
.
sqrt
(
2.0
)
else
:
input_pyramid
=
input_pyramid
+
h
h
=
input_pyramid
hs
.
append
(
h
)
h
=
hs
[
-
1
]
h
=
modules
[
m_idx
](
h
,
temb
)
m_idx
+=
1
h
=
modules
[
m_idx
](
h
)
m_idx
+=
1
h
=
modules
[
m_idx
](
h
,
temb
)
m_idx
+=
1
pyramid
=
None
# Upsampling block
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
modules
[
m_idx
](
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
),
temb
)
m_idx
+=
1
if
h
.
shape
[
-
1
]
in
self
.
attn_resolutions
:
h
=
modules
[
m_idx
](
h
)
m_idx
+=
1
if
self
.
progressive
!=
"none"
:
if
i_level
==
self
.
num_resolutions
-
1
:
if
self
.
progressive
==
"output_skip"
:
pyramid
=
self
.
act
(
modules
[
m_idx
](
h
))
m_idx
+=
1
pyramid
=
modules
[
m_idx
](
pyramid
)
m_idx
+=
1
elif
self
.
progressive
==
"residual"
:
pyramid
=
self
.
act
(
modules
[
m_idx
](
h
))
m_idx
+=
1
pyramid
=
modules
[
m_idx
](
pyramid
)
m_idx
+=
1
else
:
raise
ValueError
(
f
"
{
self
.
progressive
}
is not a valid name."
)
else
:
if
self
.
progressive
==
"output_skip"
:
pyramid
=
self
.
pyramid_upsample
(
pyramid
)
pyramid_h
=
self
.
act
(
modules
[
m_idx
](
h
))
m_idx
+=
1
pyramid_h
=
modules
[
m_idx
](
pyramid_h
)
m_idx
+=
1
pyramid
=
pyramid
+
pyramid_h
elif
self
.
progressive
==
"residual"
:
pyramid
=
modules
[
m_idx
](
pyramid
)
m_idx
+=
1
if
self
.
skip_rescale
:
pyramid
=
(
pyramid
+
h
)
/
np
.
sqrt
(
2.0
)
else
:
pyramid
=
pyramid
+
h
h
=
pyramid
else
:
raise
ValueError
(
f
"
{
self
.
progressive
}
is not a valid name"
)
if
i_level
!=
0
:
if
self
.
resblock_type
==
"ddpm"
:
h
=
modules
[
m_idx
](
h
)
m_idx
+=
1
else
:
h
=
modules
[
m_idx
](
h
,
temb
)
m_idx
+=
1
assert
not
hs
if
self
.
progressive
==
"output_skip"
:
h
=
pyramid
else
:
h
=
self
.
act
(
modules
[
m_idx
](
h
))
m_idx
+=
1
h
=
modules
[
m_idx
](
h
)
m_idx
+=
1
assert
m_idx
==
len
(
modules
)
if
self
.
config
.
model
.
scale_by_sigma
:
used_sigmas
=
used_sigmas
.
reshape
((
x
.
shape
[
0
],
*
([
1
]
*
len
(
x
.
shape
[
1
:]))))
h
=
h
/
used_sigmas
return
h
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment