Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
1899457b
Unverified
Commit
1899457b
authored
Jun 29, 2022
by
Patrick von Platen
Committed by
GitHub
Jun 29, 2022
Browse files
Merge pull request #40 from huggingface/start_resnet_unificiation
resnet in one file
parents
e5d9baf0
ebf3717c
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1123 additions
and
564 deletions
+1123
-564
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+697
-122
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+41
-43
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+102
-111
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+21
-19
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+100
-104
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+30
-30
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+132
-135
No files found.
src/diffusers/models/resnet.py
View file @
1899457b
This diff is collapsed.
Click to expand it.
src/diffusers/models/unet.py
View file @
1899457b
...
...
@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock
,
Upsample
def
nonlinearity
(
x
):
...
...
@@ -34,46 +34,46 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
#
class ResnetBlock(nn.Module):
#
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
#
super().__init__()
#
self.in_channels = in_channels
#
out_channels = in_channels if out_channels is None else out_channels
#
self.out_channels = out_channels
#
self.use_conv_shortcut = conv_shortcut
#
#
self.norm1 = Normalize(in_channels)
#
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
#
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
#
self.norm2 = Normalize(out_channels)
#
self.dropout = torch.nn.Dropout(dropout)
#
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
#
if self.in_channels != self.out_channels:
#
if self.use_conv_shortcut:
#
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
#
else:
#
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
#
#
def forward(self, x, temb):
#
h = x
#
h = self.norm1(h)
#
h = nonlinearity(h)
#
h = self.conv1(h)
#
#
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
#
#
h = self.norm2(h)
#
h = nonlinearity(h)
#
h = self.dropout(h)
#
h = self.conv2(h)
#
#
if self.in_channels != self.out_channels:
#
if self.use_conv_shortcut:
#
x = self.conv_shortcut(x)
#
else:
#
x = self.nin_shortcut(x)
#
#
return x + h
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
...
...
@@ -127,7 +127,6 @@ class UNetModel(ModelMixin, ConfigMixin):
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
attn_2
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
...
...
@@ -142,7 +141,6 @@ class UNetModel(ModelMixin, ConfigMixin):
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
down
.
attn_2
=
attn_2
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
curr_res
=
curr_res
//
2
...
...
src/diffusers/models/unet_glide.py
View file @
1899457b
...
...
@@ -8,7 +8,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
def
convert_module_to_f16
(
l
):
...
...
@@ -96,16 +96,14 @@ def zero_module(module):
return
module
class
TimestepBlock
(
nn
.
Module
):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@
abstractmethod
def
forward
(
self
,
x
,
emb
):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
# class TimestepBlock(nn.Module):
# """
# Any module where forward() takes timestep embeddings as a second argument. #"""
#
# @abstractmethod
# def forward(self, x, emb):
# """
# Apply the module to `x` given `emb` timestep embeddings. #"""
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
...
...
@@ -124,106 +122,99 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
class
ResBlock
(
TimestepBlock
):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
downsampling.
"""
def
__init__
(
self
,
channels
,
emb_channels
,
dropout
,
out_channels
=
None
,
use_conv
=
False
,
use_scale_shift_norm
=
False
,
dims
=
2
,
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
emb_channels
=
emb_channels
self
.
dropout
=
dropout
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_checkpoint
=
use_checkpoint
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
,
swish
=
1.0
),
nn
.
Identity
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
,
swish
=
0.0
if
use_scale_shift_norm
else
1.0
),
nn
.
SiLU
()
if
use_scale_shift_norm
else
nn
.
Identity
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
)
if
self
.
out_channels
==
channels
:
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
def
forward
(
self
,
x
,
emb
):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
self
.
h_upd
(
h
)
x
=
self
.
x_upd
(
x
)
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
# class ResBlock(TimestepBlock):
# """
# A residual block that can optionally change the number of channels. # # :param channels: the number of input
channels
.
:
param
emb_channels
:
the
number
of
timestep
embedding
channels
.
# :param dropout: the rate of dropout. :param
out_channels
:
if
specified
,
the
number
of
out
channels
.
:
param
# use_conv: if True and out_channels is specified, use a
spatial
# convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
dims
:
determines
if
the
signal
is
1
D
,
2
D
,
or
3
D
.
:
param
use_checkpoint
:
if
True
,
use
gradient
checkpointing
# on this
module
.
:
param
up
:
if
True
,
use
this
block
for
upsampling
.
:
param
down
:
if
True
,
use
this
block
for
# downsampling. #"""
#
# def __init__(
# self,
# channels,
# emb_channels,
# dropout,
# out_channels=None,
# use_conv=False,
# use_scale_shift_norm=False,
# dims=2,
# use_checkpoint=False,
# up=False,
# down=False,
# ):
# super().__init__()
# self.channels = channels
# self.emb_channels = emb_channels
# self.dropout = dropout
# self.out_channels = out_channels or channels
# self.use_conv = use_conv
# self.use_checkpoint = use_checkpoint
# self.use_scale_shift_norm = use_scale_shift_norm
#
# self.in_layers = nn.Sequential(
# normalization(channels, swish=1.0),
# nn.Identity(),
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
# )
#
# self.updown = up or down
#
# if up:
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
# elif down:
# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
# else:
# self.h_upd = self.x_upd = nn.Identity()
#
# self.emb_layers = nn.Sequential(
# nn.SiLU(),
# linear(
# emb_channels,
# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
# ),
# )
# self.out_layers = nn.Sequential(
# normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
# nn.SiLU() if use_scale_shift_norm else nn.Identity(),
# nn.Dropout(p=dropout),
# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
# )
#
# if self.out_channels == channels:
# self.skip_connection = nn.Identity()
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
# else:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
#
# def forward(self, x, emb):
# """
# Apply the block to a Tensor, conditioned on a timestep embedding. # # :param x: an [N x C x ...] Tensor of features.
:
param
emb
:
an
[
N
x
emb_channels
]
Tensor
of
timestep
embeddings
.
# :return: an [N x C x ...] Tensor of outputs. #"""
# if self.updown:
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
# h = in_rest(x)
# h = self.h_upd(h)
# x = self.x_upd(x)
# h = in_conv(h)
# else:
# h = self.in_layers(x)
# emb_out = self.emb_layers(emb).type(h.dtype)
# while len(emb_out.shape) < len(h.shape):
# emb_out = emb_out[..., None]
# if self.use_scale_shift_norm:
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
# scale, shift = torch.chunk(emb_out, 2, dim=1)
# h = out_norm(h) * (1 + scale) + shift
# h = out_rest(h)
# else:
# h = h + emb_out
# h = self.out_layers(h)
# return self.skip_connection(x) + h
class
GlideUNetModel
(
ModelMixin
,
ConfigMixin
):
...
...
src/diffusers/models/unet_grad_tts.py
View file @
1899457b
...
...
@@ -4,7 +4,9 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
LinearAttention
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
from
.resnet
import
ResnetBlockGradTTS
as
ResnetBlock
from
.resnet
import
Upsample
class
Mish
(
torch
.
nn
.
Module
):
...
...
@@ -34,24 +36,24 @@ class Block(torch.nn.Module):
return
output
*
mask
class
ResnetBlock
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
time_emb_dim
,
groups
=
8
):
super
(
ResnetBlock
,
self
).
__init__
()
self
.
mlp
=
torch
.
nn
.
Sequential
(
Mish
(),
torch
.
nn
.
Linear
(
time_emb_dim
,
dim_out
))
self
.
block1
=
Block
(
dim
,
dim_out
,
groups
=
groups
)
self
.
block2
=
Block
(
dim_out
,
dim_out
,
groups
=
groups
)
if
dim
!=
dim_out
:
self
.
res_conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
else
:
self
.
res_conv
=
torch
.
nn
.
Identity
()
def
forward
(
self
,
x
,
mask
,
time_emb
):
h
=
self
.
block1
(
x
,
mask
)
h
+=
self
.
mlp
(
time_emb
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
h
=
self
.
block2
(
h
,
mask
)
output
=
h
+
self
.
res_conv
(
x
*
mask
)
return
output
#
class ResnetBlock(torch.nn.Module):
#
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
#
super(ResnetBlock, self).__init__()
#
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
#
#
self.block1 = Block(dim, dim_out, groups=groups)
#
self.block2 = Block(dim_out, dim_out, groups=groups)
#
if dim != dim_out:
#
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
#
else:
#
self.res_conv = torch.nn.Identity()
#
#
def forward(self, x, mask, time_emb):
#
h = self.block1(x, mask)
#
h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
#
h = self.block2(h, mask)
#
output = h + self.res_conv(x * mask)
#
return output
class
Residual
(
torch
.
nn
.
Module
):
...
...
src/diffusers/models/unet_ldm.py
View file @
1899457b
...
...
@@ -11,7 +11,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
def
exists
(
val
):
...
...
@@ -359,16 +359,14 @@ class AttentionPool2d(nn.Module):
return
x
[:,
:,
0
]
class
TimestepBlock
(
nn
.
Module
):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@
abstractmethod
def
forward
(
self
,
x
,
emb
):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
# class TimestepBlock(nn.Module):
# """
# Any module where forward() takes timestep embeddings as a second argument. #"""
#
# @abstractmethod
# def forward(self, x, emb):
# """
# Apply the module to `x` given `emb` timestep embeddings. #"""
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
...
...
@@ -387,99 +385,97 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
class
ResBlock
(
TimestepBlock
):
"""
A residual block that can optionally change the number of channels. :param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param
out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use
a spatial
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
downsampling.
"""
def
__init__
(
self
,
channels
,
emb_channels
,
dropout
,
out_channels
=
None
,
use_conv
=
False
,
use_scale_shift_norm
=
False
,
dims
=
2
,
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
emb_channels
=
emb_channels
self
.
dropout
=
dropout
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_checkpoint
=
use_checkpoint
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
),
nn
.
SiLU
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
),
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
)
if
self
.
out_channels
==
channels
:
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
def
forward
(
self
,
x
,
emb
):
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
self
.
h_upd
(
h
)
x
=
self
.
x_upd
(
x
)
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
# class A_ResBlock(TimestepBlock):
# """
# A residual block that can optionally change the number of channels. :param channels: the number of input channels. #
:
param
emb_channels
:
the
number
of
timestep
embedding
channels
.
:
param
dropout
:
the
rate
of
dropout
.
:
param
#
out_channels
:
if
specified
,
the
number
of
out
channels
.
:
param
use_conv
:
if
True
and
out_channels
is
specified
,
use
# a
spatial
# convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
dims
:
determines
if
the
signal
is
1
D
,
2
D
,
or
3
D
.
:
param
use_checkpoint
:
if
True
,
use
gradient
checkpointing
# on this
module
.
:
param
up
:
if
True
,
use
this
block
for
upsampling
.
:
param
down
:
if
True
,
use
this
block
for
# downsampling. #"""
#
# def __init__(
# self,
# channels,
# emb_channels,
# dropout,
# out_channels=None,
# use_conv=False,
# use_scale_shift_norm=False,
# dims=2,
# use_checkpoint=False,
# up=False,
# down=False,
# ):
# super().__init__()
# self.channels = channels
# self.emb_channels = emb_channels
# self.dropout = dropout
# self.out_channels = out_channels or channels
# self.use_conv = use_conv
# self.use_checkpoint = use_checkpoint
# self.use_scale_shift_norm = use_scale_shift_norm
#
# self.in_layers = nn.Sequential(
# normalization(channels),
# nn.SiLU(),
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
# )
#
# self.updown = up or down
#
# if up:
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
# elif down:
# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
# else:
# self.h_upd = self.x_upd = nn.Identity()
#
# self.emb_layers = nn.Sequential(
# nn.SiLU(),
# linear(
# emb_channels,
# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
# ),
# )
# self.out_layers = nn.Sequential(
# normalization(self.out_channels),
# nn.SiLU(),
# nn.Dropout(p=dropout),
# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
# )
#
# if self.out_channels == channels:
# self.skip_connection = nn.Identity()
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
# else:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
#
# def forward(self, x, emb):
# if self.updown:
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
# h = in_rest(x)
# h = self.h_upd(h)
# x = self.x_upd(x)
# h = in_conv(h)
# else:
# h = self.in_layers(x)
# emb_out = self.emb_layers(emb).type(h.dtype)
# while len(emb_out.shape) < len(h.shape):
# emb_out = emb_out[..., None]
# if self.use_scale_shift_norm:
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
# scale, shift = torch.chunk(emb_out, 2, dim=1)
# h = out_norm(h) * (1 + scale) + shift
# h = out_rest(h)
# else:
# h = h + emb_out
# h = self.out_layers(h)
# return self.skip_connection(x) + h
#
class
QKVAttention
(
nn
.
Module
):
...
...
src/diffusers/models/unet_rl.py
View file @
1899457b
...
...
@@ -6,6 +6,7 @@ import torch.nn as nn
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
ResidualTemporalBlock
class
SinusoidalPosEmb
(
nn
.
Module
):
...
...
@@ -72,36 +73,35 @@ class Conv1dBlock(nn.Module):
return
self
.
block
(
x
)
class
ResidualTemporalBlock
(
nn
.
Module
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
super
().
__init__
()
self
.
blocks
=
nn
.
ModuleList
(
[
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
),
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
),
]
)
self
.
time_mlp
=
nn
.
Sequential
(
nn
.
Mish
(),
nn
.
Linear
(
embed_dim
,
out_channels
),
RearrangeDim
(),
# Rearrange("batch t -> batch t 1"),
)
self
.
residual_conv
=
(
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
)
def
forward
(
self
,
x
,
t
):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out
=
self
.
blocks
[
0
](
x
)
+
self
.
time_mlp
(
t
)
out
=
self
.
blocks
[
1
](
out
)
return
out
+
self
.
residual_conv
(
x
)
# class ResidualTemporalBlock(nn.Module):
# def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
# super().__init__()
#
# self.blocks = nn.ModuleList(
# [
# Conv1dBlock(inp_channels, out_channels, kernel_size),
# Conv1dBlock(out_channels, out_channels, kernel_size),
# ]
# )
#
# self.time_mlp = nn.Sequential(
# nn.Mish(),
# nn.Linear(embed_dim, out_channels),
# RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
# )
#
# self.residual_conv = (
# nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
# )
#
# def forward(self, x, t):
# """
# x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x #
out_channels
x
horizon
]
#"""
# out = self.blocks[0](x) + self.time_mlp(t)
# out = self.blocks[1](out)
# return out + self.residual_conv(x)
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
1899457b
...
...
@@ -28,6 +28,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
ResnetBlockBigGANpp
,
ResnetBlockDDPMpp
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
...
...
@@ -299,7 +300,7 @@ def downsample_2d(x, k=None, factor=2, gain=1):
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
def
ddpm_
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
"""1x1 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
...
...
@@ -307,7 +308,7 @@ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, pad
return
conv
def
ddpm_
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
"""3x3 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
...
...
@@ -317,10 +318,6 @@ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_sc
return
conv
conv1x1
=
ddpm_conv1x1
conv3x3
=
ddpm_conv3x3
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
einsum_str
=
"{},{}->{}"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
...
...
@@ -494,135 +491,135 @@ class Downsample(nn.Module):
return
x
class
ResnetBlockDDPMpp
(
nn
.
Module
):
"""ResBlock adapted from DDPM."""
def
__init__
(
self
,
act
,
in_ch
,
out_ch
=
None
,
temb_dim
=
None
,
conv_shortcut
=
False
,
dropout
=
0.1
,
skip_rescale
=
False
,
init_scale
=
0.0
,
):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
if
in_ch
!=
out_ch
:
if
conv_shortcut
:
self
.
Conv_2
=
conv3x3
(
in_ch
,
out_ch
)
else
:
self
.
NIN_0
=
NIN
(
in_ch
,
out_ch
)
self
.
skip_rescale
=
skip_rescale
self
.
act
=
act
self
.
out_ch
=
out_ch
self
.
conv_shortcut
=
conv_shortcut
def
forward
(
self
,
x
,
temb
=
None
):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
h
=
self
.
Conv_0
(
h
)
if
temb
is
not
None
:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
h
=
self
.
Dropout_0
(
h
)
h
=
self
.
Conv_1
(
h
)
if
x
.
shape
[
1
]
!=
self
.
out_ch
:
if
self
.
conv_shortcut
:
x
=
self
.
Conv_2
(
x
)
else
:
x
=
self
.
NIN_0
(
x
)
if
not
self
.
skip_rescale
:
return
x
+
h
else
:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
class
ResnetBlockBigGANpp
(
nn
.
Module
):
def
__init__
(
self
,
act
,
in_ch
,
out_ch
=
None
,
temb_dim
=
None
,
up
=
False
,
down
=
False
,
dropout
=
0.1
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
skip_rescale
=
True
,
init_scale
=
0.0
,
):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
up
=
up
self
.
down
=
down
self
.
fir
=
fir
self
.
fir_kernel
=
fir_kernel
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
if
in_ch
!=
out_ch
or
up
or
down
:
self
.
Conv_2
=
conv1x1
(
in_ch
,
out_ch
)
self
.
skip_rescale
=
skip_rescale
self
.
act
=
act
self
.
in_ch
=
in_ch
self
.
out_ch
=
out_ch
def
forward
(
self
,
x
,
temb
=
None
):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
if
self
.
up
:
if
self
.
fir
:
h
=
upsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
naive_upsample_2d
(
h
,
factor
=
2
)
x
=
naive_upsample_2d
(
x
,
factor
=
2
)
elif
self
.
down
:
if
self
.
fir
:
h
=
downsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
naive_downsample_2d
(
h
,
factor
=
2
)
x
=
naive_downsample_2d
(
x
,
factor
=
2
)
h
=
self
.
Conv_0
(
h
)
# Add bias to each feature map conditioned on the time embedding
if
temb
is
not
None
:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
h
=
self
.
Dropout_0
(
h
)
h
=
self
.
Conv_1
(
h
)
if
self
.
in_ch
!=
self
.
out_ch
or
self
.
up
or
self
.
down
:
x
=
self
.
Conv_2
(
x
)
if
not
self
.
skip_rescale
:
return
x
+
h
else
:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
#
class ResnetBlockDDPMpp(nn.Module):
#
"""ResBlock adapted from DDPM."""
#
#
def __init__(
#
self,
#
act,
#
in_ch,
#
out_ch=None,
#
temb_dim=None,
#
conv_shortcut=False,
#
dropout=0.1,
#
skip_rescale=False,
#
init_scale=0.0,
#
):
#
super().__init__()
#
out_ch = out_ch if out_ch else in_ch
#
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
#
self.Conv_0 = conv3x3(in_ch, out_ch)
#
if temb_dim is not None:
#
self.Dense_0 = nn.Linear(temb_dim, out_ch)
#
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
#
nn.init.zeros_(self.Dense_0.bias)
#
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
#
self.Dropout_0 = nn.Dropout(dropout)
#
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
#
if in_ch != out_ch:
#
if conv_shortcut:
#
self.Conv_2 = conv3x3(in_ch, out_ch)
#
else:
#
self.NIN_0 = NIN(in_ch, out_ch)
#
#
self.skip_rescale = skip_rescale
#
self.act = act
#
self.out_ch = out_ch
#
self.conv_shortcut = conv_shortcut
#
#
def forward(self, x, temb=None):
#
h = self.act(self.GroupNorm_0(x))
#
h = self.Conv_0(h)
#
if temb is not None:
#
h += self.Dense_0(self.act(temb))[:, :, None, None]
#
h = self.act(self.GroupNorm_1(h))
#
h = self.Dropout_0(h)
#
h = self.Conv_1(h)
#
if x.shape[1] != self.out_ch:
#
if self.conv_shortcut:
#
x = self.Conv_2(x)
#
else:
#
x = self.NIN_0(x)
#
if not self.skip_rescale:
#
return x + h
#
else:
#
return (x + h) / np.sqrt(2.0)
#
class ResnetBlockBigGANpp(nn.Module):
#
def __init__(
#
self,
#
act,
#
in_ch,
#
out_ch=None,
#
temb_dim=None,
#
up=False,
#
down=False,
#
dropout=0.1,
#
fir=False,
#
fir_kernel=(1, 3, 3, 1),
#
skip_rescale=True,
#
init_scale=0.0,
#
):
#
super().__init__()
#
#
out_ch = out_ch if out_ch else in_ch
#
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
#
self.up = up
#
self.down = down
#
self.fir = fir
#
self.fir_kernel = fir_kernel
#
#
self.Conv_0 = conv3x3(in_ch, out_ch)
#
if temb_dim is not None:
#
self.Dense_0 = nn.Linear(temb_dim, out_ch)
#
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
#
nn.init.zeros_(self.Dense_0.bias)
#
#
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
#
self.Dropout_0 = nn.Dropout(dropout)
#
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
#
if in_ch != out_ch or up or down:
#
self.Conv_2 = conv1x1(in_ch, out_ch)
#
#
self.skip_rescale = skip_rescale
#
self.act = act
#
self.in_ch = in_ch
#
self.out_ch = out_ch
#
#
def forward(self, x, temb=None):
#
h = self.act(self.GroupNorm_0(x))
#
#
if self.up:
#
if self.fir:
#
h = upsample_2d(h, self.fir_kernel, factor=2)
#
x = upsample_2d(x, self.fir_kernel, factor=2)
#
else:
#
h = naive_upsample_2d(h, factor=2)
#
x = naive_upsample_2d(x, factor=2)
#
elif self.down:
#
if self.fir:
#
h = downsample_2d(h, self.fir_kernel, factor=2)
#
x = downsample_2d(x, self.fir_kernel, factor=2)
#
else:
#
h = naive_downsample_2d(h, factor=2)
#
x = naive_downsample_2d(x, factor=2)
#
#
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
#
if temb is not None:
#
h += self.Dense_0(self.act(temb))[:, :, None, None]
#
h = self.act(self.GroupNorm_1(h))
#
h = self.Dropout_0(h)
#
h = self.Conv_1(h)
#
#
if self.in_ch != self.out_ch or self.up or self.down:
#
x = self.Conv_2(x)
#
#
if not self.skip_rescale:
#
return x + h
#
else:
#
return (x + h) / np.sqrt(2.0)
class
NCSNpp
(
ModelMixin
,
ConfigMixin
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment