Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
1899457b
Unverified
Commit
1899457b
authored
Jun 29, 2022
by
Patrick von Platen
Committed by
GitHub
Jun 29, 2022
Browse files
Merge pull request #40 from huggingface/start_resnet_unificiation
resnet in one file
parents
e5d9baf0
ebf3717c
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1123 additions
and
564 deletions
+1123
-564
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+697
-122
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+41
-43
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+102
-111
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+21
-19
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+100
-104
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+30
-30
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+132
-135
No files found.
src/diffusers/models/resnet.py
View file @
1899457b
This diff is collapsed.
Click to expand it.
src/diffusers/models/unet.py
View file @
1899457b
...
...
@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock
,
Upsample
def
nonlinearity
(
x
):
...
...
@@ -34,46 +34,46 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
#
class ResnetBlock(nn.Module):
#
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
#
super().__init__()
#
self.in_channels = in_channels
#
out_channels = in_channels if out_channels is None else out_channels
#
self.out_channels = out_channels
#
self.use_conv_shortcut = conv_shortcut
#
#
self.norm1 = Normalize(in_channels)
#
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
#
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
#
self.norm2 = Normalize(out_channels)
#
self.dropout = torch.nn.Dropout(dropout)
#
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
#
if self.in_channels != self.out_channels:
#
if self.use_conv_shortcut:
#
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
#
else:
#
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
#
#
def forward(self, x, temb):
#
h = x
#
h = self.norm1(h)
#
h = nonlinearity(h)
#
h = self.conv1(h)
#
#
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
#
#
h = self.norm2(h)
#
h = nonlinearity(h)
#
h = self.dropout(h)
#
h = self.conv2(h)
#
#
if self.in_channels != self.out_channels:
#
if self.use_conv_shortcut:
#
x = self.conv_shortcut(x)
#
else:
#
x = self.nin_shortcut(x)
#
#
return x + h
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
...
...
@@ -127,7 +127,6 @@ class UNetModel(ModelMixin, ConfigMixin):
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
attn_2
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
...
...
@@ -142,7 +141,6 @@ class UNetModel(ModelMixin, ConfigMixin):
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
down
.
attn_2
=
attn_2
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
curr_res
=
curr_res
//
2
...
...
src/diffusers/models/unet_glide.py
View file @
1899457b
...
...
@@ -8,7 +8,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
def
convert_module_to_f16
(
l
):
...
...
@@ -96,16 +96,14 @@ def zero_module(module):
return
module
class
TimestepBlock
(
nn
.
Module
):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@
abstractmethod
def
forward
(
self
,
x
,
emb
):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
# class TimestepBlock(nn.Module):
# """
# Any module where forward() takes timestep embeddings as a second argument. #"""
#
# @abstractmethod
# def forward(self, x, emb):
# """
# Apply the module to `x` given `emb` timestep embeddings. #"""
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
...
...
@@ -124,106 +122,99 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
class
ResBlock
(
TimestepBlock
):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
downsampling.
"""
def
__init__
(
self
,
channels
,
emb_channels
,
dropout
,
out_channels
=
None
,
use_conv
=
False
,
use_scale_shift_norm
=
False
,
dims
=
2
,
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
emb_channels
=
emb_channels
self
.
dropout
=
dropout
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_checkpoint
=
use_checkpoint
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
,
swish
=
1.0
),
nn
.
Identity
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
,
swish
=
0.0
if
use_scale_shift_norm
else
1.0
),
nn
.
SiLU
()
if
use_scale_shift_norm
else
nn
.
Identity
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
)
if
self
.
out_channels
==
channels
:
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
def
forward
(
self
,
x
,
emb
):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
self
.
h_upd
(
h
)
x
=
self
.
x_upd
(
x
)
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
# class ResBlock(TimestepBlock):
# """
# A residual block that can optionally change the number of channels. # # :param channels: the number of input
channels
.
:
param
emb_channels
:
the
number
of
timestep
embedding
channels
.
# :param dropout: the rate of dropout. :param
out_channels
:
if
specified
,
the
number
of
out
channels
.
:
param
# use_conv: if True and out_channels is specified, use a
spatial
# convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
dims
:
determines
if
the
signal
is
1
D
,
2
D
,
or
3
D
.
:
param
use_checkpoint
:
if
True
,
use
gradient
checkpointing
# on this
module
.
:
param
up
:
if
True
,
use
this
block
for
upsampling
.
:
param
down
:
if
True
,
use
this
block
for
# downsampling. #"""
#
# def __init__(
# self,
# channels,
# emb_channels,
# dropout,
# out_channels=None,
# use_conv=False,
# use_scale_shift_norm=False,
# dims=2,
# use_checkpoint=False,
# up=False,
# down=False,
# ):
# super().__init__()
# self.channels = channels
# self.emb_channels = emb_channels
# self.dropout = dropout
# self.out_channels = out_channels or channels
# self.use_conv = use_conv
# self.use_checkpoint = use_checkpoint
# self.use_scale_shift_norm = use_scale_shift_norm
#
# self.in_layers = nn.Sequential(
# normalization(channels, swish=1.0),
# nn.Identity(),
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
# )
#
# self.updown = up or down
#
# if up:
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
# elif down:
# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
# else:
# self.h_upd = self.x_upd = nn.Identity()
#
# self.emb_layers = nn.Sequential(
# nn.SiLU(),
# linear(
# emb_channels,
# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
# ),
# )
# self.out_layers = nn.Sequential(
# normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
# nn.SiLU() if use_scale_shift_norm else nn.Identity(),
# nn.Dropout(p=dropout),
# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
# )
#
# if self.out_channels == channels:
# self.skip_connection = nn.Identity()
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
# else:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
#
# def forward(self, x, emb):
# """
# Apply the block to a Tensor, conditioned on a timestep embedding. # # :param x: an [N x C x ...] Tensor of features.
:
param
emb
:
an
[
N
x
emb_channels
]
Tensor
of
timestep
embeddings
.
# :return: an [N x C x ...] Tensor of outputs. #"""
# if self.updown:
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
# h = in_rest(x)
# h = self.h_upd(h)
# x = self.x_upd(x)
# h = in_conv(h)
# else:
# h = self.in_layers(x)
# emb_out = self.emb_layers(emb).type(h.dtype)
# while len(emb_out.shape) < len(h.shape):
# emb_out = emb_out[..., None]
# if self.use_scale_shift_norm:
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
# scale, shift = torch.chunk(emb_out, 2, dim=1)
# h = out_norm(h) * (1 + scale) + shift
# h = out_rest(h)
# else:
# h = h + emb_out
# h = self.out_layers(h)
# return self.skip_connection(x) + h
class
GlideUNetModel
(
ModelMixin
,
ConfigMixin
):
...
...
src/diffusers/models/unet_grad_tts.py
View file @
1899457b
...
...
@@ -4,7 +4,9 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
LinearAttention
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
from
.resnet
import
ResnetBlockGradTTS
as
ResnetBlock
from
.resnet
import
Upsample
class
Mish
(
torch
.
nn
.
Module
):
...
...
@@ -34,24 +36,24 @@ class Block(torch.nn.Module):
return
output
*
mask
class
ResnetBlock
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
time_emb_dim
,
groups
=
8
):
super
(
ResnetBlock
,
self
).
__init__
()
self
.
mlp
=
torch
.
nn
.
Sequential
(
Mish
(),
torch
.
nn
.
Linear
(
time_emb_dim
,
dim_out
))
self
.
block1
=
Block
(
dim
,
dim_out
,
groups
=
groups
)
self
.
block2
=
Block
(
dim_out
,
dim_out
,
groups
=
groups
)
if
dim
!=
dim_out
:
self
.
res_conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
else
:
self
.
res_conv
=
torch
.
nn
.
Identity
()
def
forward
(
self
,
x
,
mask
,
time_emb
):
h
=
self
.
block1
(
x
,
mask
)
h
+=
self
.
mlp
(
time_emb
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
h
=
self
.
block2
(
h
,
mask
)
output
=
h
+
self
.
res_conv
(
x
*
mask
)
return
output
#
class ResnetBlock(torch.nn.Module):
#
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
#
super(ResnetBlock, self).__init__()
#
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
#
#
self.block1 = Block(dim, dim_out, groups=groups)
#
self.block2 = Block(dim_out, dim_out, groups=groups)
#
if dim != dim_out:
#
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
#
else:
#
self.res_conv = torch.nn.Identity()
#
#
def forward(self, x, mask, time_emb):
#
h = self.block1(x, mask)
#
h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
#
h = self.block2(h, mask)
#
output = h + self.res_conv(x * mask)
#
return output
class
Residual
(
torch
.
nn
.
Module
):
...
...
src/diffusers/models/unet_ldm.py
View file @
1899457b
...
...
@@ -11,7 +11,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
def
exists
(
val
):
...
...
@@ -359,16 +359,14 @@ class AttentionPool2d(nn.Module):
return
x
[:,
:,
0
]
class
TimestepBlock
(
nn
.
Module
):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@
abstractmethod
def
forward
(
self
,
x
,
emb
):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
# class TimestepBlock(nn.Module):
# """
# Any module where forward() takes timestep embeddings as a second argument. #"""
#
# @abstractmethod
# def forward(self, x, emb):
# """
# Apply the module to `x` given `emb` timestep embeddings. #"""
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
...
...
@@ -387,99 +385,97 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
class
ResBlock
(
TimestepBlock
):
"""
A residual block that can optionally change the number of channels. :param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param
out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use
a spatial
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
downsampling.
"""
def
__init__
(
self
,
channels
,
emb_channels
,
dropout
,
out_channels
=
None
,
use_conv
=
False
,
use_scale_shift_norm
=
False
,
dims
=
2
,
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
emb_channels
=
emb_channels
self
.
dropout
=
dropout
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_checkpoint
=
use_checkpoint
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
),
nn
.
SiLU
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
),
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
)
if
self
.
out_channels
==
channels
:
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
def
forward
(
self
,
x
,
emb
):
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
self
.
h_upd
(
h
)
x
=
self
.
x_upd
(
x
)
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
# class A_ResBlock(TimestepBlock):
# """
# A residual block that can optionally change the number of channels. :param channels: the number of input channels. #
:
param
emb_channels
:
the
number
of
timestep
embedding
channels
.
:
param
dropout
:
the
rate
of
dropout
.
:
param
#
out_channels
:
if
specified
,
the
number
of
out
channels
.
:
param
use_conv
:
if
True
and
out_channels
is
specified
,
use
# a
spatial
# convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
dims
:
determines
if
the
signal
is
1
D
,
2
D
,
or
3
D
.
:
param
use_checkpoint
:
if
True
,
use
gradient
checkpointing
# on this
module
.
:
param
up
:
if
True
,
use
this
block
for
upsampling
.
:
param
down
:
if
True
,
use
this
block
for
# downsampling. #"""
#
# def __init__(
# self,
# channels,
# emb_channels,
# dropout,
# out_channels=None,
# use_conv=False,
# use_scale_shift_norm=False,
# dims=2,
# use_checkpoint=False,
# up=False,
# down=False,
# ):
# super().__init__()
# self.channels = channels
# self.emb_channels = emb_channels
# self.dropout = dropout
# self.out_channels = out_channels or channels
# self.use_conv = use_conv
# self.use_checkpoint = use_checkpoint
# self.use_scale_shift_norm = use_scale_shift_norm
#
# self.in_layers = nn.Sequential(
# normalization(channels),
# nn.SiLU(),
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
# )
#
# self.updown = up or down
#
# if up:
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
# elif down:
# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
# else:
# self.h_upd = self.x_upd = nn.Identity()
#
# self.emb_layers = nn.Sequential(
# nn.SiLU(),
# linear(
# emb_channels,
# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
# ),
# )
# self.out_layers = nn.Sequential(
# normalization(self.out_channels),
# nn.SiLU(),
# nn.Dropout(p=dropout),
# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
# )
#
# if self.out_channels == channels:
# self.skip_connection = nn.Identity()
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
# else:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
#
# def forward(self, x, emb):
# if self.updown:
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
# h = in_rest(x)
# h = self.h_upd(h)
# x = self.x_upd(x)
# h = in_conv(h)
# else:
# h = self.in_layers(x)
# emb_out = self.emb_layers(emb).type(h.dtype)
# while len(emb_out.shape) < len(h.shape):
# emb_out = emb_out[..., None]
# if self.use_scale_shift_norm:
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
# scale, shift = torch.chunk(emb_out, 2, dim=1)
# h = out_norm(h) * (1 + scale) + shift
# h = out_rest(h)
# else:
# h = h + emb_out
# h = self.out_layers(h)
# return self.skip_connection(x) + h
#
class
QKVAttention
(
nn
.
Module
):
...
...
src/diffusers/models/unet_rl.py
View file @
1899457b
...
...
@@ -6,6 +6,7 @@ import torch.nn as nn
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
ResidualTemporalBlock
class
SinusoidalPosEmb
(
nn
.
Module
):
...
...
@@ -72,36 +73,35 @@ class Conv1dBlock(nn.Module):
return
self
.
block
(
x
)
class
ResidualTemporalBlock
(
nn
.
Module
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
super
().
__init__
()
self
.
blocks
=
nn
.
ModuleList
(
[
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
),
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
),
]
)
self
.
time_mlp
=
nn
.
Sequential
(
nn
.
Mish
(),
nn
.
Linear
(
embed_dim
,
out_channels
),
RearrangeDim
(),
# Rearrange("batch t -> batch t 1"),
)
self
.
residual_conv
=
(
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
)
def
forward
(
self
,
x
,
t
):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out
=
self
.
blocks
[
0
](
x
)
+
self
.
time_mlp
(
t
)
out
=
self
.
blocks
[
1
](
out
)
return
out
+
self
.
residual_conv
(
x
)
# class ResidualTemporalBlock(nn.Module):
# def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
# super().__init__()
#
# self.blocks = nn.ModuleList(
# [
# Conv1dBlock(inp_channels, out_channels, kernel_size),
# Conv1dBlock(out_channels, out_channels, kernel_size),
# ]
# )
#
# self.time_mlp = nn.Sequential(
# nn.Mish(),
# nn.Linear(embed_dim, out_channels),
# RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
# )
#
# self.residual_conv = (
# nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
# )
#
# def forward(self, x, t):
# """
# x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x #
out_channels
x
horizon
]
#"""
# out = self.blocks[0](x) + self.time_mlp(t)
# out = self.blocks[1](out)
# return out + self.residual_conv(x)
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
1899457b
...
...
@@ -28,6 +28,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
ResnetBlockBigGANpp
,
ResnetBlockDDPMpp
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
...
...
@@ -299,7 +300,7 @@ def downsample_2d(x, k=None, factor=2, gain=1):
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
def
ddpm_
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
"""1x1 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
...
...
@@ -307,7 +308,7 @@ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, pad
return
conv
def
ddpm_
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
"""3x3 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
...
...
@@ -317,10 +318,6 @@ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_sc
return
conv
conv1x1
=
ddpm_conv1x1
conv3x3
=
ddpm_conv3x3
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
einsum_str
=
"{},{}->{}"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
...
...
@@ -494,135 +491,135 @@ class Downsample(nn.Module):
return
x
class
ResnetBlockDDPMpp
(
nn
.
Module
):
"""ResBlock adapted from DDPM."""
def
__init__
(
self
,
act
,
in_ch
,
out_ch
=
None
,
temb_dim
=
None
,
conv_shortcut
=
False
,
dropout
=
0.1
,
skip_rescale
=
False
,
init_scale
=
0.0
,
):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
if
in_ch
!=
out_ch
:
if
conv_shortcut
:
self
.
Conv_2
=
conv3x3
(
in_ch
,
out_ch
)
else
:
self
.
NIN_0
=
NIN
(
in_ch
,
out_ch
)
self
.
skip_rescale
=
skip_rescale
self
.
act
=
act
self
.
out_ch
=
out_ch
self
.
conv_shortcut
=
conv_shortcut
def
forward
(
self
,
x
,
temb
=
None
):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
h
=
self
.
Conv_0
(
h
)
if
temb
is
not
None
:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
h
=
self
.
Dropout_0
(
h
)
h
=
self
.
Conv_1
(
h
)
if
x
.
shape
[
1
]
!=
self
.
out_ch
:
if
self
.
conv_shortcut
:
x
=
self
.
Conv_2
(
x
)
else
:
x
=
self
.
NIN_0
(
x
)
if
not
self
.
skip_rescale
:
return
x
+
h
else
:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
class
ResnetBlockBigGANpp
(
nn
.
Module
):
def
__init__
(
self
,
act
,
in_ch
,
out_ch
=
None
,
temb_dim
=
None
,
up
=
False
,
down
=
False
,
dropout
=
0.1
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
skip_rescale
=
True
,
init_scale
=
0.0
,
):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
up
=
up
self
.
down
=
down
self
.
fir
=
fir
self
.
fir_kernel
=
fir_kernel
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
if
in_ch
!=
out_ch
or
up
or
down
:
self
.
Conv_2
=
conv1x1
(
in_ch
,
out_ch
)
self
.
skip_rescale
=
skip_rescale
self
.
act
=
act
self
.
in_ch
=
in_ch
self
.
out_ch
=
out_ch
def
forward
(
self
,
x
,
temb
=
None
):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
if
self
.
up
:
if
self
.
fir
:
h
=
upsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
naive_upsample_2d
(
h
,
factor
=
2
)
x
=
naive_upsample_2d
(
x
,
factor
=
2
)
elif
self
.
down
:
if
self
.
fir
:
h
=
downsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
naive_downsample_2d
(
h
,
factor
=
2
)
x
=
naive_downsample_2d
(
x
,
factor
=
2
)
h
=
self
.
Conv_0
(
h
)
# Add bias to each feature map conditioned on the time embedding
if
temb
is
not
None
:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
h
=
self
.
Dropout_0
(
h
)
h
=
self
.
Conv_1
(
h
)
if
self
.
in_ch
!=
self
.
out_ch
or
self
.
up
or
self
.
down
:
x
=
self
.
Conv_2
(
x
)
if
not
self
.
skip_rescale
:
return
x
+
h
else
:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
#
class ResnetBlockDDPMpp(nn.Module):
#
"""ResBlock adapted from DDPM."""
#
#
def __init__(
#
self,
#
act,
#
in_ch,
#
out_ch=None,
#
temb_dim=None,
#
conv_shortcut=False,
#
dropout=0.1,
#
skip_rescale=False,
#
init_scale=0.0,
#
):
#
super().__init__()
#
out_ch = out_ch if out_ch else in_ch
#
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
#
self.Conv_0 = conv3x3(in_ch, out_ch)
#
if temb_dim is not None:
#
self.Dense_0 = nn.Linear(temb_dim, out_ch)
#
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
#
nn.init.zeros_(self.Dense_0.bias)
#
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
#
self.Dropout_0 = nn.Dropout(dropout)
#
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
#
if in_ch != out_ch:
#
if conv_shortcut:
#
self.Conv_2 = conv3x3(in_ch, out_ch)
#
else:
#
self.NIN_0 = NIN(in_ch, out_ch)
#
#
self.skip_rescale = skip_rescale
#
self.act = act
#
self.out_ch = out_ch
#
self.conv_shortcut = conv_shortcut
#
#
def forward(self, x, temb=None):
#
h = self.act(self.GroupNorm_0(x))
#
h = self.Conv_0(h)
#
if temb is not None:
#
h += self.Dense_0(self.act(temb))[:, :, None, None]
#
h = self.act(self.GroupNorm_1(h))
#
h = self.Dropout_0(h)
#
h = self.Conv_1(h)
#
if x.shape[1] != self.out_ch:
#
if self.conv_shortcut:
#
x = self.Conv_2(x)
#
else:
#
x = self.NIN_0(x)
#
if not self.skip_rescale:
#
return x + h
#
else:
#
return (x + h) / np.sqrt(2.0)
#
class ResnetBlockBigGANpp(nn.Module):
#
def __init__(
#
self,
#
act,
#
in_ch,
#
out_ch=None,
#
temb_dim=None,
#
up=False,
#
down=False,
#
dropout=0.1,
#
fir=False,
#
fir_kernel=(1, 3, 3, 1),
#
skip_rescale=True,
#
init_scale=0.0,
#
):
#
super().__init__()
#
#
out_ch = out_ch if out_ch else in_ch
#
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
#
self.up = up
#
self.down = down
#
self.fir = fir
#
self.fir_kernel = fir_kernel
#
#
self.Conv_0 = conv3x3(in_ch, out_ch)
#
if temb_dim is not None:
#
self.Dense_0 = nn.Linear(temb_dim, out_ch)
#
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
#
nn.init.zeros_(self.Dense_0.bias)
#
#
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
#
self.Dropout_0 = nn.Dropout(dropout)
#
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
#
if in_ch != out_ch or up or down:
#
self.Conv_2 = conv1x1(in_ch, out_ch)
#
#
self.skip_rescale = skip_rescale
#
self.act = act
#
self.in_ch = in_ch
#
self.out_ch = out_ch
#
#
def forward(self, x, temb=None):
#
h = self.act(self.GroupNorm_0(x))
#
#
if self.up:
#
if self.fir:
#
h = upsample_2d(h, self.fir_kernel, factor=2)
#
x = upsample_2d(x, self.fir_kernel, factor=2)
#
else:
#
h = naive_upsample_2d(h, factor=2)
#
x = naive_upsample_2d(x, factor=2)
#
elif self.down:
#
if self.fir:
#
h = downsample_2d(h, self.fir_kernel, factor=2)
#
x = downsample_2d(x, self.fir_kernel, factor=2)
#
else:
#
h = naive_downsample_2d(h, factor=2)
#
x = naive_downsample_2d(x, factor=2)
#
#
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
#
if temb is not None:
#
h += self.Dense_0(self.act(temb))[:, :, None, None]
#
h = self.act(self.GroupNorm_1(h))
#
h = self.Dropout_0(h)
#
h = self.Conv_1(h)
#
#
if self.in_ch != self.out_ch or self.up or self.down:
#
x = self.Conv_2(x)
#
#
if not self.skip_rescale:
#
return x + h
#
else:
#
return (x + h) / np.sqrt(2.0)
class
NCSNpp
(
ModelMixin
,
ConfigMixin
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment