Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
ebf3717c
Commit
ebf3717c
authored
Jun 29, 2022
by
Patrick von Platen
Browse files
resnet in one file
parent
e5d9baf0
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1123 additions
and
564 deletions
+1123
-564
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+697
-122
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+41
-43
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+102
-111
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+21
-19
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+100
-104
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+30
-30
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+132
-135
No files found.
src/diffusers/models/resnet.py
View file @
ebf3717c
This diff is collapsed.
Click to expand it.
src/diffusers/models/unet.py
View file @
ebf3717c
...
@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock
,
Upsample
def
nonlinearity
(
x
):
def
nonlinearity
(
x
):
...
@@ -34,46 +34,46 @@ def Normalize(in_channels):
...
@@ -34,46 +34,46 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
ResnetBlock
(
nn
.
Module
):
#
class ResnetBlock(nn.Module):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
#
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super
().
__init__
()
#
super().__init__()
self
.
in_channels
=
in_channels
#
self.in_channels = in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
#
out_channels = in_channels if out_channels is None else out_channels
self
.
out_channels
=
out_channels
#
self.out_channels = out_channels
self
.
use_conv_shortcut
=
conv_shortcut
#
self.use_conv_shortcut = conv_shortcut
#
self
.
norm1
=
Normalize
(
in_channels
)
#
self.norm1 = Normalize(in_channels)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
#
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
#
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self
.
norm2
=
Normalize
(
out_channels
)
#
self.norm2 = Normalize(out_channels)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
#
self.dropout = torch.nn.Dropout(dropout)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
#
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if
self
.
in_channels
!=
self
.
out_channels
:
#
if self.in_channels != self.out_channels:
if
self
.
use_conv_shortcut
:
#
if self.use_conv_shortcut:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
#
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else
:
#
else:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
#
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
#
def
forward
(
self
,
x
,
temb
):
#
def forward(self, x, temb):
h
=
x
#
h = x
h
=
self
.
norm1
(
h
)
#
h = self.norm1(h)
h
=
nonlinearity
(
h
)
#
h = nonlinearity(h)
h
=
self
.
conv1
(
h
)
#
h = self.conv1(h)
#
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
#
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
#
h
=
self
.
norm2
(
h
)
#
h = self.norm2(h)
h
=
nonlinearity
(
h
)
#
h = nonlinearity(h)
h
=
self
.
dropout
(
h
)
#
h = self.dropout(h)
h
=
self
.
conv2
(
h
)
#
h = self.conv2(h)
#
if
self
.
in_channels
!=
self
.
out_channels
:
#
if self.in_channels != self.out_channels:
if
self
.
use_conv_shortcut
:
#
if self.use_conv_shortcut:
x
=
self
.
conv_shortcut
(
x
)
#
x = self.conv_shortcut(x)
else
:
#
else:
x
=
self
.
nin_shortcut
(
x
)
#
x = self.nin_shortcut(x)
#
return
x
+
h
#
return x + h
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
...
@@ -127,7 +127,6 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -127,7 +127,6 @@ class UNetModel(ModelMixin, ConfigMixin):
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
attn_2
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
for
i_block
in
range
(
self
.
num_res_blocks
):
...
@@ -142,7 +141,6 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -142,7 +141,6 @@ class UNetModel(ModelMixin, ConfigMixin):
down
=
nn
.
Module
()
down
=
nn
.
Module
()
down
.
block
=
block
down
.
block
=
block
down
.
attn
=
attn
down
.
attn
=
attn
down
.
attn_2
=
attn_2
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
down
.
downsample
=
Downsample
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
curr_res
=
curr_res
//
2
curr_res
=
curr_res
//
2
...
...
src/diffusers/models/unet_glide.py
View file @
ebf3717c
...
@@ -8,7 +8,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -8,7 +8,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
...
@@ -96,16 +96,14 @@ def zero_module(module):
...
@@ -96,16 +96,14 @@ def zero_module(module):
return
module
return
module
class
TimestepBlock
(
nn
.
Module
):
# class TimestepBlock(nn.Module):
"""
# """
Any module where forward() takes timestep embeddings as a second argument.
# Any module where forward() takes timestep embeddings as a second argument. #"""
"""
#
# @abstractmethod
@
abstractmethod
# def forward(self, x, emb):
def
forward
(
self
,
x
,
emb
):
# """
"""
# Apply the module to `x` given `emb` timestep embeddings. #"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
...
@@ -124,106 +122,99 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -124,106 +122,99 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
return
x
class
ResBlock
(
TimestepBlock
):
# class ResBlock(TimestepBlock):
"""
# """
A residual block that can optionally change the number of channels.
# A residual block that can optionally change the number of channels. # # :param channels: the number of input
channels
.
:
param
emb_channels
:
the
number
of
timestep
embedding
channels
.
# :param dropout: the rate of dropout. :param
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
out_channels
:
if
specified
,
the
number
of
out
channels
.
:
param
# use_conv: if True and out_channels is specified, use a
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
spatial
# convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
use_conv: if True and out_channels is specified, use a spatial
dims
:
determines
if
the
signal
is
1
D
,
2
D
,
or
3
D
.
:
param
use_checkpoint
:
if
True
,
use
gradient
checkpointing
# on this
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
module
.
:
param
up
:
if
True
,
use
this
block
for
upsampling
.
:
param
down
:
if
True
,
use
this
block
for
# downsampling. #"""
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
#
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
# def __init__(
downsampling.
# self,
"""
# channels,
# emb_channels,
def
__init__
(
# dropout,
self
,
# out_channels=None,
channels
,
# use_conv=False,
emb_channels
,
# use_scale_shift_norm=False,
dropout
,
# dims=2,
out_channels
=
None
,
# use_checkpoint=False,
use_conv
=
False
,
# up=False,
use_scale_shift_norm
=
False
,
# down=False,
dims
=
2
,
# ):
use_checkpoint
=
False
,
# super().__init__()
up
=
False
,
# self.channels = channels
down
=
False
,
# self.emb_channels = emb_channels
):
# self.dropout = dropout
super
().
__init__
()
# self.out_channels = out_channels or channels
self
.
channels
=
channels
# self.use_conv = use_conv
self
.
emb_channels
=
emb_channels
# self.use_checkpoint = use_checkpoint
self
.
dropout
=
dropout
# self.use_scale_shift_norm = use_scale_shift_norm
self
.
out_channels
=
out_channels
or
channels
#
self
.
use_conv
=
use_conv
# self.in_layers = nn.Sequential(
self
.
use_checkpoint
=
use_checkpoint
# normalization(channels, swish=1.0),
self
.
use_scale_shift_norm
=
use_scale_shift_norm
# nn.Identity(),
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
self
.
in_layers
=
nn
.
Sequential
(
# )
normalization
(
channels
,
swish
=
1.0
),
#
nn
.
Identity
(),
# self.updown = up or down
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
#
)
# if up:
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self
.
updown
=
up
or
down
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
# elif down:
if
up
:
# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
# else:
elif
down
:
# self.h_upd = self.x_upd = nn.Identity()
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
#
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
# self.emb_layers = nn.Sequential(
else
:
# nn.SiLU(),
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
# linear(
# emb_channels,
self
.
emb_layers
=
nn
.
Sequential
(
# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
nn
.
SiLU
(),
# ),
linear
(
# )
emb_channels
,
# self.out_layers = nn.Sequential(
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
# normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
),
# nn.SiLU() if use_scale_shift_norm else nn.Identity(),
)
# nn.Dropout(p=dropout),
self
.
out_layers
=
nn
.
Sequential
(
# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
normalization
(
self
.
out_channels
,
swish
=
0.0
if
use_scale_shift_norm
else
1.0
),
# )
nn
.
SiLU
()
if
use_scale_shift_norm
else
nn
.
Identity
(),
#
nn
.
Dropout
(
p
=
dropout
),
# if self.out_channels == channels:
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
# self.skip_connection = nn.Identity()
)
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
if
self
.
out_channels
==
channels
:
# else:
self
.
skip_connection
=
nn
.
Identity
()
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
elif
use_conv
:
#
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
# def forward(self, x, emb):
else
:
# """
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
# Apply the block to a Tensor, conditioned on a timestep embedding. # # :param x: an [N x C x ...] Tensor of features.
:
param
emb
:
an
[
N
x
emb_channels
]
Tensor
of
timestep
embeddings
.
# :return: an [N x C x ...] Tensor of outputs. #"""
def
forward
(
self
,
x
,
emb
):
# if self.updown:
"""
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
Apply the block to a Tensor, conditioned on a timestep embedding.
# h = in_rest(x)
# h = self.h_upd(h)
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
# x = self.x_upd(x)
:return: an [N x C x ...] Tensor of outputs.
# h = in_conv(h)
"""
# else:
if
self
.
updown
:
# h = self.in_layers(x)
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
# emb_out = self.emb_layers(emb).type(h.dtype)
h
=
in_rest
(
x
)
# while len(emb_out.shape) < len(h.shape):
h
=
self
.
h_upd
(
h
)
# emb_out = emb_out[..., None]
x
=
self
.
x_upd
(
x
)
# if self.use_scale_shift_norm:
h
=
in_conv
(
h
)
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
else
:
# scale, shift = torch.chunk(emb_out, 2, dim=1)
h
=
self
.
in_layers
(
x
)
# h = out_norm(h) * (1 + scale) + shift
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
# h = out_rest(h)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
# else:
emb_out
=
emb_out
[...,
None
]
# h = h + emb_out
if
self
.
use_scale_shift_norm
:
# h = self.out_layers(h)
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
# return self.skip_connection(x) + h
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
class
GlideUNetModel
(
ModelMixin
,
ConfigMixin
):
class
GlideUNetModel
(
ModelMixin
,
ConfigMixin
):
...
...
src/diffusers/models/unet_grad_tts.py
View file @
ebf3717c
...
@@ -4,7 +4,9 @@ from ..configuration_utils import ConfigMixin
...
@@ -4,7 +4,9 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
LinearAttention
from
.attention
import
LinearAttention
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
from
.resnet
import
ResnetBlockGradTTS
as
ResnetBlock
from
.resnet
import
Upsample
class
Mish
(
torch
.
nn
.
Module
):
class
Mish
(
torch
.
nn
.
Module
):
...
@@ -34,24 +36,24 @@ class Block(torch.nn.Module):
...
@@ -34,24 +36,24 @@ class Block(torch.nn.Module):
return
output
*
mask
return
output
*
mask
class
ResnetBlock
(
torch
.
nn
.
Module
):
#
class ResnetBlock(torch.nn.Module):
def
__init__
(
self
,
dim
,
dim_out
,
time_emb_dim
,
groups
=
8
):
#
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
super
(
ResnetBlock
,
self
).
__init__
()
#
super(ResnetBlock, self).__init__()
self
.
mlp
=
torch
.
nn
.
Sequential
(
Mish
(),
torch
.
nn
.
Linear
(
time_emb_dim
,
dim_out
))
#
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
#
self
.
block1
=
Block
(
dim
,
dim_out
,
groups
=
groups
)
#
self.block1 = Block(dim, dim_out, groups=groups)
self
.
block2
=
Block
(
dim_out
,
dim_out
,
groups
=
groups
)
#
self.block2 = Block(dim_out, dim_out, groups=groups)
if
dim
!=
dim_out
:
#
if dim != dim_out:
self
.
res_conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
#
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else
:
#
else:
self
.
res_conv
=
torch
.
nn
.
Identity
()
#
self.res_conv = torch.nn.Identity()
#
def
forward
(
self
,
x
,
mask
,
time_emb
):
#
def forward(self, x, mask, time_emb):
h
=
self
.
block1
(
x
,
mask
)
#
h = self.block1(x, mask)
h
+=
self
.
mlp
(
time_emb
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
#
h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
h
=
self
.
block2
(
h
,
mask
)
#
h = self.block2(h, mask)
output
=
h
+
self
.
res_conv
(
x
*
mask
)
#
output = h + self.res_conv(x * mask)
return
output
#
return output
class
Residual
(
torch
.
nn
.
Module
):
class
Residual
(
torch
.
nn
.
Module
):
...
...
src/diffusers/models/unet_ldm.py
View file @
ebf3717c
...
@@ -11,7 +11,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -11,7 +11,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
def
exists
(
val
):
def
exists
(
val
):
...
@@ -359,16 +359,14 @@ class AttentionPool2d(nn.Module):
...
@@ -359,16 +359,14 @@ class AttentionPool2d(nn.Module):
return
x
[:,
:,
0
]
return
x
[:,
:,
0
]
class
TimestepBlock
(
nn
.
Module
):
# class TimestepBlock(nn.Module):
"""
# """
Any module where forward() takes timestep embeddings as a second argument.
# Any module where forward() takes timestep embeddings as a second argument. #"""
"""
#
# @abstractmethod
@
abstractmethod
# def forward(self, x, emb):
def
forward
(
self
,
x
,
emb
):
# """
"""
# Apply the module to `x` given `emb` timestep embeddings. #"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
...
@@ -387,99 +385,97 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -387,99 +385,97 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
return
x
class
ResBlock
(
TimestepBlock
):
# class A_ResBlock(TimestepBlock):
"""
# """
A residual block that can optionally change the number of channels. :param channels: the number of input channels.
# A residual block that can optionally change the number of channels. :param channels: the number of input channels. #
:param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param
:
param
emb_channels
:
the
number
of
timestep
embedding
channels
.
:
param
dropout
:
the
rate
of
dropout
.
:
param
#
out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use
out_channels
:
if
specified
,
the
number
of
out
channels
.
:
param
use_conv
:
if
True
and
out_channels
is
specified
,
use
# a
a spatial
spatial
# convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
dims
:
determines
if
the
signal
is
1
D
,
2
D
,
or
3
D
.
:
param
use_checkpoint
:
if
True
,
use
gradient
checkpointing
# on this
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
module
.
:
param
up
:
if
True
,
use
this
block
for
upsampling
.
:
param
down
:
if
True
,
use
this
block
for
# downsampling. #"""
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
#
downsampling.
# def __init__(
"""
# self,
# channels,
def
__init__
(
# emb_channels,
self
,
# dropout,
channels
,
# out_channels=None,
emb_channels
,
# use_conv=False,
dropout
,
# use_scale_shift_norm=False,
out_channels
=
None
,
# dims=2,
use_conv
=
False
,
# use_checkpoint=False,
use_scale_shift_norm
=
False
,
# up=False,
dims
=
2
,
# down=False,
use_checkpoint
=
False
,
# ):
up
=
False
,
# super().__init__()
down
=
False
,
# self.channels = channels
):
# self.emb_channels = emb_channels
super
().
__init__
()
# self.dropout = dropout
self
.
channels
=
channels
# self.out_channels = out_channels or channels
self
.
emb_channels
=
emb_channels
# self.use_conv = use_conv
self
.
dropout
=
dropout
# self.use_checkpoint = use_checkpoint
self
.
out_channels
=
out_channels
or
channels
# self.use_scale_shift_norm = use_scale_shift_norm
self
.
use_conv
=
use_conv
#
self
.
use_checkpoint
=
use_checkpoint
# self.in_layers = nn.Sequential(
self
.
use_scale_shift_norm
=
use_scale_shift_norm
# normalization(channels),
# nn.SiLU(),
self
.
in_layers
=
nn
.
Sequential
(
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
normalization
(
channels
),
# )
nn
.
SiLU
(),
#
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
# self.updown = up or down
)
#
# if up:
self
.
updown
=
up
or
down
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
if
up
:
# elif down:
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
elif
down
:
# else:
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
# self.h_upd = self.x_upd = nn.Identity()
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
#
else
:
# self.emb_layers = nn.Sequential(
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
# nn.SiLU(),
# linear(
self
.
emb_layers
=
nn
.
Sequential
(
# emb_channels,
nn
.
SiLU
(),
# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
linear
(
# ),
emb_channels
,
# )
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
# self.out_layers = nn.Sequential(
),
# normalization(self.out_channels),
)
# nn.SiLU(),
self
.
out_layers
=
nn
.
Sequential
(
# nn.Dropout(p=dropout),
normalization
(
self
.
out_channels
),
# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
nn
.
SiLU
(),
# )
nn
.
Dropout
(
p
=
dropout
),
#
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
# if self.out_channels == channels:
)
# self.skip_connection = nn.Identity()
# elif use_conv:
if
self
.
out_channels
==
channels
:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
self
.
skip_connection
=
nn
.
Identity
()
# else:
elif
use_conv
:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
#
else
:
# def forward(self, x, emb):
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
# if self.updown:
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
def
forward
(
self
,
x
,
emb
):
# h = in_rest(x)
if
self
.
updown
:
# h = self.h_upd(h)
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
# x = self.x_upd(x)
h
=
in_rest
(
x
)
# h = in_conv(h)
h
=
self
.
h_upd
(
h
)
# else:
x
=
self
.
x_upd
(
x
)
# h = self.in_layers(x)
h
=
in_conv
(
h
)
# emb_out = self.emb_layers(emb).type(h.dtype)
else
:
# while len(emb_out.shape) < len(h.shape):
h
=
self
.
in_layers
(
x
)
# emb_out = emb_out[..., None]
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
# if self.use_scale_shift_norm:
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
emb_out
=
emb_out
[...,
None
]
# scale, shift = torch.chunk(emb_out, 2, dim=1)
if
self
.
use_scale_shift_norm
:
# h = out_norm(h) * (1 + scale) + shift
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
# h = out_rest(h)
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
# else:
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
# h = h + emb_out
h
=
out_rest
(
h
)
# h = self.out_layers(h)
else
:
# return self.skip_connection(x) + h
h
=
h
+
emb_out
#
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
class
QKVAttention
(
nn
.
Module
):
class
QKVAttention
(
nn
.
Module
):
...
...
src/diffusers/models/unet_rl.py
View file @
ebf3717c
...
@@ -6,6 +6,7 @@ import torch.nn as nn
...
@@ -6,6 +6,7 @@ import torch.nn as nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
ResidualTemporalBlock
class
SinusoidalPosEmb
(
nn
.
Module
):
class
SinusoidalPosEmb
(
nn
.
Module
):
...
@@ -72,36 +73,35 @@ class Conv1dBlock(nn.Module):
...
@@ -72,36 +73,35 @@ class Conv1dBlock(nn.Module):
return
self
.
block
(
x
)
return
self
.
block
(
x
)
class
ResidualTemporalBlock
(
nn
.
Module
):
# class ResidualTemporalBlock(nn.Module):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
# def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super
().
__init__
()
# super().__init__()
#
self
.
blocks
=
nn
.
ModuleList
(
# self.blocks = nn.ModuleList(
[
# [
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
),
# Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
),
# Conv1dBlock(out_channels, out_channels, kernel_size),
]
# ]
)
# )
#
self
.
time_mlp
=
nn
.
Sequential
(
# self.time_mlp = nn.Sequential(
nn
.
Mish
(),
# nn.Mish(),
nn
.
Linear
(
embed_dim
,
out_channels
),
# nn.Linear(embed_dim, out_channels),
RearrangeDim
(),
# RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
# Rearrange("batch t -> batch t 1"),
)
# )
#
self
.
residual_conv
=
(
# self.residual_conv = (
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
# nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
# )
#
def
forward
(
self
,
x
,
t
):
# def forward(self, x, t):
"""
# """
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
# x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x #
out_channels x horizon ]
out_channels
x
horizon
]
#"""
"""
# out = self.blocks[0](x) + self.time_mlp(t)
out
=
self
.
blocks
[
0
](
x
)
+
self
.
time_mlp
(
t
)
# out = self.blocks[1](out)
out
=
self
.
blocks
[
1
](
out
)
# return out + self.residual_conv(x)
return
out
+
self
.
residual_conv
(
x
)
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
# (nn.Module):
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
ebf3717c
...
@@ -28,6 +28,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -28,6 +28,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
ResnetBlockBigGANpp
,
ResnetBlockDDPMpp
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
...
@@ -299,7 +300,7 @@ def downsample_2d(x, k=None, factor=2, gain=1):
...
@@ -299,7 +300,7 @@ def downsample_2d(x, k=None, factor=2, gain=1):
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
def
ddpm_
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
"""1x1 convolution with DDPM initialization."""
"""1x1 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
...
@@ -307,7 +308,7 @@ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, pad
...
@@ -307,7 +308,7 @@ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, pad
return
conv
return
conv
def
ddpm_
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
"""3x3 convolution with DDPM initialization."""
"""3x3 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
...
@@ -317,10 +318,6 @@ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_sc
...
@@ -317,10 +318,6 @@ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_sc
return
conv
return
conv
conv1x1
=
ddpm_conv1x1
conv3x3
=
ddpm_conv3x3
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
einsum_str
=
"{},{}->{}"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
einsum_str
=
"{},{}->{}"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
...
@@ -494,135 +491,135 @@ class Downsample(nn.Module):
...
@@ -494,135 +491,135 @@ class Downsample(nn.Module):
return
x
return
x
class
ResnetBlockDDPMpp
(
nn
.
Module
):
#
class ResnetBlockDDPMpp(nn.Module):
"""ResBlock adapted from DDPM."""
#
"""ResBlock adapted from DDPM."""
#
def
__init__
(
#
def __init__(
self
,
#
self,
act
,
#
act,
in_ch
,
#
in_ch,
out_ch
=
None
,
#
out_ch=None,
temb_dim
=
None
,
#
temb_dim=None,
conv_shortcut
=
False
,
#
conv_shortcut=False,
dropout
=
0.1
,
#
dropout=0.1,
skip_rescale
=
False
,
#
skip_rescale=False,
init_scale
=
0.0
,
#
init_scale=0.0,
):
#
):
super
().
__init__
()
#
super().__init__()
out_ch
=
out_ch
if
out_ch
else
in_ch
#
out_ch = out_ch if out_ch else in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
#
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
#
self.Conv_0 = conv3x3(in_ch, out_ch)
if
temb_dim
is
not
None
:
#
if temb_dim is not None:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
#
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
data
.
shape
)
#
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
#
nn.init.zeros_(self.Dense_0.bias)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
#
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
#
self.Dropout_0 = nn.Dropout(dropout)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
#
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if
in_ch
!=
out_ch
:
#
if in_ch != out_ch:
if
conv_shortcut
:
#
if conv_shortcut:
self
.
Conv_2
=
conv3x3
(
in_ch
,
out_ch
)
#
self.Conv_2 = conv3x3(in_ch, out_ch)
else
:
#
else:
self
.
NIN_0
=
NIN
(
in_ch
,
out_ch
)
#
self.NIN_0 = NIN(in_ch, out_ch)
#
self
.
skip_rescale
=
skip_rescale
#
self.skip_rescale = skip_rescale
self
.
act
=
act
#
self.act = act
self
.
out_ch
=
out_ch
#
self.out_ch = out_ch
self
.
conv_shortcut
=
conv_shortcut
#
self.conv_shortcut = conv_shortcut
#
def
forward
(
self
,
x
,
temb
=
None
):
#
def forward(self, x, temb=None):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
#
h = self.act(self.GroupNorm_0(x))
h
=
self
.
Conv_0
(
h
)
#
h = self.Conv_0(h)
if
temb
is
not
None
:
#
if temb is not None:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
#
h += self.Dense_0(self.act(temb))[:, :, None, None]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
#
h = self.act(self.GroupNorm_1(h))
h
=
self
.
Dropout_0
(
h
)
#
h = self.Dropout_0(h)
h
=
self
.
Conv_1
(
h
)
#
h = self.Conv_1(h)
if
x
.
shape
[
1
]
!=
self
.
out_ch
:
#
if x.shape[1] != self.out_ch:
if
self
.
conv_shortcut
:
#
if self.conv_shortcut:
x
=
self
.
Conv_2
(
x
)
#
x = self.Conv_2(x)
else
:
#
else:
x
=
self
.
NIN_0
(
x
)
#
x = self.NIN_0(x)
if
not
self
.
skip_rescale
:
#
if not self.skip_rescale:
return
x
+
h
#
return x + h
else
:
#
else:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
#
return (x + h) / np.sqrt(2.0)
class
ResnetBlockBigGANpp
(
nn
.
Module
):
#
class ResnetBlockBigGANpp(nn.Module):
def
__init__
(
#
def __init__(
self
,
#
self,
act
,
#
act,
in_ch
,
#
in_ch,
out_ch
=
None
,
#
out_ch=None,
temb_dim
=
None
,
#
temb_dim=None,
up
=
False
,
#
up=False,
down
=
False
,
#
down=False,
dropout
=
0.1
,
#
dropout=0.1,
fir
=
False
,
#
fir=False,
fir_kernel
=
(
1
,
3
,
3
,
1
),
#
fir_kernel=(1, 3, 3, 1),
skip_rescale
=
True
,
#
skip_rescale=True,
init_scale
=
0.0
,
#
init_scale=0.0,
):
#
):
super
().
__init__
()
#
super().__init__()
#
out_ch
=
out_ch
if
out_ch
else
in_ch
#
out_ch = out_ch if out_ch else in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
#
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self
.
up
=
up
#
self.up = up
self
.
down
=
down
#
self.down = down
self
.
fir
=
fir
#
self.fir = fir
self
.
fir_kernel
=
fir_kernel
#
self.fir_kernel = fir_kernel
#
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
#
self.Conv_0 = conv3x3(in_ch, out_ch)
if
temb_dim
is
not
None
:
#
if temb_dim is not None:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
#
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
shape
)
#
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
#
nn.init.zeros_(self.Dense_0.bias)
#
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
#
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
#
self.Dropout_0 = nn.Dropout(dropout)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
#
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if
in_ch
!=
out_ch
or
up
or
down
:
#
if in_ch != out_ch or up or down:
self
.
Conv_2
=
conv1x1
(
in_ch
,
out_ch
)
#
self.Conv_2 = conv1x1(in_ch, out_ch)
#
self
.
skip_rescale
=
skip_rescale
#
self.skip_rescale = skip_rescale
self
.
act
=
act
#
self.act = act
self
.
in_ch
=
in_ch
#
self.in_ch = in_ch
self
.
out_ch
=
out_ch
#
self.out_ch = out_ch
#
def
forward
(
self
,
x
,
temb
=
None
):
#
def forward(self, x, temb=None):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
#
h = self.act(self.GroupNorm_0(x))
#
if
self
.
up
:
#
if self.up:
if
self
.
fir
:
#
if self.fir:
h
=
upsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
#
h = upsample_2d(h, self.fir_kernel, factor=2)
x
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
#
x = upsample_2d(x, self.fir_kernel, factor=2)
else
:
#
else:
h
=
naive_upsample_2d
(
h
,
factor
=
2
)
#
h = naive_upsample_2d(h, factor=2)
x
=
naive_upsample_2d
(
x
,
factor
=
2
)
#
x = naive_upsample_2d(x, factor=2)
elif
self
.
down
:
#
elif self.down:
if
self
.
fir
:
#
if self.fir:
h
=
downsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
#
h = downsample_2d(h, self.fir_kernel, factor=2)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
#
x = downsample_2d(x, self.fir_kernel, factor=2)
else
:
#
else:
h
=
naive_downsample_2d
(
h
,
factor
=
2
)
#
h = naive_downsample_2d(h, factor=2)
x
=
naive_downsample_2d
(
x
,
factor
=
2
)
#
x = naive_downsample_2d(x, factor=2)
#
h
=
self
.
Conv_0
(
h
)
#
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
# Add bias to each feature map conditioned on the time embedding
if
temb
is
not
None
:
#
if temb is not None:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
#
h += self.Dense_0(self.act(temb))[:, :, None, None]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
#
h = self.act(self.GroupNorm_1(h))
h
=
self
.
Dropout_0
(
h
)
#
h = self.Dropout_0(h)
h
=
self
.
Conv_1
(
h
)
#
h = self.Conv_1(h)
#
if
self
.
in_ch
!=
self
.
out_ch
or
self
.
up
or
self
.
down
:
#
if self.in_ch != self.out_ch or self.up or self.down:
x
=
self
.
Conv_2
(
x
)
#
x = self.Conv_2(x)
#
if
not
self
.
skip_rescale
:
#
if not self.skip_rescale:
return
x
+
h
#
return x + h
else
:
#
else:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
#
return (x + h) / np.sqrt(2.0)
class
NCSNpp
(
ModelMixin
,
ConfigMixin
):
class
NCSNpp
(
ModelMixin
,
ConfigMixin
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment