Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
1899457b
Unverified
Commit
1899457b
authored
Jun 29, 2022
by
Patrick von Platen
Committed by
GitHub
Jun 29, 2022
Browse files
Merge pull request #40 from huggingface/start_resnet_unificiation
resnet in one file
parents
e5d9baf0
ebf3717c
Changes
7
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1123 additions
and
564 deletions
+1123
-564
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+697
-122
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+41
-43
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+102
-111
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+21
-19
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+100
-104
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+30
-30
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+132
-135
No files found.
src/diffusers/models/resnet.py
View file @
1899457b
This diff is collapsed.
Click to expand it.
src/diffusers/models/unet.py
View file @
1899457b
...
@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock
,
Upsample
def
nonlinearity
(
x
):
def
nonlinearity
(
x
):
...
@@ -34,46 +34,46 @@ def Normalize(in_channels):
...
@@ -34,46 +34,46 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
ResnetBlock
(
nn
.
Module
):
#
class ResnetBlock(nn.Module):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
#
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super
().
__init__
()
#
super().__init__()
self
.
in_channels
=
in_channels
#
self.in_channels = in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
#
out_channels = in_channels if out_channels is None else out_channels
self
.
out_channels
=
out_channels
#
self.out_channels = out_channels
self
.
use_conv_shortcut
=
conv_shortcut
#
self.use_conv_shortcut = conv_shortcut
#
self
.
norm1
=
Normalize
(
in_channels
)
#
self.norm1 = Normalize(in_channels)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
#
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
#
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self
.
norm2
=
Normalize
(
out_channels
)
#
self.norm2 = Normalize(out_channels)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
#
self.dropout = torch.nn.Dropout(dropout)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
#
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if
self
.
in_channels
!=
self
.
out_channels
:
#
if self.in_channels != self.out_channels:
if
self
.
use_conv_shortcut
:
#
if self.use_conv_shortcut:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
#
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else
:
#
else:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
#
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
#
def
forward
(
self
,
x
,
temb
):
#
def forward(self, x, temb):
h
=
x
#
h = x
h
=
self
.
norm1
(
h
)
#
h = self.norm1(h)
h
=
nonlinearity
(
h
)
#
h = nonlinearity(h)
h
=
self
.
conv1
(
h
)
#
h = self.conv1(h)
#
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
#
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
#
h
=
self
.
norm2
(
h
)
#
h = self.norm2(h)
h
=
nonlinearity
(
h
)
#
h = nonlinearity(h)
h
=
self
.
dropout
(
h
)
#
h = self.dropout(h)
h
=
self
.
conv2
(
h
)
#
h = self.conv2(h)
#
if
self
.
in_channels
!=
self
.
out_channels
:
#
if self.in_channels != self.out_channels:
if
self
.
use_conv_shortcut
:
#
if self.use_conv_shortcut:
x
=
self
.
conv_shortcut
(
x
)
#
x = self.conv_shortcut(x)
else
:
#
else:
x
=
self
.
nin_shortcut
(
x
)
#
x = self.nin_shortcut(x)
#
return
x
+
h
#
return x + h
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
...
@@ -127,7 +127,6 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -127,7 +127,6 @@ class UNetModel(ModelMixin, ConfigMixin):
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
attn_2
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
for
i_block
in
range
(
self
.
num_res_blocks
):
...
@@ -142,7 +141,6 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -142,7 +141,6 @@ class UNetModel(ModelMixin, ConfigMixin):
down
=
nn
.
Module
()
down
=
nn
.
Module
()
down
.
block
=
block
down
.
block
=
block
down
.
attn
=
attn
down
.
attn
=
attn
down
.
attn_2
=
attn_2
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
down
.
downsample
=
Downsample
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
curr_res
=
curr_res
//
2
curr_res
=
curr_res
//
2
...
...
src/diffusers/models/unet_glide.py
View file @
1899457b
...
@@ -8,7 +8,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -8,7 +8,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
...
@@ -96,16 +96,14 @@ def zero_module(module):
...
@@ -96,16 +96,14 @@ def zero_module(module):
return
module
return
module
class
TimestepBlock
(
nn
.
Module
):
# class TimestepBlock(nn.Module):
"""
# """
Any module where forward() takes timestep embeddings as a second argument.
# Any module where forward() takes timestep embeddings as a second argument. #"""
"""
#
# @abstractmethod
@
abstractmethod
# def forward(self, x, emb):
def
forward
(
self
,
x
,
emb
):
# """
"""
# Apply the module to `x` given `emb` timestep embeddings. #"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
...
@@ -124,106 +122,99 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -124,106 +122,99 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
return
x
class
ResBlock
(
TimestepBlock
):
# class ResBlock(TimestepBlock):
"""
# """
A residual block that can optionally change the number of channels.
# A residual block that can optionally change the number of channels. # # :param channels: the number of input
channels
.
:
param
emb_channels
:
the
number
of
timestep
embedding
channels
.
# :param dropout: the rate of dropout. :param
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
out_channels
:
if
specified
,
the
number
of
out
channels
.
:
param
# use_conv: if True and out_channels is specified, use a
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
spatial
# convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
use_conv: if True and out_channels is specified, use a spatial
dims
:
determines
if
the
signal
is
1
D
,
2
D
,
or
3
D
.
:
param
use_checkpoint
:
if
True
,
use
gradient
checkpointing
# on this
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
module
.
:
param
up
:
if
True
,
use
this
block
for
upsampling
.
:
param
down
:
if
True
,
use
this
block
for
# downsampling. #"""
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
#
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
# def __init__(
downsampling.
# self,
"""
# channels,
# emb_channels,
def
__init__
(
# dropout,
self
,
# out_channels=None,
channels
,
# use_conv=False,
emb_channels
,
# use_scale_shift_norm=False,
dropout
,
# dims=2,
out_channels
=
None
,
# use_checkpoint=False,
use_conv
=
False
,
# up=False,
use_scale_shift_norm
=
False
,
# down=False,
dims
=
2
,
# ):
use_checkpoint
=
False
,
# super().__init__()
up
=
False
,
# self.channels = channels
down
=
False
,
# self.emb_channels = emb_channels
):
# self.dropout = dropout
super
().
__init__
()
# self.out_channels = out_channels or channels
self
.
channels
=
channels
# self.use_conv = use_conv
self
.
emb_channels
=
emb_channels
# self.use_checkpoint = use_checkpoint
self
.
dropout
=
dropout
# self.use_scale_shift_norm = use_scale_shift_norm
self
.
out_channels
=
out_channels
or
channels
#
self
.
use_conv
=
use_conv
# self.in_layers = nn.Sequential(
self
.
use_checkpoint
=
use_checkpoint
# normalization(channels, swish=1.0),
self
.
use_scale_shift_norm
=
use_scale_shift_norm
# nn.Identity(),
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
self
.
in_layers
=
nn
.
Sequential
(
# )
normalization
(
channels
,
swish
=
1.0
),
#
nn
.
Identity
(),
# self.updown = up or down
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
#
)
# if up:
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self
.
updown
=
up
or
down
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
# elif down:
if
up
:
# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
# else:
elif
down
:
# self.h_upd = self.x_upd = nn.Identity()
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
#
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
# self.emb_layers = nn.Sequential(
else
:
# nn.SiLU(),
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
# linear(
# emb_channels,
self
.
emb_layers
=
nn
.
Sequential
(
# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
nn
.
SiLU
(),
# ),
linear
(
# )
emb_channels
,
# self.out_layers = nn.Sequential(
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
# normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
),
# nn.SiLU() if use_scale_shift_norm else nn.Identity(),
)
# nn.Dropout(p=dropout),
self
.
out_layers
=
nn
.
Sequential
(
# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
normalization
(
self
.
out_channels
,
swish
=
0.0
if
use_scale_shift_norm
else
1.0
),
# )
nn
.
SiLU
()
if
use_scale_shift_norm
else
nn
.
Identity
(),
#
nn
.
Dropout
(
p
=
dropout
),
# if self.out_channels == channels:
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
# self.skip_connection = nn.Identity()
)
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
if
self
.
out_channels
==
channels
:
# else:
self
.
skip_connection
=
nn
.
Identity
()
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
elif
use_conv
:
#
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
# def forward(self, x, emb):
else
:
# """
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
# Apply the block to a Tensor, conditioned on a timestep embedding. # # :param x: an [N x C x ...] Tensor of features.
:
param
emb
:
an
[
N
x
emb_channels
]
Tensor
of
timestep
embeddings
.
# :return: an [N x C x ...] Tensor of outputs. #"""
def
forward
(
self
,
x
,
emb
):
# if self.updown:
"""
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
Apply the block to a Tensor, conditioned on a timestep embedding.
# h = in_rest(x)
# h = self.h_upd(h)
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
# x = self.x_upd(x)
:return: an [N x C x ...] Tensor of outputs.
# h = in_conv(h)
"""
# else:
if
self
.
updown
:
# h = self.in_layers(x)
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
# emb_out = self.emb_layers(emb).type(h.dtype)
h
=
in_rest
(
x
)
# while len(emb_out.shape) < len(h.shape):
h
=
self
.
h_upd
(
h
)
# emb_out = emb_out[..., None]
x
=
self
.
x_upd
(
x
)
# if self.use_scale_shift_norm:
h
=
in_conv
(
h
)
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
else
:
# scale, shift = torch.chunk(emb_out, 2, dim=1)
h
=
self
.
in_layers
(
x
)
# h = out_norm(h) * (1 + scale) + shift
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
# h = out_rest(h)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
# else:
emb_out
=
emb_out
[...,
None
]
# h = h + emb_out
if
self
.
use_scale_shift_norm
:
# h = self.out_layers(h)
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
# return self.skip_connection(x) + h
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
class
GlideUNetModel
(
ModelMixin
,
ConfigMixin
):
class
GlideUNetModel
(
ModelMixin
,
ConfigMixin
):
...
...
src/diffusers/models/unet_grad_tts.py
View file @
1899457b
...
@@ -4,7 +4,9 @@ from ..configuration_utils import ConfigMixin
...
@@ -4,7 +4,9 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
LinearAttention
from
.attention
import
LinearAttention
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
from
.resnet
import
ResnetBlockGradTTS
as
ResnetBlock
from
.resnet
import
Upsample
class
Mish
(
torch
.
nn
.
Module
):
class
Mish
(
torch
.
nn
.
Module
):
...
@@ -34,24 +36,24 @@ class Block(torch.nn.Module):
...
@@ -34,24 +36,24 @@ class Block(torch.nn.Module):
return
output
*
mask
return
output
*
mask
class
ResnetBlock
(
torch
.
nn
.
Module
):
#
class ResnetBlock(torch.nn.Module):
def
__init__
(
self
,
dim
,
dim_out
,
time_emb_dim
,
groups
=
8
):
#
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
super
(
ResnetBlock
,
self
).
__init__
()
#
super(ResnetBlock, self).__init__()
self
.
mlp
=
torch
.
nn
.
Sequential
(
Mish
(),
torch
.
nn
.
Linear
(
time_emb_dim
,
dim_out
))
#
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
#
self
.
block1
=
Block
(
dim
,
dim_out
,
groups
=
groups
)
#
self.block1 = Block(dim, dim_out, groups=groups)
self
.
block2
=
Block
(
dim_out
,
dim_out
,
groups
=
groups
)
#
self.block2 = Block(dim_out, dim_out, groups=groups)
if
dim
!=
dim_out
:
#
if dim != dim_out:
self
.
res_conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
#
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else
:
#
else:
self
.
res_conv
=
torch
.
nn
.
Identity
()
#
self.res_conv = torch.nn.Identity()
#
def
forward
(
self
,
x
,
mask
,
time_emb
):
#
def forward(self, x, mask, time_emb):
h
=
self
.
block1
(
x
,
mask
)
#
h = self.block1(x, mask)
h
+=
self
.
mlp
(
time_emb
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
#
h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
h
=
self
.
block2
(
h
,
mask
)
#
h = self.block2(h, mask)
output
=
h
+
self
.
res_conv
(
x
*
mask
)
#
output = h + self.res_conv(x * mask)
return
output
#
return output
class
Residual
(
torch
.
nn
.
Module
):
class
Residual
(
torch
.
nn
.
Module
):
...
...
src/diffusers/models/unet_ldm.py
View file @
1899457b
...
@@ -11,7 +11,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -11,7 +11,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
def
exists
(
val
):
def
exists
(
val
):
...
@@ -359,16 +359,14 @@ class AttentionPool2d(nn.Module):
...
@@ -359,16 +359,14 @@ class AttentionPool2d(nn.Module):
return
x
[:,
:,
0
]
return
x
[:,
:,
0
]
class
TimestepBlock
(
nn
.
Module
):
# class TimestepBlock(nn.Module):
"""
# """
Any module where forward() takes timestep embeddings as a second argument.
# Any module where forward() takes timestep embeddings as a second argument. #"""
"""
#
# @abstractmethod
@
abstractmethod
# def forward(self, x, emb):
def
forward
(
self
,
x
,
emb
):
# """
"""
# Apply the module to `x` given `emb` timestep embeddings. #"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
...
@@ -387,99 +385,97 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -387,99 +385,97 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
return
x
class
ResBlock
(
TimestepBlock
):
# class A_ResBlock(TimestepBlock):
"""
# """
A residual block that can optionally change the number of channels. :param channels: the number of input channels.
# A residual block that can optionally change the number of channels. :param channels: the number of input channels. #
:param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param
:
param
emb_channels
:
the
number
of
timestep
embedding
channels
.
:
param
dropout
:
the
rate
of
dropout
.
:
param
#
out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use
out_channels
:
if
specified
,
the
number
of
out
channels
.
:
param
use_conv
:
if
True
and
out_channels
is
specified
,
use
# a
a spatial
spatial
# convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
dims
:
determines
if
the
signal
is
1
D
,
2
D
,
or
3
D
.
:
param
use_checkpoint
:
if
True
,
use
gradient
checkpointing
# on this
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
module
.
:
param
up
:
if
True
,
use
this
block
for
upsampling
.
:
param
down
:
if
True
,
use
this
block
for
# downsampling. #"""
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
#
downsampling.
# def __init__(
"""
# self,
# channels,
def
__init__
(
# emb_channels,
self
,
# dropout,
channels
,
# out_channels=None,
emb_channels
,
# use_conv=False,
dropout
,
# use_scale_shift_norm=False,
out_channels
=
None
,
# dims=2,
use_conv
=
False
,
# use_checkpoint=False,
use_scale_shift_norm
=
False
,
# up=False,
dims
=
2
,
# down=False,
use_checkpoint
=
False
,
# ):
up
=
False
,
# super().__init__()
down
=
False
,
# self.channels = channels
):
# self.emb_channels = emb_channels
super
().
__init__
()
# self.dropout = dropout
self
.
channels
=
channels
# self.out_channels = out_channels or channels
self
.
emb_channels
=
emb_channels
# self.use_conv = use_conv
self
.
dropout
=
dropout
# self.use_checkpoint = use_checkpoint
self
.
out_channels
=
out_channels
or
channels
# self.use_scale_shift_norm = use_scale_shift_norm
self
.
use_conv
=
use_conv
#
self
.
use_checkpoint
=
use_checkpoint
# self.in_layers = nn.Sequential(
self
.
use_scale_shift_norm
=
use_scale_shift_norm
# normalization(channels),
# nn.SiLU(),
self
.
in_layers
=
nn
.
Sequential
(
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
normalization
(
channels
),
# )
nn
.
SiLU
(),
#
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
# self.updown = up or down
)
#
# if up:
self
.
updown
=
up
or
down
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
if
up
:
# elif down:
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
elif
down
:
# else:
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
# self.h_upd = self.x_upd = nn.Identity()
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
#
else
:
# self.emb_layers = nn.Sequential(
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
# nn.SiLU(),
# linear(
self
.
emb_layers
=
nn
.
Sequential
(
# emb_channels,
nn
.
SiLU
(),
# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
linear
(
# ),
emb_channels
,
# )
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
# self.out_layers = nn.Sequential(
),
# normalization(self.out_channels),
)
# nn.SiLU(),
self
.
out_layers
=
nn
.
Sequential
(
# nn.Dropout(p=dropout),
normalization
(
self
.
out_channels
),
# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
nn
.
SiLU
(),
# )
nn
.
Dropout
(
p
=
dropout
),
#
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
# if self.out_channels == channels:
)
# self.skip_connection = nn.Identity()
# elif use_conv:
if
self
.
out_channels
==
channels
:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
self
.
skip_connection
=
nn
.
Identity
()
# else:
elif
use_conv
:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
#
else
:
# def forward(self, x, emb):
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
# if self.updown:
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
def
forward
(
self
,
x
,
emb
):
# h = in_rest(x)
if
self
.
updown
:
# h = self.h_upd(h)
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
# x = self.x_upd(x)
h
=
in_rest
(
x
)
# h = in_conv(h)
h
=
self
.
h_upd
(
h
)
# else:
x
=
self
.
x_upd
(
x
)
# h = self.in_layers(x)
h
=
in_conv
(
h
)
# emb_out = self.emb_layers(emb).type(h.dtype)
else
:
# while len(emb_out.shape) < len(h.shape):
h
=
self
.
in_layers
(
x
)
# emb_out = emb_out[..., None]
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
# if self.use_scale_shift_norm:
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
emb_out
=
emb_out
[...,
None
]
# scale, shift = torch.chunk(emb_out, 2, dim=1)
if
self
.
use_scale_shift_norm
:
# h = out_norm(h) * (1 + scale) + shift
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
# h = out_rest(h)
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
# else:
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
# h = h + emb_out
h
=
out_rest
(
h
)
# h = self.out_layers(h)
else
:
# return self.skip_connection(x) + h
h
=
h
+
emb_out
#
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
class
QKVAttention
(
nn
.
Module
):
class
QKVAttention
(
nn
.
Module
):
...
...
src/diffusers/models/unet_rl.py
View file @
1899457b
...
@@ -6,6 +6,7 @@ import torch.nn as nn
...
@@ -6,6 +6,7 @@ import torch.nn as nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
ResidualTemporalBlock
class
SinusoidalPosEmb
(
nn
.
Module
):
class
SinusoidalPosEmb
(
nn
.
Module
):
...
@@ -72,36 +73,35 @@ class Conv1dBlock(nn.Module):
...
@@ -72,36 +73,35 @@ class Conv1dBlock(nn.Module):
return
self
.
block
(
x
)
return
self
.
block
(
x
)
class
ResidualTemporalBlock
(
nn
.
Module
):
# class ResidualTemporalBlock(nn.Module):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
# def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super
().
__init__
()
# super().__init__()
#
self
.
blocks
=
nn
.
ModuleList
(
# self.blocks = nn.ModuleList(
[
# [
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
),
# Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
),
# Conv1dBlock(out_channels, out_channels, kernel_size),
]
# ]
)
# )
#
self
.
time_mlp
=
nn
.
Sequential
(
# self.time_mlp = nn.Sequential(
nn
.
Mish
(),
# nn.Mish(),
nn
.
Linear
(
embed_dim
,
out_channels
),
# nn.Linear(embed_dim, out_channels),
RearrangeDim
(),
# RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
# Rearrange("batch t -> batch t 1"),
)
# )
#
self
.
residual_conv
=
(
# self.residual_conv = (
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
# nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
# )
#
def
forward
(
self
,
x
,
t
):
# def forward(self, x, t):
"""
# """
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
# x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x #
out_channels x horizon ]
out_channels
x
horizon
]
#"""
"""
# out = self.blocks[0](x) + self.time_mlp(t)
out
=
self
.
blocks
[
0
](
x
)
+
self
.
time_mlp
(
t
)
# out = self.blocks[1](out)
out
=
self
.
blocks
[
1
](
out
)
# return out + self.residual_conv(x)
return
out
+
self
.
residual_conv
(
x
)
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
1899457b
...
@@ -28,6 +28,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -28,6 +28,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
ResnetBlockBigGANpp
,
ResnetBlockDDPMpp
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
...
@@ -299,7 +300,7 @@ def downsample_2d(x, k=None, factor=2, gain=1):
...
@@ -299,7 +300,7 @@ def downsample_2d(x, k=None, factor=2, gain=1):
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
def
ddpm_
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
"""1x1 convolution with DDPM initialization."""
"""1x1 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
...
@@ -307,7 +308,7 @@ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, pad
...
@@ -307,7 +308,7 @@ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, pad
return
conv
return
conv
def
ddpm_
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
"""3x3 convolution with DDPM initialization."""
"""3x3 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
...
@@ -317,10 +318,6 @@ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_sc
...
@@ -317,10 +318,6 @@ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_sc
return
conv
return
conv
conv1x1
=
ddpm_conv1x1
conv3x3
=
ddpm_conv3x3
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
einsum_str
=
"{},{}->{}"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
einsum_str
=
"{},{}->{}"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
...
@@ -494,135 +491,135 @@ class Downsample(nn.Module):
...
@@ -494,135 +491,135 @@ class Downsample(nn.Module):
return
x
return
x
class
ResnetBlockDDPMpp
(
nn
.
Module
):
#
class ResnetBlockDDPMpp(nn.Module):
"""ResBlock adapted from DDPM."""
#
"""ResBlock adapted from DDPM."""
#
def
__init__
(
#
def __init__(
self
,
#
self,
act
,
#
act,
in_ch
,
#
in_ch,
out_ch
=
None
,
#
out_ch=None,
temb_dim
=
None
,
#
temb_dim=None,
conv_shortcut
=
False
,
#
conv_shortcut=False,
dropout
=
0.1
,
#
dropout=0.1,
skip_rescale
=
False
,
#
skip_rescale=False,
init_scale
=
0.0
,
#
init_scale=0.0,
):
#
):
super
().
__init__
()
#
super().__init__()
out_ch
=
out_ch
if
out_ch
else
in_ch
#
out_ch = out_ch if out_ch else in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
#
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
#
self.Conv_0 = conv3x3(in_ch, out_ch)
if
temb_dim
is
not
None
:
#
if temb_dim is not None:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
#
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
data
.
shape
)
#
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
#
nn.init.zeros_(self.Dense_0.bias)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
#
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
#
self.Dropout_0 = nn.Dropout(dropout)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
#
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if
in_ch
!=
out_ch
:
#
if in_ch != out_ch:
if
conv_shortcut
:
#
if conv_shortcut:
self
.
Conv_2
=
conv3x3
(
in_ch
,
out_ch
)
#
self.Conv_2 = conv3x3(in_ch, out_ch)
else
:
#
else:
self
.
NIN_0
=
NIN
(
in_ch
,
out_ch
)
#
self.NIN_0 = NIN(in_ch, out_ch)
#
self
.
skip_rescale
=
skip_rescale
#
self.skip_rescale = skip_rescale
self
.
act
=
act
#
self.act = act
self
.
out_ch
=
out_ch
#
self.out_ch = out_ch
self
.
conv_shortcut
=
conv_shortcut
#
self.conv_shortcut = conv_shortcut
#
def
forward
(
self
,
x
,
temb
=
None
):
#
def forward(self, x, temb=None):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
#
h = self.act(self.GroupNorm_0(x))
h
=
self
.
Conv_0
(
h
)
#
h = self.Conv_0(h)
if
temb
is
not
None
:
#
if temb is not None:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
#
h += self.Dense_0(self.act(temb))[:, :, None, None]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
#
h = self.act(self.GroupNorm_1(h))
h
=
self
.
Dropout_0
(
h
)
#
h = self.Dropout_0(h)
h
=
self
.
Conv_1
(
h
)
#
h = self.Conv_1(h)
if
x
.
shape
[
1
]
!=
self
.
out_ch
:
#
if x.shape[1] != self.out_ch:
if
self
.
conv_shortcut
:
#
if self.conv_shortcut:
x
=
self
.
Conv_2
(
x
)
#
x = self.Conv_2(x)
else
:
#
else:
x
=
self
.
NIN_0
(
x
)
#
x = self.NIN_0(x)
if
not
self
.
skip_rescale
:
#
if not self.skip_rescale:
return
x
+
h
#
return x + h
else
:
#
else:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
#
return (x + h) / np.sqrt(2.0)
class
ResnetBlockBigGANpp
(
nn
.
Module
):
#
class ResnetBlockBigGANpp(nn.Module):
def
__init__
(
#
def __init__(
self
,
#
self,
act
,
#
act,
in_ch
,
#
in_ch,
out_ch
=
None
,
#
out_ch=None,
temb_dim
=
None
,
#
temb_dim=None,
up
=
False
,
#
up=False,
down
=
False
,
#
down=False,
dropout
=
0.1
,
#
dropout=0.1,
fir
=
False
,
#
fir=False,
fir_kernel
=
(
1
,
3
,
3
,
1
),
#
fir_kernel=(1, 3, 3, 1),
skip_rescale
=
True
,
#
skip_rescale=True,
init_scale
=
0.0
,
#
init_scale=0.0,
):
#
):
super
().
__init__
()
#
super().__init__()
#
out_ch
=
out_ch
if
out_ch
else
in_ch
#
out_ch = out_ch if out_ch else in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
#
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self
.
up
=
up
#
self.up = up
self
.
down
=
down
#
self.down = down
self
.
fir
=
fir
#
self.fir = fir
self
.
fir_kernel
=
fir_kernel
#
self.fir_kernel = fir_kernel
#
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
#
self.Conv_0 = conv3x3(in_ch, out_ch)
if
temb_dim
is
not
None
:
#
if temb_dim is not None:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
#
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
shape
)
#
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
#
nn.init.zeros_(self.Dense_0.bias)
#
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
#
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
#
self.Dropout_0 = nn.Dropout(dropout)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
#
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if
in_ch
!=
out_ch
or
up
or
down
:
#
if in_ch != out_ch or up or down:
self
.
Conv_2
=
conv1x1
(
in_ch
,
out_ch
)
#
self.Conv_2 = conv1x1(in_ch, out_ch)
#
self
.
skip_rescale
=
skip_rescale
#
self.skip_rescale = skip_rescale
self
.
act
=
act
#
self.act = act
self
.
in_ch
=
in_ch
#
self.in_ch = in_ch
self
.
out_ch
=
out_ch
#
self.out_ch = out_ch
#
def
forward
(
self
,
x
,
temb
=
None
):
#
def forward(self, x, temb=None):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
#
h = self.act(self.GroupNorm_0(x))
#
if
self
.
up
:
#
if self.up:
if
self
.
fir
:
#
if self.fir:
h
=
upsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
#
h = upsample_2d(h, self.fir_kernel, factor=2)
x
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
#
x = upsample_2d(x, self.fir_kernel, factor=2)
else
:
#
else:
h
=
naive_upsample_2d
(
h
,
factor
=
2
)
#
h = naive_upsample_2d(h, factor=2)
x
=
naive_upsample_2d
(
x
,
factor
=
2
)
#
x = naive_upsample_2d(x, factor=2)
elif
self
.
down
:
#
elif self.down:
if
self
.
fir
:
#
if self.fir:
h
=
downsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
#
h = downsample_2d(h, self.fir_kernel, factor=2)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
#
x = downsample_2d(x, self.fir_kernel, factor=2)
else
:
#
else:
h
=
naive_downsample_2d
(
h
,
factor
=
2
)
#
h = naive_downsample_2d(h, factor=2)
x
=
naive_downsample_2d
(
x
,
factor
=
2
)
#
x = naive_downsample_2d(x, factor=2)
#
h
=
self
.
Conv_0
(
h
)
#
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
# Add bias to each feature map conditioned on the time embedding
if
temb
is
not
None
:
#
if temb is not None:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
#
h += self.Dense_0(self.act(temb))[:, :, None, None]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
#
h = self.act(self.GroupNorm_1(h))
h
=
self
.
Dropout_0
(
h
)
#
h = self.Dropout_0(h)
h
=
self
.
Conv_1
(
h
)
#
h = self.Conv_1(h)
#
if
self
.
in_ch
!=
self
.
out_ch
or
self
.
up
or
self
.
down
:
#
if self.in_ch != self.out_ch or self.up or self.down:
x
=
self
.
Conv_2
(
x
)
#
x = self.Conv_2(x)
#
if
not
self
.
skip_rescale
:
#
if not self.skip_rescale:
return
x
+
h
#
return x + h
else
:
#
else:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
#
return (x + h) / np.sqrt(2.0)
class
NCSNpp
(
ModelMixin
,
ConfigMixin
):
class
NCSNpp
(
ModelMixin
,
ConfigMixin
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment