Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
ebf3717c
Commit
ebf3717c
authored
Jun 29, 2022
by
Patrick von Platen
Browse files
resnet in one file
parent
e5d9baf0
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1123 additions
and
564 deletions
+1123
-564
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+697
-122
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+41
-43
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+102
-111
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+21
-19
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+100
-104
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+30
-30
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+132
-135
No files found.
src/diffusers/models/resnet.py
View file @
ebf3717c
import
string
from
abc
import
abstractmethod
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -54,6 +58,18 @@ def nonlinearity(x, swish=1.0):
...
@@ -54,6 +58,18 @@ def nonlinearity(x, swish=1.0):
return
x
*
F
.
sigmoid
(
x
*
float
(
swish
))
return
x
*
F
.
sigmoid
(
x
*
float
(
swish
))
class
TimestepBlock
(
nn
.
Module
):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@
abstractmethod
def
forward
(
self
,
x
,
emb
):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class
Upsample
(
nn
.
Module
):
class
Upsample
(
nn
.
Module
):
"""
"""
An upsampling layer with an optional convolution.
An upsampling layer with an optional convolution.
...
@@ -134,154 +150,713 @@ class Downsample(nn.Module):
...
@@ -134,154 +150,713 @@ class Downsample(nn.Module):
return
self
.
op
(
x
)
return
self
.
op
(
x
)
class
UNetUpsample
(
nn
.
Module
):
# class UNetUpsample(nn.Module):
def
__init__
(
self
,
in_channels
,
with_conv
):
# def __init__(self, in_channels, with_conv):
super
().
__init__
()
# super().__init__()
self
.
with_conv
=
with_conv
# self.with_conv = with_conv
if
self
.
with_conv
:
# if self.with_conv:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
#
# def forward(self, x):
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
# if self.with_conv:
# x = self.conv(x)
# return x
#
#
# class GlideUpsample(nn.Module):
# """
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param
use_conv
:
a
bool
determining
if
a
convolution
is
# applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
3
D
,
then
# upsampling occurs in the inner-two dimensions. #"""
#
# def __init__(self, channels, use_conv, dims=2, out_channels=None):
# super().__init__()
# self.channels = channels
# self.out_channels = out_channels or channels
# self.use_conv = use_conv
# self.dims = dims
# if use_conv:
# self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
#
# def forward(self, x):
# assert x.shape[1] == self.channels
# if self.dims == 3:
# x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
# else:
# x = F.interpolate(x, scale_factor=2, mode="nearest")
# if self.use_conv:
# x = self.conv(x)
# return x
#
#
# class LDMUpsample(nn.Module):
# """
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param #
use_conv
:
a
bool
determining
if
a
convolution
is
applied
.
:
param
dims
:
determines
if
the
signal
is
1
D
,
2
D
,
or
3
D
.
# If
3
D
,
then
# upsampling occurs in the inner-two dimensions. #"""
#
# def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
# super().__init__()
# self.channels = channels
# self.out_channels = out_channels or channels
# self.use_conv = use_conv
# self.dims = dims
# if use_conv:
# self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
#
# def forward(self, x):
# assert x.shape[1] == self.channels
# if self.dims == 3:
# x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
# else:
# x = F.interpolate(x, scale_factor=2, mode="nearest")
# if self.use_conv:
# x = self.conv(x)
# return x
#
#
# class GradTTSUpsample(torch.nn.Module):
# def __init__(self, dim):
# super(Upsample, self).__init__()
# self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
#
# def forward(self, x):
# return self.conv(x)
#
#
# TODO (patil-suraj): needs test
# class Upsample1d(nn.Module):
# def __init__(self, dim):
# super().__init__()
# self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
#
# def forward(self, x):
# return self.conv(x)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
# RESNETS
class
GlideUpsample
(
nn
.
Module
):
# unet_glide.py & unet_ldm.py
class
ResBlock
(
TimestepBlock
):
"""
"""
An upsampling layer with an optional convolution.
A residual block that can optionally change the number of channels.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
upsampling occurs in the inner-two dimensions.
use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
downsampling.
"""
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
):
def
__init__
(
self
,
channels
,
emb_channels
,
dropout
,
out_channels
=
None
,
use_conv
=
False
,
use_scale_shift_norm
=
False
,
dims
=
2
,
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
self
.
emb_channels
=
emb_channels
self
.
dropout
=
dropout
self
.
out_channels
=
out_channels
or
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_conv
=
use_conv
self
.
dims
=
dims
self
.
use_checkpoint
=
use_checkpoint
if
use_conv
:
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
,
swish
=
1.0
),
nn
.
Identity
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
,
swish
=
0.0
if
use_scale_shift_norm
else
1.0
),
nn
.
SiLU
()
if
use_scale_shift_norm
else
nn
.
Identity
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
)
if
self
.
out_channels
==
channels
:
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
def
forward
(
self
,
x
,
emb
):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
self
.
h_upd
(
h
)
x
=
self
.
x_upd
(
x
)
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
# unet.py
if
self
.
dims
==
3
:
class
ResnetBlock
(
nn
.
Module
):
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
# unet_grad_tts.py
class
ResnetBlockGradTTS
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
time_emb_dim
,
groups
=
8
):
super
(
ResnetBlockGradTTS
,
self
).
__init__
()
self
.
mlp
=
torch
.
nn
.
Sequential
(
Mish
(),
torch
.
nn
.
Linear
(
time_emb_dim
,
dim_out
))
self
.
block1
=
Block
(
dim
,
dim_out
,
groups
=
groups
)
self
.
block2
=
Block
(
dim_out
,
dim_out
,
groups
=
groups
)
if
dim
!=
dim_out
:
self
.
res_conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
else
:
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
self
.
res_conv
=
torch
.
nn
.
Identity
()
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
def
forward
(
self
,
x
,
mask
,
time_emb
):
return
x
h
=
self
.
block1
(
x
,
mask
)
h
+=
self
.
mlp
(
time_emb
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
h
=
self
.
block2
(
h
,
mask
)
output
=
h
+
self
.
res_conv
(
x
*
mask
)
return
output
# unet_rl.py
class
ResidualTemporalBlock
(
nn
.
Module
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
super
().
__init__
()
self
.
blocks
=
nn
.
ModuleList
(
[
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
),
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
),
]
)
self
.
time_mlp
=
nn
.
Sequential
(
nn
.
Mish
(),
nn
.
Linear
(
embed_dim
,
out_channels
),
RearrangeDim
(),
# Rearrange("batch t -> batch t 1"),
)
self
.
residual_conv
=
(
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
)
def
forward
(
self
,
x
,
t
):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out
=
self
.
blocks
[
0
](
x
)
+
self
.
time_mlp
(
t
)
out
=
self
.
blocks
[
1
](
out
)
return
out
+
self
.
residual_conv
(
x
)
# unet_score_estimation.py
class
ResnetBlockBigGANpp
(
nn
.
Module
):
def
__init__
(
self
,
act
,
in_ch
,
out_ch
=
None
,
temb_dim
=
None
,
up
=
False
,
down
=
False
,
dropout
=
0.1
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
skip_rescale
=
True
,
init_scale
=
0.0
,
):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
up
=
up
self
.
down
=
down
self
.
fir
=
fir
self
.
fir_kernel
=
fir_kernel
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
if
in_ch
!=
out_ch
or
up
or
down
:
self
.
Conv_2
=
conv1x1
(
in_ch
,
out_ch
)
self
.
skip_rescale
=
skip_rescale
self
.
act
=
act
self
.
in_ch
=
in_ch
self
.
out_ch
=
out_ch
def
forward
(
self
,
x
,
temb
=
None
):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
if
self
.
up
:
if
self
.
fir
:
h
=
upsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
naive_upsample_2d
(
h
,
factor
=
2
)
x
=
naive_upsample_2d
(
x
,
factor
=
2
)
elif
self
.
down
:
if
self
.
fir
:
h
=
downsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
naive_downsample_2d
(
h
,
factor
=
2
)
x
=
naive_downsample_2d
(
x
,
factor
=
2
)
h
=
self
.
Conv_0
(
h
)
# Add bias to each feature map conditioned on the time embedding
if
temb
is
not
None
:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
h
=
self
.
Dropout_0
(
h
)
h
=
self
.
Conv_1
(
h
)
if
self
.
in_ch
!=
self
.
out_ch
or
self
.
up
or
self
.
down
:
x
=
self
.
Conv_2
(
x
)
if
not
self
.
skip_rescale
:
return
x
+
h
else
:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
# unet_score_estimation.py
class
ResnetBlockDDPMpp
(
nn
.
Module
):
"""ResBlock adapted from DDPM."""
def
__init__
(
self
,
act
,
in_ch
,
out_ch
=
None
,
temb_dim
=
None
,
conv_shortcut
=
False
,
dropout
=
0.1
,
skip_rescale
=
False
,
init_scale
=
0.0
,
):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
if
in_ch
!=
out_ch
:
if
conv_shortcut
:
self
.
Conv_2
=
conv3x3
(
in_ch
,
out_ch
)
else
:
self
.
NIN_0
=
NIN
(
in_ch
,
out_ch
)
self
.
skip_rescale
=
skip_rescale
self
.
act
=
act
self
.
out_ch
=
out_ch
self
.
conv_shortcut
=
conv_shortcut
def
forward
(
self
,
x
,
temb
=
None
):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
h
=
self
.
Conv_0
(
h
)
if
temb
is
not
None
:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
h
=
self
.
Dropout_0
(
h
)
h
=
self
.
Conv_1
(
h
)
if
x
.
shape
[
1
]
!=
self
.
out_ch
:
if
self
.
conv_shortcut
:
x
=
self
.
Conv_2
(
x
)
else
:
x
=
self
.
NIN_0
(
x
)
if
not
self
.
skip_rescale
:
return
x
+
h
else
:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
class
LDMUpsample
(
nn
.
Module
):
# HELPER Modules
def
normalization
(
channels
,
swish
=
0.0
):
"""
"""
An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param
Make a standard normalization layer, with an optional swish activation.
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D.
If 3D, then
:param channels: number of input channels. :return: an nn.Module for normalization.
upsampling occurs in the inner-two dimensions.
"""
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
class
GroupNorm32
(
nn
.
GroupNorm
):
def
__init__
(
self
,
num_groups
,
num_channels
,
swish
,
eps
=
1e-5
):
super
().
__init__
(
num_groups
=
num_groups
,
num_channels
=
num_channels
,
eps
=
eps
)
self
.
swish
=
swish
def
forward
(
self
,
x
):
y
=
super
().
forward
(
x
.
float
()).
to
(
x
.
dtype
)
if
self
.
swish
==
1.0
:
y
=
F
.
silu
(
y
)
elif
self
.
swish
:
y
=
y
*
F
.
sigmoid
(
y
*
float
(
self
.
swish
))
return
y
def
linear
(
*
args
,
**
kwargs
):
"""
Create a linear module.
"""
return
nn
.
Linear
(
*
args
,
**
kwargs
)
def
zero_module
(
module
):
"""
Zero out the parameters of a module and return it.
"""
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
class
Mish
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
x
))
class
Block
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
groups
=
8
):
super
(
Block
,
self
).
__init__
()
self
.
block
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
3
,
padding
=
1
),
torch
.
nn
.
GroupNorm
(
groups
,
dim_out
),
Mish
()
)
def
forward
(
self
,
x
,
mask
):
output
=
self
.
block
(
x
*
mask
)
return
output
*
mask
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
class
Conv1dBlock
(
nn
.
Module
):
"""
Conv1d --> GroupNorm --> Mish
"""
def
__init__
(
self
,
inp_channels
,
out_channels
,
kernel_size
,
n_groups
=
8
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
block
=
nn
.
Sequential
(
self
.
use_conv
=
use_conv
nn
.
Conv1d
(
inp_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
),
self
.
dims
=
dims
RearrangeDim
(),
if
use_conv
:
# Rearrange("batch channels horizon -> batch channels 1 horizon"),
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
)
nn
.
GroupNorm
(
n_groups
,
out_channels
),
RearrangeDim
(),
# Rearrange("batch channels 1 horizon -> batch channels horizon"),
nn
.
Mish
(),
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
return
self
.
block
(
x
)
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
class
RearrangeDim
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
tensor
):
if
len
(
tensor
.
shape
)
==
2
:
return
tensor
[:,
:,
None
]
if
len
(
tensor
.
shape
)
==
3
:
return
tensor
[:,
:,
None
,
:]
elif
len
(
tensor
.
shape
)
==
4
:
return
tensor
[:,
:,
0
,
:]
else
:
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
raise
ValueError
(
f
"`len(tensor)`:
{
len
(
tensor
)
}
has to be 2, 3 or 4."
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
"""1x1 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
"""3x3 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
def
default_init
(
scale
=
1.0
):
"""The same initialization used in DDPM."""
scale
=
1e-10
if
scale
==
0
else
scale
return
variance_scaling
(
scale
,
"fan_avg"
,
"uniform"
)
def
variance_scaling
(
scale
,
mode
,
distribution
,
in_axis
=
1
,
out_axis
=
0
,
dtype
=
torch
.
float32
,
device
=
"cpu"
):
"""Ported from JAX."""
def
_compute_fans
(
shape
,
in_axis
=
1
,
out_axis
=
0
):
receptive_field_size
=
np
.
prod
(
shape
)
/
shape
[
in_axis
]
/
shape
[
out_axis
]
fan_in
=
shape
[
in_axis
]
*
receptive_field_size
fan_out
=
shape
[
out_axis
]
*
receptive_field_size
return
fan_in
,
fan_out
def
init
(
shape
,
dtype
=
dtype
,
device
=
device
):
fan_in
,
fan_out
=
_compute_fans
(
shape
,
in_axis
,
out_axis
)
if
mode
==
"fan_in"
:
denominator
=
fan_in
elif
mode
==
"fan_out"
:
denominator
=
fan_out
elif
mode
==
"fan_avg"
:
denominator
=
(
fan_in
+
fan_out
)
/
2
else
:
raise
ValueError
(
"invalid mode for variance scaling initializer: {}"
.
format
(
mode
))
variance
=
scale
/
denominator
if
distribution
==
"normal"
:
return
torch
.
randn
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
np
.
sqrt
(
variance
)
elif
distribution
==
"uniform"
:
return
(
torch
.
rand
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
2.0
-
1.0
)
*
np
.
sqrt
(
3
*
variance
)
else
:
raise
ValueError
(
"invalid distribution for variance scaling initializer"
)
return
init
class
GradTTSUpsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
Upsample
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
ConvTranspose2d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)
):
return
self
.
conv
(
x
)
return
upfirdn2d_native
(
input
,
kernel
,
up
,
up
,
down
,
down
,
pad
[
0
],
pad
[
1
],
pad
[
0
],
pad
[
1
]
)
# TODO (patil-suraj): needs test
def
upfirdn2d_native
(
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
):
class
Upsample1d
(
nn
.
Module
):
_
,
channel
,
in_h
,
in_w
=
input
.
shape
def
__init__
(
self
,
dim
):
input
=
input
.
reshape
(
-
1
,
in_h
,
in_w
,
1
)
_
,
in_h
,
in_w
,
minor
=
input
.
shape
kernel_h
,
kernel_w
=
kernel
.
shape
out
=
input
.
view
(
-
1
,
in_h
,
1
,
in_w
,
1
,
minor
)
out
=
F
.
pad
(
out
,
[
0
,
0
,
0
,
up_x
-
1
,
0
,
0
,
0
,
up_y
-
1
])
out
=
out
.
view
(
-
1
,
in_h
*
up_y
,
in_w
*
up_x
,
minor
)
out
=
F
.
pad
(
out
,
[
0
,
0
,
max
(
pad_x0
,
0
),
max
(
pad_x1
,
0
),
max
(
pad_y0
,
0
),
max
(
pad_y1
,
0
)])
out
=
out
[
:,
max
(
-
pad_y0
,
0
)
:
out
.
shape
[
1
]
-
max
(
-
pad_y1
,
0
),
max
(
-
pad_x0
,
0
)
:
out
.
shape
[
2
]
-
max
(
-
pad_x1
,
0
),
:,
]
out
=
out
.
permute
(
0
,
3
,
1
,
2
)
out
=
out
.
reshape
([
-
1
,
1
,
in_h
*
up_y
+
pad_y0
+
pad_y1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
])
w
=
torch
.
flip
(
kernel
,
[
0
,
1
]).
view
(
1
,
1
,
kernel_h
,
kernel_w
)
out
=
F
.
conv2d
(
out
,
w
)
out
=
out
.
reshape
(
-
1
,
minor
,
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
+
1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
+
1
,
)
out
=
out
.
permute
(
0
,
2
,
3
,
1
)
out
=
out
[:,
::
down_y
,
::
down_x
,
:]
out_h
=
(
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
)
//
down_y
+
1
out_w
=
(
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
)
//
down_x
+
1
return
out
.
view
(
-
1
,
channel
,
out_h
,
out_w
)
def
upsample_2d
(
x
,
k
=
None
,
factor
=
2
,
gain
=
1
):
r
"""Upsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
multiple of the upsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
if
k
is
None
:
k
=
[
1
]
*
factor
k
=
_setup_kernel
(
k
)
*
(
gain
*
(
factor
**
2
))
p
=
k
.
shape
[
0
]
-
factor
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
up
=
factor
,
pad
=
((
p
+
1
)
//
2
+
factor
-
1
,
p
//
2
))
def
downsample_2d
(
x
,
k
=
None
,
factor
=
2
,
gain
=
1
):
r
"""Downsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
if
k
is
None
:
k
=
[
1
]
*
factor
k
=
_setup_kernel
(
k
)
*
gain
p
=
k
.
shape
[
0
]
-
factor
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
def
naive_upsample_2d
(
x
,
factor
=
2
):
_N
,
C
,
H
,
W
=
x
.
shape
x
=
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
,
1
,
W
,
1
))
x
=
x
.
repeat
(
1
,
1
,
1
,
factor
,
1
,
factor
)
return
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
*
factor
,
W
*
factor
))
def
naive_downsample_2d
(
x
,
factor
=
2
):
_N
,
C
,
H
,
W
=
x
.
shape
x
=
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
//
factor
,
factor
,
W
//
factor
,
factor
))
return
torch
.
mean
(
x
,
dim
=
(
3
,
5
))
class
NIN
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
num_units
,
init_scale
=
0.1
):
super
().
__init__
()
super
().
__init__
()
self
.
conv
=
nn
.
ConvTranspose1d
(
dim
,
dim
,
4
,
2
,
1
)
self
.
W
=
nn
.
Parameter
(
default_init
(
scale
=
init_scale
)((
in_dim
,
num_units
)),
requires_grad
=
True
)
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
num_units
),
requires_grad
=
True
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
y
=
contract_inner
(
x
,
self
.
W
)
+
self
.
b
return
y
.
permute
(
0
,
3
,
1
,
2
)
# class ResnetBlock(nn.Module):
# def __init__(
# self,
def
_setup_kernel
(
k
):
# *,
k
=
np
.
asarray
(
k
,
dtype
=
np
.
float32
)
# in_channels,
if
k
.
ndim
==
1
:
# out_channels=None,
k
=
np
.
outer
(
k
,
k
)
# conv_shortcut=False,
k
/=
np
.
sum
(
k
)
# dropout,
assert
k
.
ndim
==
2
# temb_channels=512,
assert
k
.
shape
[
0
]
==
k
.
shape
[
1
]
# use_scale_shift_norm=False,
return
k
# ):
# super().__init__()
# self.in_channels = in_channels
def
contract_inner
(
x
,
y
):
# out_channels = in_channels if out_channels is None else out_channels
"""tensordot(x, y, 1)."""
# self.out_channels = out_channels
x_chars
=
list
(
string
.
ascii_lowercase
[:
len
(
x
.
shape
)])
# self.use_conv_shortcut = conv_shortcut
y_chars
=
list
(
string
.
ascii_lowercase
[
len
(
x
.
shape
)
:
len
(
y
.
shape
)
+
len
(
x
.
shape
)])
# self.use_scale_shift_norm = use_scale_shift_norm
y_chars
[
0
]
=
x_chars
[
-
1
]
# first axis of y and last of x get summed
out_chars
=
x_chars
[:
-
1
]
+
y_chars
[
1
:]
# self.norm1 = Normalize(in_channels)
return
_einsum
(
x_chars
,
y_chars
,
out_chars
,
x
,
y
)
# self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# temp_out_channles = 2 * out_channels if use_scale_shift_norm else out_channels
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
# self.temb_proj = torch.nn.Linear(temb_channels, temp_out_channles)
einsum_str
=
"{},{}->{}"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
# self.norm2 = Normalize(out_channels)
# self.dropout = torch.nn.Dropout(dropout)
# self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# else:
# self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
# def forward(self, x, temb):
# h = x
# h = self.norm1(h)
# h = nonlinearity(h)
# h = self.conv1(h)
# # TODO: check if this broadcasting works correctly for 1D and 3D
# temb = self.temb_proj(nonlinearity(temb))[:, :, None, None]
# if self.use_scale_shift_norm:
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
# scale, shift = torch.chunk(temb, 2, dim=1)
# h = self.norm2(h) * (1 + scale) + shift
# h = out_rest(h)
# else:
# h = h + temb
# h = self.norm2(h)
# h = nonlinearity(h)
# h = self.dropout(h)
# h = self.conv2(h)
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# x = self.conv_shortcut(x)
# else:
# x = self.nin_shortcut(x)
# return x + h
src/diffusers/models/unet.py
View file @
ebf3717c
...
@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock
,
Upsample
def
nonlinearity
(
x
):
def
nonlinearity
(
x
):
...
@@ -34,46 +34,46 @@ def Normalize(in_channels):
...
@@ -34,46 +34,46 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
ResnetBlock
(
nn
.
Module
):
#
class ResnetBlock(nn.Module):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
#
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super
().
__init__
()
#
super().__init__()
self
.
in_channels
=
in_channels
#
self.in_channels = in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
#
out_channels = in_channels if out_channels is None else out_channels
self
.
out_channels
=
out_channels
#
self.out_channels = out_channels
self
.
use_conv_shortcut
=
conv_shortcut
#
self.use_conv_shortcut = conv_shortcut
#
self
.
norm1
=
Normalize
(
in_channels
)
#
self.norm1 = Normalize(in_channels)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
#
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
#
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self
.
norm2
=
Normalize
(
out_channels
)
#
self.norm2 = Normalize(out_channels)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
#
self.dropout = torch.nn.Dropout(dropout)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
#
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if
self
.
in_channels
!=
self
.
out_channels
:
#
if self.in_channels != self.out_channels:
if
self
.
use_conv_shortcut
:
#
if self.use_conv_shortcut:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
#
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else
:
#
else:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
#
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
#
def
forward
(
self
,
x
,
temb
):
#
def forward(self, x, temb):
h
=
x
#
h = x
h
=
self
.
norm1
(
h
)
#
h = self.norm1(h)
h
=
nonlinearity
(
h
)
#
h = nonlinearity(h)
h
=
self
.
conv1
(
h
)
#
h = self.conv1(h)
#
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
#
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
#
h
=
self
.
norm2
(
h
)
#
h = self.norm2(h)
h
=
nonlinearity
(
h
)
#
h = nonlinearity(h)
h
=
self
.
dropout
(
h
)
#
h = self.dropout(h)
h
=
self
.
conv2
(
h
)
#
h = self.conv2(h)
#
if
self
.
in_channels
!=
self
.
out_channels
:
#
if self.in_channels != self.out_channels:
if
self
.
use_conv_shortcut
:
#
if self.use_conv_shortcut:
x
=
self
.
conv_shortcut
(
x
)
#
x = self.conv_shortcut(x)
else
:
#
else:
x
=
self
.
nin_shortcut
(
x
)
#
x = self.nin_shortcut(x)
#
return
x
+
h
#
return x + h
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
...
@@ -127,7 +127,6 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -127,7 +127,6 @@ class UNetModel(ModelMixin, ConfigMixin):
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
attn_2
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
for
i_block
in
range
(
self
.
num_res_blocks
):
...
@@ -142,7 +141,6 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -142,7 +141,6 @@ class UNetModel(ModelMixin, ConfigMixin):
down
=
nn
.
Module
()
down
=
nn
.
Module
()
down
.
block
=
block
down
.
block
=
block
down
.
attn
=
attn
down
.
attn
=
attn
down
.
attn_2
=
attn_2
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
down
.
downsample
=
Downsample
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
curr_res
=
curr_res
//
2
curr_res
=
curr_res
//
2
...
...
src/diffusers/models/unet_glide.py
View file @
ebf3717c
...
@@ -8,7 +8,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -8,7 +8,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
...
@@ -96,16 +96,14 @@ def zero_module(module):
...
@@ -96,16 +96,14 @@ def zero_module(module):
return
module
return
module
class
TimestepBlock
(
nn
.
Module
):
# class TimestepBlock(nn.Module):
"""
# """
Any module where forward() takes timestep embeddings as a second argument.
# Any module where forward() takes timestep embeddings as a second argument. #"""
"""
#
# @abstractmethod
@
abstractmethod
# def forward(self, x, emb):
def
forward
(
self
,
x
,
emb
):
# """
"""
# Apply the module to `x` given `emb` timestep embeddings. #"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
...
@@ -124,106 +122,99 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -124,106 +122,99 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
return
x
class
ResBlock
(
TimestepBlock
):
# class ResBlock(TimestepBlock):
"""
# """
A residual block that can optionally change the number of channels.
# A residual block that can optionally change the number of channels. # # :param channels: the number of input
channels
.
:
param
emb_channels
:
the
number
of
timestep
embedding
channels
.
# :param dropout: the rate of dropout. :param
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
out_channels
:
if
specified
,
the
number
of
out
channels
.
:
param
# use_conv: if True and out_channels is specified, use a
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
spatial
# convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
use_conv: if True and out_channels is specified, use a spatial
dims
:
determines
if
the
signal
is
1
D
,
2
D
,
or
3
D
.
:
param
use_checkpoint
:
if
True
,
use
gradient
checkpointing
# on this
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
module
.
:
param
up
:
if
True
,
use
this
block
for
upsampling
.
:
param
down
:
if
True
,
use
this
block
for
# downsampling. #"""
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
#
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
# def __init__(
downsampling.
# self,
"""
# channels,
# emb_channels,
def
__init__
(
# dropout,
self
,
# out_channels=None,
channels
,
# use_conv=False,
emb_channels
,
# use_scale_shift_norm=False,
dropout
,
# dims=2,
out_channels
=
None
,
# use_checkpoint=False,
use_conv
=
False
,
# up=False,
use_scale_shift_norm
=
False
,
# down=False,
dims
=
2
,
# ):
use_checkpoint
=
False
,
# super().__init__()
up
=
False
,
# self.channels = channels
down
=
False
,
# self.emb_channels = emb_channels
):
# self.dropout = dropout
super
().
__init__
()
# self.out_channels = out_channels or channels
self
.
channels
=
channels
# self.use_conv = use_conv
self
.
emb_channels
=
emb_channels
# self.use_checkpoint = use_checkpoint
self
.
dropout
=
dropout
# self.use_scale_shift_norm = use_scale_shift_norm
self
.
out_channels
=
out_channels
or
channels
#
self
.
use_conv
=
use_conv
# self.in_layers = nn.Sequential(
self
.
use_checkpoint
=
use_checkpoint
# normalization(channels, swish=1.0),
self
.
use_scale_shift_norm
=
use_scale_shift_norm
# nn.Identity(),
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
self
.
in_layers
=
nn
.
Sequential
(
# )
normalization
(
channels
,
swish
=
1.0
),
#
nn
.
Identity
(),
# self.updown = up or down
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
#
)
# if up:
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self
.
updown
=
up
or
down
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
# elif down:
if
up
:
# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
# else:
elif
down
:
# self.h_upd = self.x_upd = nn.Identity()
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
#
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
# self.emb_layers = nn.Sequential(
else
:
# nn.SiLU(),
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
# linear(
# emb_channels,
self
.
emb_layers
=
nn
.
Sequential
(
# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
nn
.
SiLU
(),
# ),
linear
(
# )
emb_channels
,
# self.out_layers = nn.Sequential(
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
# normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
),
# nn.SiLU() if use_scale_shift_norm else nn.Identity(),
)
# nn.Dropout(p=dropout),
self
.
out_layers
=
nn
.
Sequential
(
# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
normalization
(
self
.
out_channels
,
swish
=
0.0
if
use_scale_shift_norm
else
1.0
),
# )
nn
.
SiLU
()
if
use_scale_shift_norm
else
nn
.
Identity
(),
#
nn
.
Dropout
(
p
=
dropout
),
# if self.out_channels == channels:
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
# self.skip_connection = nn.Identity()
)
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
if
self
.
out_channels
==
channels
:
# else:
self
.
skip_connection
=
nn
.
Identity
()
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
elif
use_conv
:
#
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
# def forward(self, x, emb):
else
:
# """
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
# Apply the block to a Tensor, conditioned on a timestep embedding. # # :param x: an [N x C x ...] Tensor of features.
:
param
emb
:
an
[
N
x
emb_channels
]
Tensor
of
timestep
embeddings
.
# :return: an [N x C x ...] Tensor of outputs. #"""
def
forward
(
self
,
x
,
emb
):
# if self.updown:
"""
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
Apply the block to a Tensor, conditioned on a timestep embedding.
# h = in_rest(x)
# h = self.h_upd(h)
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
# x = self.x_upd(x)
:return: an [N x C x ...] Tensor of outputs.
# h = in_conv(h)
"""
# else:
if
self
.
updown
:
# h = self.in_layers(x)
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
# emb_out = self.emb_layers(emb).type(h.dtype)
h
=
in_rest
(
x
)
# while len(emb_out.shape) < len(h.shape):
h
=
self
.
h_upd
(
h
)
# emb_out = emb_out[..., None]
x
=
self
.
x_upd
(
x
)
# if self.use_scale_shift_norm:
h
=
in_conv
(
h
)
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
else
:
# scale, shift = torch.chunk(emb_out, 2, dim=1)
h
=
self
.
in_layers
(
x
)
# h = out_norm(h) * (1 + scale) + shift
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
# h = out_rest(h)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
# else:
emb_out
=
emb_out
[...,
None
]
# h = h + emb_out
if
self
.
use_scale_shift_norm
:
# h = self.out_layers(h)
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
# return self.skip_connection(x) + h
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
class
GlideUNetModel
(
ModelMixin
,
ConfigMixin
):
class
GlideUNetModel
(
ModelMixin
,
ConfigMixin
):
...
...
src/diffusers/models/unet_grad_tts.py
View file @
ebf3717c
...
@@ -4,7 +4,9 @@ from ..configuration_utils import ConfigMixin
...
@@ -4,7 +4,9 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
LinearAttention
from
.attention
import
LinearAttention
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
from
.resnet
import
ResnetBlockGradTTS
as
ResnetBlock
from
.resnet
import
Upsample
class
Mish
(
torch
.
nn
.
Module
):
class
Mish
(
torch
.
nn
.
Module
):
...
@@ -34,24 +36,24 @@ class Block(torch.nn.Module):
...
@@ -34,24 +36,24 @@ class Block(torch.nn.Module):
return
output
*
mask
return
output
*
mask
class
ResnetBlock
(
torch
.
nn
.
Module
):
#
class ResnetBlock(torch.nn.Module):
def
__init__
(
self
,
dim
,
dim_out
,
time_emb_dim
,
groups
=
8
):
#
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
super
(
ResnetBlock
,
self
).
__init__
()
#
super(ResnetBlock, self).__init__()
self
.
mlp
=
torch
.
nn
.
Sequential
(
Mish
(),
torch
.
nn
.
Linear
(
time_emb_dim
,
dim_out
))
#
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
#
self
.
block1
=
Block
(
dim
,
dim_out
,
groups
=
groups
)
#
self.block1 = Block(dim, dim_out, groups=groups)
self
.
block2
=
Block
(
dim_out
,
dim_out
,
groups
=
groups
)
#
self.block2 = Block(dim_out, dim_out, groups=groups)
if
dim
!=
dim_out
:
#
if dim != dim_out:
self
.
res_conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
#
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else
:
#
else:
self
.
res_conv
=
torch
.
nn
.
Identity
()
#
self.res_conv = torch.nn.Identity()
#
def
forward
(
self
,
x
,
mask
,
time_emb
):
#
def forward(self, x, mask, time_emb):
h
=
self
.
block1
(
x
,
mask
)
#
h = self.block1(x, mask)
h
+=
self
.
mlp
(
time_emb
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
#
h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
h
=
self
.
block2
(
h
,
mask
)
#
h = self.block2(h, mask)
output
=
h
+
self
.
res_conv
(
x
*
mask
)
#
output = h + self.res_conv(x * mask)
return
output
#
return output
class
Residual
(
torch
.
nn
.
Module
):
class
Residual
(
torch
.
nn
.
Module
):
...
...
src/diffusers/models/unet_ldm.py
View file @
ebf3717c
...
@@ -11,7 +11,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -11,7 +11,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
def
exists
(
val
):
def
exists
(
val
):
...
@@ -359,16 +359,14 @@ class AttentionPool2d(nn.Module):
...
@@ -359,16 +359,14 @@ class AttentionPool2d(nn.Module):
return
x
[:,
:,
0
]
return
x
[:,
:,
0
]
class
TimestepBlock
(
nn
.
Module
):
# class TimestepBlock(nn.Module):
"""
# """
Any module where forward() takes timestep embeddings as a second argument.
# Any module where forward() takes timestep embeddings as a second argument. #"""
"""
#
# @abstractmethod
@
abstractmethod
# def forward(self, x, emb):
def
forward
(
self
,
x
,
emb
):
# """
"""
# Apply the module to `x` given `emb` timestep embeddings. #"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
...
@@ -387,99 +385,97 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -387,99 +385,97 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
return
x
class
ResBlock
(
TimestepBlock
):
# class A_ResBlock(TimestepBlock):
"""
# """
A residual block that can optionally change the number of channels. :param channels: the number of input channels.
# A residual block that can optionally change the number of channels. :param channels: the number of input channels. #
:param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param
:
param
emb_channels
:
the
number
of
timestep
embedding
channels
.
:
param
dropout
:
the
rate
of
dropout
.
:
param
#
out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use
out_channels
:
if
specified
,
the
number
of
out
channels
.
:
param
use_conv
:
if
True
and
out_channels
is
specified
,
use
# a
a spatial
spatial
# convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
dims
:
determines
if
the
signal
is
1
D
,
2
D
,
or
3
D
.
:
param
use_checkpoint
:
if
True
,
use
gradient
checkpointing
# on this
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
module
.
:
param
up
:
if
True
,
use
this
block
for
upsampling
.
:
param
down
:
if
True
,
use
this
block
for
# downsampling. #"""
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
#
downsampling.
# def __init__(
"""
# self,
# channels,
def
__init__
(
# emb_channels,
self
,
# dropout,
channels
,
# out_channels=None,
emb_channels
,
# use_conv=False,
dropout
,
# use_scale_shift_norm=False,
out_channels
=
None
,
# dims=2,
use_conv
=
False
,
# use_checkpoint=False,
use_scale_shift_norm
=
False
,
# up=False,
dims
=
2
,
# down=False,
use_checkpoint
=
False
,
# ):
up
=
False
,
# super().__init__()
down
=
False
,
# self.channels = channels
):
# self.emb_channels = emb_channels
super
().
__init__
()
# self.dropout = dropout
self
.
channels
=
channels
# self.out_channels = out_channels or channels
self
.
emb_channels
=
emb_channels
# self.use_conv = use_conv
self
.
dropout
=
dropout
# self.use_checkpoint = use_checkpoint
self
.
out_channels
=
out_channels
or
channels
# self.use_scale_shift_norm = use_scale_shift_norm
self
.
use_conv
=
use_conv
#
self
.
use_checkpoint
=
use_checkpoint
# self.in_layers = nn.Sequential(
self
.
use_scale_shift_norm
=
use_scale_shift_norm
# normalization(channels),
# nn.SiLU(),
self
.
in_layers
=
nn
.
Sequential
(
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
normalization
(
channels
),
# )
nn
.
SiLU
(),
#
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
# self.updown = up or down
)
#
# if up:
self
.
updown
=
up
or
down
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
if
up
:
# elif down:
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
elif
down
:
# else:
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
# self.h_upd = self.x_upd = nn.Identity()
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
#
else
:
# self.emb_layers = nn.Sequential(
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
# nn.SiLU(),
# linear(
self
.
emb_layers
=
nn
.
Sequential
(
# emb_channels,
nn
.
SiLU
(),
# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
linear
(
# ),
emb_channels
,
# )
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
# self.out_layers = nn.Sequential(
),
# normalization(self.out_channels),
)
# nn.SiLU(),
self
.
out_layers
=
nn
.
Sequential
(
# nn.Dropout(p=dropout),
normalization
(
self
.
out_channels
),
# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
nn
.
SiLU
(),
# )
nn
.
Dropout
(
p
=
dropout
),
#
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
# if self.out_channels == channels:
)
# self.skip_connection = nn.Identity()
# elif use_conv:
if
self
.
out_channels
==
channels
:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
self
.
skip_connection
=
nn
.
Identity
()
# else:
elif
use_conv
:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
#
else
:
# def forward(self, x, emb):
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
# if self.updown:
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
def
forward
(
self
,
x
,
emb
):
# h = in_rest(x)
if
self
.
updown
:
# h = self.h_upd(h)
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
# x = self.x_upd(x)
h
=
in_rest
(
x
)
# h = in_conv(h)
h
=
self
.
h_upd
(
h
)
# else:
x
=
self
.
x_upd
(
x
)
# h = self.in_layers(x)
h
=
in_conv
(
h
)
# emb_out = self.emb_layers(emb).type(h.dtype)
else
:
# while len(emb_out.shape) < len(h.shape):
h
=
self
.
in_layers
(
x
)
# emb_out = emb_out[..., None]
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
# if self.use_scale_shift_norm:
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
emb_out
=
emb_out
[...,
None
]
# scale, shift = torch.chunk(emb_out, 2, dim=1)
if
self
.
use_scale_shift_norm
:
# h = out_norm(h) * (1 + scale) + shift
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
# h = out_rest(h)
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
# else:
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
# h = h + emb_out
h
=
out_rest
(
h
)
# h = self.out_layers(h)
else
:
# return self.skip_connection(x) + h
h
=
h
+
emb_out
#
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
class
QKVAttention
(
nn
.
Module
):
class
QKVAttention
(
nn
.
Module
):
...
...
src/diffusers/models/unet_rl.py
View file @
ebf3717c
...
@@ -6,6 +6,7 @@ import torch.nn as nn
...
@@ -6,6 +6,7 @@ import torch.nn as nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
ResidualTemporalBlock
class
SinusoidalPosEmb
(
nn
.
Module
):
class
SinusoidalPosEmb
(
nn
.
Module
):
...
@@ -72,36 +73,35 @@ class Conv1dBlock(nn.Module):
...
@@ -72,36 +73,35 @@ class Conv1dBlock(nn.Module):
return
self
.
block
(
x
)
return
self
.
block
(
x
)
class
ResidualTemporalBlock
(
nn
.
Module
):
# class ResidualTemporalBlock(nn.Module):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
# def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super
().
__init__
()
# super().__init__()
#
self
.
blocks
=
nn
.
ModuleList
(
# self.blocks = nn.ModuleList(
[
# [
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
),
# Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
),
# Conv1dBlock(out_channels, out_channels, kernel_size),
]
# ]
)
# )
#
self
.
time_mlp
=
nn
.
Sequential
(
# self.time_mlp = nn.Sequential(
nn
.
Mish
(),
# nn.Mish(),
nn
.
Linear
(
embed_dim
,
out_channels
),
# nn.Linear(embed_dim, out_channels),
RearrangeDim
(),
# RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
# Rearrange("batch t -> batch t 1"),
)
# )
#
self
.
residual_conv
=
(
# self.residual_conv = (
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
# nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
# )
#
def
forward
(
self
,
x
,
t
):
# def forward(self, x, t):
"""
# """
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
# x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x #
out_channels x horizon ]
out_channels
x
horizon
]
#"""
"""
# out = self.blocks[0](x) + self.time_mlp(t)
out
=
self
.
blocks
[
0
](
x
)
+
self
.
time_mlp
(
t
)
# out = self.blocks[1](out)
out
=
self
.
blocks
[
1
](
out
)
# return out + self.residual_conv(x)
return
out
+
self
.
residual_conv
(
x
)
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
ebf3717c
...
@@ -28,6 +28,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -28,6 +28,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
ResnetBlockBigGANpp
,
ResnetBlockDDPMpp
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
...
@@ -299,7 +300,7 @@ def downsample_2d(x, k=None, factor=2, gain=1):
...
@@ -299,7 +300,7 @@ def downsample_2d(x, k=None, factor=2, gain=1):
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
def
ddpm_
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
"""1x1 convolution with DDPM initialization."""
"""1x1 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
...
@@ -307,7 +308,7 @@ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, pad
...
@@ -307,7 +308,7 @@ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, pad
return
conv
return
conv
def
ddpm_
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
"""3x3 convolution with DDPM initialization."""
"""3x3 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
...
@@ -317,10 +318,6 @@ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_sc
...
@@ -317,10 +318,6 @@ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_sc
return
conv
return
conv
conv1x1
=
ddpm_conv1x1
conv3x3
=
ddpm_conv3x3
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
einsum_str
=
"{},{}->{}"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
einsum_str
=
"{},{}->{}"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
...
@@ -494,135 +491,135 @@ class Downsample(nn.Module):
...
@@ -494,135 +491,135 @@ class Downsample(nn.Module):
return
x
return
x
class
ResnetBlockDDPMpp
(
nn
.
Module
):
#
class ResnetBlockDDPMpp(nn.Module):
"""ResBlock adapted from DDPM."""
#
"""ResBlock adapted from DDPM."""
#
def
__init__
(
#
def __init__(
self
,
#
self,
act
,
#
act,
in_ch
,
#
in_ch,
out_ch
=
None
,
#
out_ch=None,
temb_dim
=
None
,
#
temb_dim=None,
conv_shortcut
=
False
,
#
conv_shortcut=False,
dropout
=
0.1
,
#
dropout=0.1,
skip_rescale
=
False
,
#
skip_rescale=False,
init_scale
=
0.0
,
#
init_scale=0.0,
):
#
):
super
().
__init__
()
#
super().__init__()
out_ch
=
out_ch
if
out_ch
else
in_ch
#
out_ch = out_ch if out_ch else in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
#
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
#
self.Conv_0 = conv3x3(in_ch, out_ch)
if
temb_dim
is
not
None
:
#
if temb_dim is not None:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
#
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
data
.
shape
)
#
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
#
nn.init.zeros_(self.Dense_0.bias)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
#
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
#
self.Dropout_0 = nn.Dropout(dropout)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
#
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if
in_ch
!=
out_ch
:
#
if in_ch != out_ch:
if
conv_shortcut
:
#
if conv_shortcut:
self
.
Conv_2
=
conv3x3
(
in_ch
,
out_ch
)
#
self.Conv_2 = conv3x3(in_ch, out_ch)
else
:
#
else:
self
.
NIN_0
=
NIN
(
in_ch
,
out_ch
)
#
self.NIN_0 = NIN(in_ch, out_ch)
#
self
.
skip_rescale
=
skip_rescale
#
self.skip_rescale = skip_rescale
self
.
act
=
act
#
self.act = act
self
.
out_ch
=
out_ch
#
self.out_ch = out_ch
self
.
conv_shortcut
=
conv_shortcut
#
self.conv_shortcut = conv_shortcut
#
def
forward
(
self
,
x
,
temb
=
None
):
#
def forward(self, x, temb=None):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
#
h = self.act(self.GroupNorm_0(x))
h
=
self
.
Conv_0
(
h
)
#
h = self.Conv_0(h)
if
temb
is
not
None
:
#
if temb is not None:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
#
h += self.Dense_0(self.act(temb))[:, :, None, None]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
#
h = self.act(self.GroupNorm_1(h))
h
=
self
.
Dropout_0
(
h
)
#
h = self.Dropout_0(h)
h
=
self
.
Conv_1
(
h
)
#
h = self.Conv_1(h)
if
x
.
shape
[
1
]
!=
self
.
out_ch
:
#
if x.shape[1] != self.out_ch:
if
self
.
conv_shortcut
:
#
if self.conv_shortcut:
x
=
self
.
Conv_2
(
x
)
#
x = self.Conv_2(x)
else
:
#
else:
x
=
self
.
NIN_0
(
x
)
#
x = self.NIN_0(x)
if
not
self
.
skip_rescale
:
#
if not self.skip_rescale:
return
x
+
h
#
return x + h
else
:
#
else:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
#
return (x + h) / np.sqrt(2.0)
class
ResnetBlockBigGANpp
(
nn
.
Module
):
#
class ResnetBlockBigGANpp(nn.Module):
def
__init__
(
#
def __init__(
self
,
#
self,
act
,
#
act,
in_ch
,
#
in_ch,
out_ch
=
None
,
#
out_ch=None,
temb_dim
=
None
,
#
temb_dim=None,
up
=
False
,
#
up=False,
down
=
False
,
#
down=False,
dropout
=
0.1
,
#
dropout=0.1,
fir
=
False
,
#
fir=False,
fir_kernel
=
(
1
,
3
,
3
,
1
),
#
fir_kernel=(1, 3, 3, 1),
skip_rescale
=
True
,
#
skip_rescale=True,
init_scale
=
0.0
,
#
init_scale=0.0,
):
#
):
super
().
__init__
()
#
super().__init__()
#
out_ch
=
out_ch
if
out_ch
else
in_ch
#
out_ch = out_ch if out_ch else in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
#
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self
.
up
=
up
#
self.up = up
self
.
down
=
down
#
self.down = down
self
.
fir
=
fir
#
self.fir = fir
self
.
fir_kernel
=
fir_kernel
#
self.fir_kernel = fir_kernel
#
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
#
self.Conv_0 = conv3x3(in_ch, out_ch)
if
temb_dim
is
not
None
:
#
if temb_dim is not None:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
#
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
shape
)
#
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
#
nn.init.zeros_(self.Dense_0.bias)
#
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
#
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
#
self.Dropout_0 = nn.Dropout(dropout)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
#
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if
in_ch
!=
out_ch
or
up
or
down
:
#
if in_ch != out_ch or up or down:
self
.
Conv_2
=
conv1x1
(
in_ch
,
out_ch
)
#
self.Conv_2 = conv1x1(in_ch, out_ch)
#
self
.
skip_rescale
=
skip_rescale
#
self.skip_rescale = skip_rescale
self
.
act
=
act
#
self.act = act
self
.
in_ch
=
in_ch
#
self.in_ch = in_ch
self
.
out_ch
=
out_ch
#
self.out_ch = out_ch
#
def
forward
(
self
,
x
,
temb
=
None
):
#
def forward(self, x, temb=None):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
#
h = self.act(self.GroupNorm_0(x))
#
if
self
.
up
:
#
if self.up:
if
self
.
fir
:
#
if self.fir:
h
=
upsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
#
h = upsample_2d(h, self.fir_kernel, factor=2)
x
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
#
x = upsample_2d(x, self.fir_kernel, factor=2)
else
:
#
else:
h
=
naive_upsample_2d
(
h
,
factor
=
2
)
#
h = naive_upsample_2d(h, factor=2)
x
=
naive_upsample_2d
(
x
,
factor
=
2
)
#
x = naive_upsample_2d(x, factor=2)
elif
self
.
down
:
#
elif self.down:
if
self
.
fir
:
#
if self.fir:
h
=
downsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
#
h = downsample_2d(h, self.fir_kernel, factor=2)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
#
x = downsample_2d(x, self.fir_kernel, factor=2)
else
:
#
else:
h
=
naive_downsample_2d
(
h
,
factor
=
2
)
#
h = naive_downsample_2d(h, factor=2)
x
=
naive_downsample_2d
(
x
,
factor
=
2
)
#
x = naive_downsample_2d(x, factor=2)
#
h
=
self
.
Conv_0
(
h
)
#
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
# Add bias to each feature map conditioned on the time embedding
if
temb
is
not
None
:
#
if temb is not None:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
#
h += self.Dense_0(self.act(temb))[:, :, None, None]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
#
h = self.act(self.GroupNorm_1(h))
h
=
self
.
Dropout_0
(
h
)
#
h = self.Dropout_0(h)
h
=
self
.
Conv_1
(
h
)
#
h = self.Conv_1(h)
#
if
self
.
in_ch
!=
self
.
out_ch
or
self
.
up
or
self
.
down
:
#
if self.in_ch != self.out_ch or self.up or self.down:
x
=
self
.
Conv_2
(
x
)
#
x = self.Conv_2(x)
#
if
not
self
.
skip_rescale
:
#
if not self.skip_rescale:
return
x
+
h
#
return x + h
else
:
#
else:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
#
return (x + h) / np.sqrt(2.0)
class
NCSNpp
(
ModelMixin
,
ConfigMixin
):
class
NCSNpp
(
ModelMixin
,
ConfigMixin
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment