Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
1899457b
Unverified
Commit
1899457b
authored
Jun 29, 2022
by
Patrick von Platen
Committed by
GitHub
Jun 29, 2022
Browse files
Merge pull request #40 from huggingface/start_resnet_unificiation
resnet in one file
parents
e5d9baf0
ebf3717c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1123 additions
and
564 deletions
+1123
-564
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+697
-122
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+41
-43
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+102
-111
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+21
-19
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+100
-104
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+30
-30
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+132
-135
No files found.
src/diffusers/models/resnet.py
View file @
1899457b
import
string
from
abc
import
abstractmethod
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -54,6 +58,18 @@ def nonlinearity(x, swish=1.0):
...
@@ -54,6 +58,18 @@ def nonlinearity(x, swish=1.0):
return
x
*
F
.
sigmoid
(
x
*
float
(
swish
))
return
x
*
F
.
sigmoid
(
x
*
float
(
swish
))
class
TimestepBlock
(
nn
.
Module
):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@
abstractmethod
def
forward
(
self
,
x
,
emb
):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class
Upsample
(
nn
.
Module
):
class
Upsample
(
nn
.
Module
):
"""
"""
An upsampling layer with an optional convolution.
An upsampling layer with an optional convolution.
...
@@ -134,154 +150,713 @@ class Downsample(nn.Module):
...
@@ -134,154 +150,713 @@ class Downsample(nn.Module):
return
self
.
op
(
x
)
return
self
.
op
(
x
)
class
UNetUpsample
(
nn
.
Module
):
# class UNetUpsample(nn.Module):
def
__init__
(
self
,
in_channels
,
with_conv
):
# def __init__(self, in_channels, with_conv):
super
().
__init__
()
# super().__init__()
self
.
with_conv
=
with_conv
# self.with_conv = with_conv
if
self
.
with_conv
:
# if self.with_conv:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
#
# def forward(self, x):
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
# if self.with_conv:
# x = self.conv(x)
# return x
#
#
# class GlideUpsample(nn.Module):
# """
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param
use_conv
:
a
bool
determining
if
a
convolution
is
# applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
3
D
,
then
# upsampling occurs in the inner-two dimensions. #"""
#
# def __init__(self, channels, use_conv, dims=2, out_channels=None):
# super().__init__()
# self.channels = channels
# self.out_channels = out_channels or channels
# self.use_conv = use_conv
# self.dims = dims
# if use_conv:
# self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
#
# def forward(self, x):
# assert x.shape[1] == self.channels
# if self.dims == 3:
# x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
# else:
# x = F.interpolate(x, scale_factor=2, mode="nearest")
# if self.use_conv:
# x = self.conv(x)
# return x
#
#
# class LDMUpsample(nn.Module):
# """
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param #
use_conv
:
a
bool
determining
if
a
convolution
is
applied
.
:
param
dims
:
determines
if
the
signal
is
1
D
,
2
D
,
or
3
D
.
# If
3
D
,
then
# upsampling occurs in the inner-two dimensions. #"""
#
# def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
# super().__init__()
# self.channels = channels
# self.out_channels = out_channels or channels
# self.use_conv = use_conv
# self.dims = dims
# if use_conv:
# self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
#
# def forward(self, x):
# assert x.shape[1] == self.channels
# if self.dims == 3:
# x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
# else:
# x = F.interpolate(x, scale_factor=2, mode="nearest")
# if self.use_conv:
# x = self.conv(x)
# return x
#
#
# class GradTTSUpsample(torch.nn.Module):
# def __init__(self, dim):
# super(Upsample, self).__init__()
# self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
#
# def forward(self, x):
# return self.conv(x)
#
#
# TODO (patil-suraj): needs test
# class Upsample1d(nn.Module):
# def __init__(self, dim):
# super().__init__()
# self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
#
# def forward(self, x):
# return self.conv(x)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
# RESNETS
class
GlideUpsample
(
nn
.
Module
):
# unet_glide.py & unet_ldm.py
class
ResBlock
(
TimestepBlock
):
"""
"""
An upsampling layer with an optional convolution.
A residual block that can optionally change the number of channels.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
upsampling occurs in the inner-two dimensions.
use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
downsampling.
"""
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
):
def
__init__
(
self
,
channels
,
emb_channels
,
dropout
,
out_channels
=
None
,
use_conv
=
False
,
use_scale_shift_norm
=
False
,
dims
=
2
,
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
self
.
emb_channels
=
emb_channels
self
.
dropout
=
dropout
self
.
out_channels
=
out_channels
or
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_conv
=
use_conv
self
.
dims
=
dims
self
.
use_checkpoint
=
use_checkpoint
if
use_conv
:
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
,
swish
=
1.0
),
nn
.
Identity
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
,
swish
=
0.0
if
use_scale_shift_norm
else
1.0
),
nn
.
SiLU
()
if
use_scale_shift_norm
else
nn
.
Identity
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
)
if
self
.
out_channels
==
channels
:
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
def
forward
(
self
,
x
,
emb
):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
self
.
h_upd
(
h
)
x
=
self
.
x_upd
(
x
)
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
# unet.py
if
self
.
dims
==
3
:
class
ResnetBlock
(
nn
.
Module
):
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
# unet_grad_tts.py
class
ResnetBlockGradTTS
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
time_emb_dim
,
groups
=
8
):
super
(
ResnetBlockGradTTS
,
self
).
__init__
()
self
.
mlp
=
torch
.
nn
.
Sequential
(
Mish
(),
torch
.
nn
.
Linear
(
time_emb_dim
,
dim_out
))
self
.
block1
=
Block
(
dim
,
dim_out
,
groups
=
groups
)
self
.
block2
=
Block
(
dim_out
,
dim_out
,
groups
=
groups
)
if
dim
!=
dim_out
:
self
.
res_conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
else
:
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
self
.
res_conv
=
torch
.
nn
.
Identity
()
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
def
forward
(
self
,
x
,
mask
,
time_emb
):
return
x
h
=
self
.
block1
(
x
,
mask
)
h
+=
self
.
mlp
(
time_emb
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
h
=
self
.
block2
(
h
,
mask
)
output
=
h
+
self
.
res_conv
(
x
*
mask
)
return
output
# unet_rl.py
class
ResidualTemporalBlock
(
nn
.
Module
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
super
().
__init__
()
self
.
blocks
=
nn
.
ModuleList
(
[
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
),
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
),
]
)
self
.
time_mlp
=
nn
.
Sequential
(
nn
.
Mish
(),
nn
.
Linear
(
embed_dim
,
out_channels
),
RearrangeDim
(),
# Rearrange("batch t -> batch t 1"),
)
self
.
residual_conv
=
(
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
)
def
forward
(
self
,
x
,
t
):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out
=
self
.
blocks
[
0
](
x
)
+
self
.
time_mlp
(
t
)
out
=
self
.
blocks
[
1
](
out
)
return
out
+
self
.
residual_conv
(
x
)
# unet_score_estimation.py
class
ResnetBlockBigGANpp
(
nn
.
Module
):
def
__init__
(
self
,
act
,
in_ch
,
out_ch
=
None
,
temb_dim
=
None
,
up
=
False
,
down
=
False
,
dropout
=
0.1
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
skip_rescale
=
True
,
init_scale
=
0.0
,
):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
up
=
up
self
.
down
=
down
self
.
fir
=
fir
self
.
fir_kernel
=
fir_kernel
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
if
in_ch
!=
out_ch
or
up
or
down
:
self
.
Conv_2
=
conv1x1
(
in_ch
,
out_ch
)
self
.
skip_rescale
=
skip_rescale
self
.
act
=
act
self
.
in_ch
=
in_ch
self
.
out_ch
=
out_ch
def
forward
(
self
,
x
,
temb
=
None
):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
if
self
.
up
:
if
self
.
fir
:
h
=
upsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
naive_upsample_2d
(
h
,
factor
=
2
)
x
=
naive_upsample_2d
(
x
,
factor
=
2
)
elif
self
.
down
:
if
self
.
fir
:
h
=
downsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
naive_downsample_2d
(
h
,
factor
=
2
)
x
=
naive_downsample_2d
(
x
,
factor
=
2
)
h
=
self
.
Conv_0
(
h
)
# Add bias to each feature map conditioned on the time embedding
if
temb
is
not
None
:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
h
=
self
.
Dropout_0
(
h
)
h
=
self
.
Conv_1
(
h
)
if
self
.
in_ch
!=
self
.
out_ch
or
self
.
up
or
self
.
down
:
x
=
self
.
Conv_2
(
x
)
if
not
self
.
skip_rescale
:
return
x
+
h
else
:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
# unet_score_estimation.py
class
ResnetBlockDDPMpp
(
nn
.
Module
):
"""ResBlock adapted from DDPM."""
def
__init__
(
self
,
act
,
in_ch
,
out_ch
=
None
,
temb_dim
=
None
,
conv_shortcut
=
False
,
dropout
=
0.1
,
skip_rescale
=
False
,
init_scale
=
0.0
,
):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
if
in_ch
!=
out_ch
:
if
conv_shortcut
:
self
.
Conv_2
=
conv3x3
(
in_ch
,
out_ch
)
else
:
self
.
NIN_0
=
NIN
(
in_ch
,
out_ch
)
self
.
skip_rescale
=
skip_rescale
self
.
act
=
act
self
.
out_ch
=
out_ch
self
.
conv_shortcut
=
conv_shortcut
def
forward
(
self
,
x
,
temb
=
None
):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
h
=
self
.
Conv_0
(
h
)
if
temb
is
not
None
:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
h
=
self
.
Dropout_0
(
h
)
h
=
self
.
Conv_1
(
h
)
if
x
.
shape
[
1
]
!=
self
.
out_ch
:
if
self
.
conv_shortcut
:
x
=
self
.
Conv_2
(
x
)
else
:
x
=
self
.
NIN_0
(
x
)
if
not
self
.
skip_rescale
:
return
x
+
h
else
:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
class
LDMUpsample
(
nn
.
Module
):
# HELPER Modules
def
normalization
(
channels
,
swish
=
0.0
):
"""
"""
An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param
Make a standard normalization layer, with an optional swish activation.
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D.
If 3D, then
:param channels: number of input channels. :return: an nn.Module for normalization.
upsampling occurs in the inner-two dimensions.
"""
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
class
GroupNorm32
(
nn
.
GroupNorm
):
def
__init__
(
self
,
num_groups
,
num_channels
,
swish
,
eps
=
1e-5
):
super
().
__init__
(
num_groups
=
num_groups
,
num_channels
=
num_channels
,
eps
=
eps
)
self
.
swish
=
swish
def
forward
(
self
,
x
):
y
=
super
().
forward
(
x
.
float
()).
to
(
x
.
dtype
)
if
self
.
swish
==
1.0
:
y
=
F
.
silu
(
y
)
elif
self
.
swish
:
y
=
y
*
F
.
sigmoid
(
y
*
float
(
self
.
swish
))
return
y
def
linear
(
*
args
,
**
kwargs
):
"""
Create a linear module.
"""
return
nn
.
Linear
(
*
args
,
**
kwargs
)
def
zero_module
(
module
):
"""
Zero out the parameters of a module and return it.
"""
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
class
Mish
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
x
))
class
Block
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
groups
=
8
):
super
(
Block
,
self
).
__init__
()
self
.
block
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
3
,
padding
=
1
),
torch
.
nn
.
GroupNorm
(
groups
,
dim_out
),
Mish
()
)
def
forward
(
self
,
x
,
mask
):
output
=
self
.
block
(
x
*
mask
)
return
output
*
mask
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
class
Conv1dBlock
(
nn
.
Module
):
"""
Conv1d --> GroupNorm --> Mish
"""
def
__init__
(
self
,
inp_channels
,
out_channels
,
kernel_size
,
n_groups
=
8
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
block
=
nn
.
Sequential
(
self
.
use_conv
=
use_conv
nn
.
Conv1d
(
inp_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
),
self
.
dims
=
dims
RearrangeDim
(),
if
use_conv
:
# Rearrange("batch channels horizon -> batch channels 1 horizon"),
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
)
nn
.
GroupNorm
(
n_groups
,
out_channels
),
RearrangeDim
(),
# Rearrange("batch channels 1 horizon -> batch channels horizon"),
nn
.
Mish
(),
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
return
self
.
block
(
x
)
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
class
RearrangeDim
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
tensor
):
if
len
(
tensor
.
shape
)
==
2
:
return
tensor
[:,
:,
None
]
if
len
(
tensor
.
shape
)
==
3
:
return
tensor
[:,
:,
None
,
:]
elif
len
(
tensor
.
shape
)
==
4
:
return
tensor
[:,
:,
0
,
:]
else
:
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
raise
ValueError
(
f
"`len(tensor)`:
{
len
(
tensor
)
}
has to be 2, 3 or 4."
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
"""1x1 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
"""3x3 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
def
default_init
(
scale
=
1.0
):
"""The same initialization used in DDPM."""
scale
=
1e-10
if
scale
==
0
else
scale
return
variance_scaling
(
scale
,
"fan_avg"
,
"uniform"
)
def
variance_scaling
(
scale
,
mode
,
distribution
,
in_axis
=
1
,
out_axis
=
0
,
dtype
=
torch
.
float32
,
device
=
"cpu"
):
"""Ported from JAX."""
def
_compute_fans
(
shape
,
in_axis
=
1
,
out_axis
=
0
):
receptive_field_size
=
np
.
prod
(
shape
)
/
shape
[
in_axis
]
/
shape
[
out_axis
]
fan_in
=
shape
[
in_axis
]
*
receptive_field_size
fan_out
=
shape
[
out_axis
]
*
receptive_field_size
return
fan_in
,
fan_out
def
init
(
shape
,
dtype
=
dtype
,
device
=
device
):
fan_in
,
fan_out
=
_compute_fans
(
shape
,
in_axis
,
out_axis
)
if
mode
==
"fan_in"
:
denominator
=
fan_in
elif
mode
==
"fan_out"
:
denominator
=
fan_out
elif
mode
==
"fan_avg"
:
denominator
=
(
fan_in
+
fan_out
)
/
2
else
:
raise
ValueError
(
"invalid mode for variance scaling initializer: {}"
.
format
(
mode
))
variance
=
scale
/
denominator
if
distribution
==
"normal"
:
return
torch
.
randn
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
np
.
sqrt
(
variance
)
elif
distribution
==
"uniform"
:
return
(
torch
.
rand
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
2.0
-
1.0
)
*
np
.
sqrt
(
3
*
variance
)
else
:
raise
ValueError
(
"invalid distribution for variance scaling initializer"
)
return
init
class
GradTTSUpsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
Upsample
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
ConvTranspose2d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)
):
return
self
.
conv
(
x
)
return
upfirdn2d_native
(
input
,
kernel
,
up
,
up
,
down
,
down
,
pad
[
0
],
pad
[
1
],
pad
[
0
],
pad
[
1
]
)
# TODO (patil-suraj): needs test
def
upfirdn2d_native
(
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
):
class
Upsample1d
(
nn
.
Module
):
_
,
channel
,
in_h
,
in_w
=
input
.
shape
def
__init__
(
self
,
dim
):
input
=
input
.
reshape
(
-
1
,
in_h
,
in_w
,
1
)
_
,
in_h
,
in_w
,
minor
=
input
.
shape
kernel_h
,
kernel_w
=
kernel
.
shape
out
=
input
.
view
(
-
1
,
in_h
,
1
,
in_w
,
1
,
minor
)
out
=
F
.
pad
(
out
,
[
0
,
0
,
0
,
up_x
-
1
,
0
,
0
,
0
,
up_y
-
1
])
out
=
out
.
view
(
-
1
,
in_h
*
up_y
,
in_w
*
up_x
,
minor
)
out
=
F
.
pad
(
out
,
[
0
,
0
,
max
(
pad_x0
,
0
),
max
(
pad_x1
,
0
),
max
(
pad_y0
,
0
),
max
(
pad_y1
,
0
)])
out
=
out
[
:,
max
(
-
pad_y0
,
0
)
:
out
.
shape
[
1
]
-
max
(
-
pad_y1
,
0
),
max
(
-
pad_x0
,
0
)
:
out
.
shape
[
2
]
-
max
(
-
pad_x1
,
0
),
:,
]
out
=
out
.
permute
(
0
,
3
,
1
,
2
)
out
=
out
.
reshape
([
-
1
,
1
,
in_h
*
up_y
+
pad_y0
+
pad_y1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
])
w
=
torch
.
flip
(
kernel
,
[
0
,
1
]).
view
(
1
,
1
,
kernel_h
,
kernel_w
)
out
=
F
.
conv2d
(
out
,
w
)
out
=
out
.
reshape
(
-
1
,
minor
,
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
+
1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
+
1
,
)
out
=
out
.
permute
(
0
,
2
,
3
,
1
)
out
=
out
[:,
::
down_y
,
::
down_x
,
:]
out_h
=
(
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
)
//
down_y
+
1
out_w
=
(
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
)
//
down_x
+
1
return
out
.
view
(
-
1
,
channel
,
out_h
,
out_w
)
def
upsample_2d
(
x
,
k
=
None
,
factor
=
2
,
gain
=
1
):
r
"""Upsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
multiple of the upsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
if
k
is
None
:
k
=
[
1
]
*
factor
k
=
_setup_kernel
(
k
)
*
(
gain
*
(
factor
**
2
))
p
=
k
.
shape
[
0
]
-
factor
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
up
=
factor
,
pad
=
((
p
+
1
)
//
2
+
factor
-
1
,
p
//
2
))
def
downsample_2d
(
x
,
k
=
None
,
factor
=
2
,
gain
=
1
):
r
"""Downsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
if
k
is
None
:
k
=
[
1
]
*
factor
k
=
_setup_kernel
(
k
)
*
gain
p
=
k
.
shape
[
0
]
-
factor
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
def
naive_upsample_2d
(
x
,
factor
=
2
):
_N
,
C
,
H
,
W
=
x
.
shape
x
=
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
,
1
,
W
,
1
))
x
=
x
.
repeat
(
1
,
1
,
1
,
factor
,
1
,
factor
)
return
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
*
factor
,
W
*
factor
))
def
naive_downsample_2d
(
x
,
factor
=
2
):
_N
,
C
,
H
,
W
=
x
.
shape
x
=
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
//
factor
,
factor
,
W
//
factor
,
factor
))
return
torch
.
mean
(
x
,
dim
=
(
3
,
5
))
class
NIN
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
num_units
,
init_scale
=
0.1
):
super
().
__init__
()
super
().
__init__
()
self
.
conv
=
nn
.
ConvTranspose1d
(
dim
,
dim
,
4
,
2
,
1
)
self
.
W
=
nn
.
Parameter
(
default_init
(
scale
=
init_scale
)((
in_dim
,
num_units
)),
requires_grad
=
True
)
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
num_units
),
requires_grad
=
True
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
y
=
contract_inner
(
x
,
self
.
W
)
+
self
.
b
return
y
.
permute
(
0
,
3
,
1
,
2
)
# class ResnetBlock(nn.Module):
# def __init__(
# self,
def
_setup_kernel
(
k
):
# *,
k
=
np
.
asarray
(
k
,
dtype
=
np
.
float32
)
# in_channels,
if
k
.
ndim
==
1
:
# out_channels=None,
k
=
np
.
outer
(
k
,
k
)
# conv_shortcut=False,
k
/=
np
.
sum
(
k
)
# dropout,
assert
k
.
ndim
==
2
# temb_channels=512,
assert
k
.
shape
[
0
]
==
k
.
shape
[
1
]
# use_scale_shift_norm=False,
return
k
# ):
# super().__init__()
# self.in_channels = in_channels
def
contract_inner
(
x
,
y
):
# out_channels = in_channels if out_channels is None else out_channels
"""tensordot(x, y, 1)."""
# self.out_channels = out_channels
x_chars
=
list
(
string
.
ascii_lowercase
[:
len
(
x
.
shape
)])
# self.use_conv_shortcut = conv_shortcut
y_chars
=
list
(
string
.
ascii_lowercase
[
len
(
x
.
shape
)
:
len
(
y
.
shape
)
+
len
(
x
.
shape
)])
# self.use_scale_shift_norm = use_scale_shift_norm
y_chars
[
0
]
=
x_chars
[
-
1
]
# first axis of y and last of x get summed
out_chars
=
x_chars
[:
-
1
]
+
y_chars
[
1
:]
# self.norm1 = Normalize(in_channels)
return
_einsum
(
x_chars
,
y_chars
,
out_chars
,
x
,
y
)
# self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# temp_out_channles = 2 * out_channels if use_scale_shift_norm else out_channels
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
# self.temb_proj = torch.nn.Linear(temb_channels, temp_out_channles)
einsum_str
=
"{},{}->{}"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
# self.norm2 = Normalize(out_channels)
# self.dropout = torch.nn.Dropout(dropout)
# self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# else:
# self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
# def forward(self, x, temb):
# h = x
# h = self.norm1(h)
# h = nonlinearity(h)
# h = self.conv1(h)
# # TODO: check if this broadcasting works correctly for 1D and 3D
# temb = self.temb_proj(nonlinearity(temb))[:, :, None, None]
# if self.use_scale_shift_norm:
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
# scale, shift = torch.chunk(temb, 2, dim=1)
# h = self.norm2(h) * (1 + scale) + shift
# h = out_rest(h)
# else:
# h = h + temb
# h = self.norm2(h)
# h = nonlinearity(h)
# h = self.dropout(h)
# h = self.conv2(h)
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# x = self.conv_shortcut(x)
# else:
# x = self.nin_shortcut(x)
# return x + h
src/diffusers/models/unet.py
View file @
1899457b
...
@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock
,
Upsample
def
nonlinearity
(
x
):
def
nonlinearity
(
x
):
...
@@ -34,46 +34,46 @@ def Normalize(in_channels):
...
@@ -34,46 +34,46 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
ResnetBlock
(
nn
.
Module
):
#
class ResnetBlock(nn.Module):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
#
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super
().
__init__
()
#
super().__init__()
self
.
in_channels
=
in_channels
#
self.in_channels = in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
#
out_channels = in_channels if out_channels is None else out_channels
self
.
out_channels
=
out_channels
#
self.out_channels = out_channels
self
.
use_conv_shortcut
=
conv_shortcut
#
self.use_conv_shortcut = conv_shortcut
#
self
.
norm1
=
Normalize
(
in_channels
)
#
self.norm1 = Normalize(in_channels)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
#
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
#
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self
.
norm2
=
Normalize
(
out_channels
)
#
self.norm2 = Normalize(out_channels)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
#
self.dropout = torch.nn.Dropout(dropout)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
#
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if
self
.
in_channels
!=
self
.
out_channels
:
#
if self.in_channels != self.out_channels:
if
self
.
use_conv_shortcut
:
#
if self.use_conv_shortcut:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
#
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else
:
#
else:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
#
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
#
def
forward
(
self
,
x
,
temb
):
#
def forward(self, x, temb):
h
=
x
#
h = x
h
=
self
.
norm1
(
h
)
#
h = self.norm1(h)
h
=
nonlinearity
(
h
)
#
h = nonlinearity(h)
h
=
self
.
conv1
(
h
)
#
h = self.conv1(h)
#
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
#
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
#
h
=
self
.
norm2
(
h
)
#
h = self.norm2(h)
h
=
nonlinearity
(
h
)
#
h = nonlinearity(h)
h
=
self
.
dropout
(
h
)
#
h = self.dropout(h)
h
=
self
.
conv2
(
h
)
#
h = self.conv2(h)
#
if
self
.
in_channels
!=
self
.
out_channels
:
#
if self.in_channels != self.out_channels:
if
self
.
use_conv_shortcut
:
#
if self.use_conv_shortcut:
x
=
self
.
conv_shortcut
(
x
)
#
x = self.conv_shortcut(x)
else
:
#
else:
x
=
self
.
nin_shortcut
(
x
)
#
x = self.nin_shortcut(x)
#
return
x
+
h
#
return x + h
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
...
@@ -127,7 +127,6 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -127,7 +127,6 @@ class UNetModel(ModelMixin, ConfigMixin):
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
attn_2
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
for
i_block
in
range
(
self
.
num_res_blocks
):
...
@@ -142,7 +141,6 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -142,7 +141,6 @@ class UNetModel(ModelMixin, ConfigMixin):
down
=
nn
.
Module
()
down
=
nn
.
Module
()
down
.
block
=
block
down
.
block
=
block
down
.
attn
=
attn
down
.
attn
=
attn
down
.
attn_2
=
attn_2
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
down
.
downsample
=
Downsample
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
curr_res
=
curr_res
//
2
curr_res
=
curr_res
//
2
...
...
src/diffusers/models/unet_glide.py
View file @
1899457b
...
@@ -8,7 +8,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -8,7 +8,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
...
@@ -96,16 +96,14 @@ def zero_module(module):
...
@@ -96,16 +96,14 @@ def zero_module(module):
return
module
return
module
class
TimestepBlock
(
nn
.
Module
):
# class TimestepBlock(nn.Module):
"""
# """
Any module where forward() takes timestep embeddings as a second argument.
# Any module where forward() takes timestep embeddings as a second argument. #"""
"""
#
# @abstractmethod
@
abstractmethod
# def forward(self, x, emb):
def
forward
(
self
,
x
,
emb
):
# """
"""
# Apply the module to `x` given `emb` timestep embeddings. #"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
...
@@ -124,106 +122,99 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -124,106 +122,99 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
return
x
class
ResBlock
(
TimestepBlock
):
# class ResBlock(TimestepBlock):
"""
# """
A residual block that can optionally change the number of channels.
# A residual block that can optionally change the number of channels. # # :param channels: the number of input
channels
.
:
param
emb_channels
:
the
number
of
timestep
embedding
channels
.
# :param dropout: the rate of dropout. :param
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
out_channels
:
if
specified
,
the
number
of
out
channels
.
:
param
# use_conv: if True and out_channels is specified, use a
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
spatial
# convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
use_conv: if True and out_channels is specified, use a spatial
dims
:
determines
if
the
signal
is
1
D
,
2
D
,
or
3
D
.
:
param
use_checkpoint
:
if
True
,
use
gradient
checkpointing
# on this
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
module
.
:
param
up
:
if
True
,
use
this
block
for
upsampling
.
:
param
down
:
if
True
,
use
this
block
for
# downsampling. #"""
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
#
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
# def __init__(
downsampling.
# self,
"""
# channels,
# emb_channels,
def
__init__
(
# dropout,
self
,
# out_channels=None,
channels
,
# use_conv=False,
emb_channels
,
# use_scale_shift_norm=False,
dropout
,
# dims=2,
out_channels
=
None
,
# use_checkpoint=False,
use_conv
=
False
,
# up=False,
use_scale_shift_norm
=
False
,
# down=False,
dims
=
2
,
# ):
use_checkpoint
=
False
,
# super().__init__()
up
=
False
,
# self.channels = channels
down
=
False
,
# self.emb_channels = emb_channels
):
# self.dropout = dropout
super
().
__init__
()
# self.out_channels = out_channels or channels
self
.
channels
=
channels
# self.use_conv = use_conv
self
.
emb_channels
=
emb_channels
# self.use_checkpoint = use_checkpoint
self
.
dropout
=
dropout
# self.use_scale_shift_norm = use_scale_shift_norm
self
.
out_channels
=
out_channels
or
channels
#
self
.
use_conv
=
use_conv
# self.in_layers = nn.Sequential(
self
.
use_checkpoint
=
use_checkpoint
# normalization(channels, swish=1.0),
self
.
use_scale_shift_norm
=
use_scale_shift_norm
# nn.Identity(),
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
self
.
in_layers
=
nn
.
Sequential
(
# )
normalization
(
channels
,
swish
=
1.0
),
#
nn
.
Identity
(),
# self.updown = up or down
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
#
)
# if up:
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self
.
updown
=
up
or
down
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
# elif down:
if
up
:
# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
# else:
elif
down
:
# self.h_upd = self.x_upd = nn.Identity()
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
#
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
# self.emb_layers = nn.Sequential(
else
:
# nn.SiLU(),
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
# linear(
# emb_channels,
self
.
emb_layers
=
nn
.
Sequential
(
# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
nn
.
SiLU
(),
# ),
linear
(
# )
emb_channels
,
# self.out_layers = nn.Sequential(
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
# normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
),
# nn.SiLU() if use_scale_shift_norm else nn.Identity(),
)
# nn.Dropout(p=dropout),
self
.
out_layers
=
nn
.
Sequential
(
# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
normalization
(
self
.
out_channels
,
swish
=
0.0
if
use_scale_shift_norm
else
1.0
),
# )
nn
.
SiLU
()
if
use_scale_shift_norm
else
nn
.
Identity
(),
#
nn
.
Dropout
(
p
=
dropout
),
# if self.out_channels == channels:
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
# self.skip_connection = nn.Identity()
)
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
if
self
.
out_channels
==
channels
:
# else:
self
.
skip_connection
=
nn
.
Identity
()
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
elif
use_conv
:
#
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
# def forward(self, x, emb):
else
:
# """
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
# Apply the block to a Tensor, conditioned on a timestep embedding. # # :param x: an [N x C x ...] Tensor of features.
:
param
emb
:
an
[
N
x
emb_channels
]
Tensor
of
timestep
embeddings
.
# :return: an [N x C x ...] Tensor of outputs. #"""
def
forward
(
self
,
x
,
emb
):
# if self.updown:
"""
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
Apply the block to a Tensor, conditioned on a timestep embedding.
# h = in_rest(x)
# h = self.h_upd(h)
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
# x = self.x_upd(x)
:return: an [N x C x ...] Tensor of outputs.
# h = in_conv(h)
"""
# else:
if
self
.
updown
:
# h = self.in_layers(x)
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
# emb_out = self.emb_layers(emb).type(h.dtype)
h
=
in_rest
(
x
)
# while len(emb_out.shape) < len(h.shape):
h
=
self
.
h_upd
(
h
)
# emb_out = emb_out[..., None]
x
=
self
.
x_upd
(
x
)
# if self.use_scale_shift_norm:
h
=
in_conv
(
h
)
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
else
:
# scale, shift = torch.chunk(emb_out, 2, dim=1)
h
=
self
.
in_layers
(
x
)
# h = out_norm(h) * (1 + scale) + shift
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
# h = out_rest(h)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
# else:
emb_out
=
emb_out
[...,
None
]
# h = h + emb_out
if
self
.
use_scale_shift_norm
:
# h = self.out_layers(h)
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
# return self.skip_connection(x) + h
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
class
GlideUNetModel
(
ModelMixin
,
ConfigMixin
):
class
GlideUNetModel
(
ModelMixin
,
ConfigMixin
):
...
...
src/diffusers/models/unet_grad_tts.py
View file @
1899457b
...
@@ -4,7 +4,9 @@ from ..configuration_utils import ConfigMixin
...
@@ -4,7 +4,9 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
LinearAttention
from
.attention
import
LinearAttention
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
from
.resnet
import
ResnetBlockGradTTS
as
ResnetBlock
from
.resnet
import
Upsample
class
Mish
(
torch
.
nn
.
Module
):
class
Mish
(
torch
.
nn
.
Module
):
...
@@ -34,24 +36,24 @@ class Block(torch.nn.Module):
...
@@ -34,24 +36,24 @@ class Block(torch.nn.Module):
return
output
*
mask
return
output
*
mask
class
ResnetBlock
(
torch
.
nn
.
Module
):
#
class ResnetBlock(torch.nn.Module):
def
__init__
(
self
,
dim
,
dim_out
,
time_emb_dim
,
groups
=
8
):
#
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
super
(
ResnetBlock
,
self
).
__init__
()
#
super(ResnetBlock, self).__init__()
self
.
mlp
=
torch
.
nn
.
Sequential
(
Mish
(),
torch
.
nn
.
Linear
(
time_emb_dim
,
dim_out
))
#
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
#
self
.
block1
=
Block
(
dim
,
dim_out
,
groups
=
groups
)
#
self.block1 = Block(dim, dim_out, groups=groups)
self
.
block2
=
Block
(
dim_out
,
dim_out
,
groups
=
groups
)
#
self.block2 = Block(dim_out, dim_out, groups=groups)
if
dim
!=
dim_out
:
#
if dim != dim_out:
self
.
res_conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
#
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else
:
#
else:
self
.
res_conv
=
torch
.
nn
.
Identity
()
#
self.res_conv = torch.nn.Identity()
#
def
forward
(
self
,
x
,
mask
,
time_emb
):
#
def forward(self, x, mask, time_emb):
h
=
self
.
block1
(
x
,
mask
)
#
h = self.block1(x, mask)
h
+=
self
.
mlp
(
time_emb
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
#
h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
h
=
self
.
block2
(
h
,
mask
)
#
h = self.block2(h, mask)
output
=
h
+
self
.
res_conv
(
x
*
mask
)
#
output = h + self.res_conv(x * mask)
return
output
#
return output
class
Residual
(
torch
.
nn
.
Module
):
class
Residual
(
torch
.
nn
.
Module
):
...
...
src/diffusers/models/unet_ldm.py
View file @
1899457b
...
@@ -11,7 +11,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -11,7 +11,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
def
exists
(
val
):
def
exists
(
val
):
...
@@ -359,16 +359,14 @@ class AttentionPool2d(nn.Module):
...
@@ -359,16 +359,14 @@ class AttentionPool2d(nn.Module):
return
x
[:,
:,
0
]
return
x
[:,
:,
0
]
class
TimestepBlock
(
nn
.
Module
):
# class TimestepBlock(nn.Module):
"""
# """
Any module where forward() takes timestep embeddings as a second argument.
# Any module where forward() takes timestep embeddings as a second argument. #"""
"""
#
# @abstractmethod
@
abstractmethod
# def forward(self, x, emb):
def
forward
(
self
,
x
,
emb
):
# """
"""
# Apply the module to `x` given `emb` timestep embeddings. #"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
...
@@ -387,99 +385,97 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -387,99 +385,97 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
return
x
class
ResBlock
(
TimestepBlock
):
# class A_ResBlock(TimestepBlock):
"""
# """
A residual block that can optionally change the number of channels. :param channels: the number of input channels.
# A residual block that can optionally change the number of channels. :param channels: the number of input channels. #
:param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param
:
param
emb_channels
:
the
number
of
timestep
embedding
channels
.
:
param
dropout
:
the
rate
of
dropout
.
:
param
#
out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use
out_channels
:
if
specified
,
the
number
of
out
channels
.
:
param
use_conv
:
if
True
and
out_channels
is
specified
,
use
# a
a spatial
spatial
# convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
dims
:
determines
if
the
signal
is
1
D
,
2
D
,
or
3
D
.
:
param
use_checkpoint
:
if
True
,
use
gradient
checkpointing
# on this
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
module
.
:
param
up
:
if
True
,
use
this
block
for
upsampling
.
:
param
down
:
if
True
,
use
this
block
for
# downsampling. #"""
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
#
downsampling.
# def __init__(
"""
# self,
# channels,
def
__init__
(
# emb_channels,
self
,
# dropout,
channels
,
# out_channels=None,
emb_channels
,
# use_conv=False,
dropout
,
# use_scale_shift_norm=False,
out_channels
=
None
,
# dims=2,
use_conv
=
False
,
# use_checkpoint=False,
use_scale_shift_norm
=
False
,
# up=False,
dims
=
2
,
# down=False,
use_checkpoint
=
False
,
# ):
up
=
False
,
# super().__init__()
down
=
False
,
# self.channels = channels
):
# self.emb_channels = emb_channels
super
().
__init__
()
# self.dropout = dropout
self
.
channels
=
channels
# self.out_channels = out_channels or channels
self
.
emb_channels
=
emb_channels
# self.use_conv = use_conv
self
.
dropout
=
dropout
# self.use_checkpoint = use_checkpoint
self
.
out_channels
=
out_channels
or
channels
# self.use_scale_shift_norm = use_scale_shift_norm
self
.
use_conv
=
use_conv
#
self
.
use_checkpoint
=
use_checkpoint
# self.in_layers = nn.Sequential(
self
.
use_scale_shift_norm
=
use_scale_shift_norm
# normalization(channels),
# nn.SiLU(),
self
.
in_layers
=
nn
.
Sequential
(
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
normalization
(
channels
),
# )
nn
.
SiLU
(),
#
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
# self.updown = up or down
)
#
# if up:
self
.
updown
=
up
or
down
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
if
up
:
# elif down:
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
elif
down
:
# else:
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
# self.h_upd = self.x_upd = nn.Identity()
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
#
else
:
# self.emb_layers = nn.Sequential(
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
# nn.SiLU(),
# linear(
self
.
emb_layers
=
nn
.
Sequential
(
# emb_channels,
nn
.
SiLU
(),
# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
linear
(
# ),
emb_channels
,
# )
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
# self.out_layers = nn.Sequential(
),
# normalization(self.out_channels),
)
# nn.SiLU(),
self
.
out_layers
=
nn
.
Sequential
(
# nn.Dropout(p=dropout),
normalization
(
self
.
out_channels
),
# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
nn
.
SiLU
(),
# )
nn
.
Dropout
(
p
=
dropout
),
#
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
# if self.out_channels == channels:
)
# self.skip_connection = nn.Identity()
# elif use_conv:
if
self
.
out_channels
==
channels
:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
self
.
skip_connection
=
nn
.
Identity
()
# else:
elif
use_conv
:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
#
else
:
# def forward(self, x, emb):
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
# if self.updown:
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
def
forward
(
self
,
x
,
emb
):
# h = in_rest(x)
if
self
.
updown
:
# h = self.h_upd(h)
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
# x = self.x_upd(x)
h
=
in_rest
(
x
)
# h = in_conv(h)
h
=
self
.
h_upd
(
h
)
# else:
x
=
self
.
x_upd
(
x
)
# h = self.in_layers(x)
h
=
in_conv
(
h
)
# emb_out = self.emb_layers(emb).type(h.dtype)
else
:
# while len(emb_out.shape) < len(h.shape):
h
=
self
.
in_layers
(
x
)
# emb_out = emb_out[..., None]
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
# if self.use_scale_shift_norm:
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
emb_out
=
emb_out
[...,
None
]
# scale, shift = torch.chunk(emb_out, 2, dim=1)
if
self
.
use_scale_shift_norm
:
# h = out_norm(h) * (1 + scale) + shift
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
# h = out_rest(h)
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
# else:
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
# h = h + emb_out
h
=
out_rest
(
h
)
# h = self.out_layers(h)
else
:
# return self.skip_connection(x) + h
h
=
h
+
emb_out
#
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
class
QKVAttention
(
nn
.
Module
):
class
QKVAttention
(
nn
.
Module
):
...
...
src/diffusers/models/unet_rl.py
View file @
1899457b
...
@@ -6,6 +6,7 @@ import torch.nn as nn
...
@@ -6,6 +6,7 @@ import torch.nn as nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
ResidualTemporalBlock
class
SinusoidalPosEmb
(
nn
.
Module
):
class
SinusoidalPosEmb
(
nn
.
Module
):
...
@@ -72,36 +73,35 @@ class Conv1dBlock(nn.Module):
...
@@ -72,36 +73,35 @@ class Conv1dBlock(nn.Module):
return
self
.
block
(
x
)
return
self
.
block
(
x
)
class
ResidualTemporalBlock
(
nn
.
Module
):
# class ResidualTemporalBlock(nn.Module):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
# def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super
().
__init__
()
# super().__init__()
#
self
.
blocks
=
nn
.
ModuleList
(
# self.blocks = nn.ModuleList(
[
# [
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
),
# Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
),
# Conv1dBlock(out_channels, out_channels, kernel_size),
]
# ]
)
# )
#
self
.
time_mlp
=
nn
.
Sequential
(
# self.time_mlp = nn.Sequential(
nn
.
Mish
(),
# nn.Mish(),
nn
.
Linear
(
embed_dim
,
out_channels
),
# nn.Linear(embed_dim, out_channels),
RearrangeDim
(),
# RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
# Rearrange("batch t -> batch t 1"),
)
# )
#
self
.
residual_conv
=
(
# self.residual_conv = (
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
# nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
# )
#
def
forward
(
self
,
x
,
t
):
# def forward(self, x, t):
"""
# """
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
# x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x #
out_channels x horizon ]
out_channels
x
horizon
]
#"""
"""
# out = self.blocks[0](x) + self.time_mlp(t)
out
=
self
.
blocks
[
0
](
x
)
+
self
.
time_mlp
(
t
)
# out = self.blocks[1](out)
out
=
self
.
blocks
[
1
](
out
)
# return out + self.residual_conv(x)
return
out
+
self
.
residual_conv
(
x
)
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
1899457b
...
@@ -28,6 +28,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -28,6 +28,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
ResnetBlockBigGANpp
,
ResnetBlockDDPMpp
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
...
@@ -299,7 +300,7 @@ def downsample_2d(x, k=None, factor=2, gain=1):
...
@@ -299,7 +300,7 @@ def downsample_2d(x, k=None, factor=2, gain=1):
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
def
ddpm_
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
"""1x1 convolution with DDPM initialization."""
"""1x1 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
...
@@ -307,7 +308,7 @@ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, pad
...
@@ -307,7 +308,7 @@ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, pad
return
conv
return
conv
def
ddpm_
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
"""3x3 convolution with DDPM initialization."""
"""3x3 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
...
@@ -317,10 +318,6 @@ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_sc
...
@@ -317,10 +318,6 @@ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_sc
return
conv
return
conv
conv1x1
=
ddpm_conv1x1
conv3x3
=
ddpm_conv3x3
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
einsum_str
=
"{},{}->{}"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
einsum_str
=
"{},{}->{}"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
...
@@ -494,135 +491,135 @@ class Downsample(nn.Module):
...
@@ -494,135 +491,135 @@ class Downsample(nn.Module):
return
x
return
x
class
ResnetBlockDDPMpp
(
nn
.
Module
):
#
class ResnetBlockDDPMpp(nn.Module):
"""ResBlock adapted from DDPM."""
#
"""ResBlock adapted from DDPM."""
#
def
__init__
(
#
def __init__(
self
,
#
self,
act
,
#
act,
in_ch
,
#
in_ch,
out_ch
=
None
,
#
out_ch=None,
temb_dim
=
None
,
#
temb_dim=None,
conv_shortcut
=
False
,
#
conv_shortcut=False,
dropout
=
0.1
,
#
dropout=0.1,
skip_rescale
=
False
,
#
skip_rescale=False,
init_scale
=
0.0
,
#
init_scale=0.0,
):
#
):
super
().
__init__
()
#
super().__init__()
out_ch
=
out_ch
if
out_ch
else
in_ch
#
out_ch = out_ch if out_ch else in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
#
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
#
self.Conv_0 = conv3x3(in_ch, out_ch)
if
temb_dim
is
not
None
:
#
if temb_dim is not None:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
#
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
data
.
shape
)
#
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
#
nn.init.zeros_(self.Dense_0.bias)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
#
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
#
self.Dropout_0 = nn.Dropout(dropout)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
#
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if
in_ch
!=
out_ch
:
#
if in_ch != out_ch:
if
conv_shortcut
:
#
if conv_shortcut:
self
.
Conv_2
=
conv3x3
(
in_ch
,
out_ch
)
#
self.Conv_2 = conv3x3(in_ch, out_ch)
else
:
#
else:
self
.
NIN_0
=
NIN
(
in_ch
,
out_ch
)
#
self.NIN_0 = NIN(in_ch, out_ch)
#
self
.
skip_rescale
=
skip_rescale
#
self.skip_rescale = skip_rescale
self
.
act
=
act
#
self.act = act
self
.
out_ch
=
out_ch
#
self.out_ch = out_ch
self
.
conv_shortcut
=
conv_shortcut
#
self.conv_shortcut = conv_shortcut
#
def
forward
(
self
,
x
,
temb
=
None
):
#
def forward(self, x, temb=None):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
#
h = self.act(self.GroupNorm_0(x))
h
=
self
.
Conv_0
(
h
)
#
h = self.Conv_0(h)
if
temb
is
not
None
:
#
if temb is not None:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
#
h += self.Dense_0(self.act(temb))[:, :, None, None]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
#
h = self.act(self.GroupNorm_1(h))
h
=
self
.
Dropout_0
(
h
)
#
h = self.Dropout_0(h)
h
=
self
.
Conv_1
(
h
)
#
h = self.Conv_1(h)
if
x
.
shape
[
1
]
!=
self
.
out_ch
:
#
if x.shape[1] != self.out_ch:
if
self
.
conv_shortcut
:
#
if self.conv_shortcut:
x
=
self
.
Conv_2
(
x
)
#
x = self.Conv_2(x)
else
:
#
else:
x
=
self
.
NIN_0
(
x
)
#
x = self.NIN_0(x)
if
not
self
.
skip_rescale
:
#
if not self.skip_rescale:
return
x
+
h
#
return x + h
else
:
#
else:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
#
return (x + h) / np.sqrt(2.0)
class
ResnetBlockBigGANpp
(
nn
.
Module
):
#
class ResnetBlockBigGANpp(nn.Module):
def
__init__
(
#
def __init__(
self
,
#
self,
act
,
#
act,
in_ch
,
#
in_ch,
out_ch
=
None
,
#
out_ch=None,
temb_dim
=
None
,
#
temb_dim=None,
up
=
False
,
#
up=False,
down
=
False
,
#
down=False,
dropout
=
0.1
,
#
dropout=0.1,
fir
=
False
,
#
fir=False,
fir_kernel
=
(
1
,
3
,
3
,
1
),
#
fir_kernel=(1, 3, 3, 1),
skip_rescale
=
True
,
#
skip_rescale=True,
init_scale
=
0.0
,
#
init_scale=0.0,
):
#
):
super
().
__init__
()
#
super().__init__()
#
out_ch
=
out_ch
if
out_ch
else
in_ch
#
out_ch = out_ch if out_ch else in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
#
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self
.
up
=
up
#
self.up = up
self
.
down
=
down
#
self.down = down
self
.
fir
=
fir
#
self.fir = fir
self
.
fir_kernel
=
fir_kernel
#
self.fir_kernel = fir_kernel
#
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
#
self.Conv_0 = conv3x3(in_ch, out_ch)
if
temb_dim
is
not
None
:
#
if temb_dim is not None:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
#
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
shape
)
#
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
#
nn.init.zeros_(self.Dense_0.bias)
#
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
#
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
#
self.Dropout_0 = nn.Dropout(dropout)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
#
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if
in_ch
!=
out_ch
or
up
or
down
:
#
if in_ch != out_ch or up or down:
self
.
Conv_2
=
conv1x1
(
in_ch
,
out_ch
)
#
self.Conv_2 = conv1x1(in_ch, out_ch)
#
self
.
skip_rescale
=
skip_rescale
#
self.skip_rescale = skip_rescale
self
.
act
=
act
#
self.act = act
self
.
in_ch
=
in_ch
#
self.in_ch = in_ch
self
.
out_ch
=
out_ch
#
self.out_ch = out_ch
#
def
forward
(
self
,
x
,
temb
=
None
):
#
def forward(self, x, temb=None):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
#
h = self.act(self.GroupNorm_0(x))
#
if
self
.
up
:
#
if self.up:
if
self
.
fir
:
#
if self.fir:
h
=
upsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
#
h = upsample_2d(h, self.fir_kernel, factor=2)
x
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
#
x = upsample_2d(x, self.fir_kernel, factor=2)
else
:
#
else:
h
=
naive_upsample_2d
(
h
,
factor
=
2
)
#
h = naive_upsample_2d(h, factor=2)
x
=
naive_upsample_2d
(
x
,
factor
=
2
)
#
x = naive_upsample_2d(x, factor=2)
elif
self
.
down
:
#
elif self.down:
if
self
.
fir
:
#
if self.fir:
h
=
downsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
#
h = downsample_2d(h, self.fir_kernel, factor=2)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
#
x = downsample_2d(x, self.fir_kernel, factor=2)
else
:
#
else:
h
=
naive_downsample_2d
(
h
,
factor
=
2
)
#
h = naive_downsample_2d(h, factor=2)
x
=
naive_downsample_2d
(
x
,
factor
=
2
)
#
x = naive_downsample_2d(x, factor=2)
#
h
=
self
.
Conv_0
(
h
)
#
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
# Add bias to each feature map conditioned on the time embedding
if
temb
is
not
None
:
#
if temb is not None:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
#
h += self.Dense_0(self.act(temb))[:, :, None, None]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
#
h = self.act(self.GroupNorm_1(h))
h
=
self
.
Dropout_0
(
h
)
#
h = self.Dropout_0(h)
h
=
self
.
Conv_1
(
h
)
#
h = self.Conv_1(h)
#
if
self
.
in_ch
!=
self
.
out_ch
or
self
.
up
or
self
.
down
:
#
if self.in_ch != self.out_ch or self.up or self.down:
x
=
self
.
Conv_2
(
x
)
#
x = self.Conv_2(x)
#
if
not
self
.
skip_rescale
:
#
if not self.skip_rescale:
return
x
+
h
#
return x + h
else
:
#
else:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
#
return (x + h) / np.sqrt(2.0)
class
NCSNpp
(
ModelMixin
,
ConfigMixin
):
class
NCSNpp
(
ModelMixin
,
ConfigMixin
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment