Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
efe1e60e
Commit
efe1e60e
authored
Jun 30, 2022
by
Patrick von Platen
Browse files
merge glide into resnets
parent
fd6f93b2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
302 deletions
+16
-302
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+6
-237
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+9
-64
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+1
-1
No files found.
src/diffusers/models/resnet.py
View file @
efe1e60e
...
@@ -161,229 +161,7 @@ class Downsample(nn.Module):
...
@@ -161,229 +161,7 @@ class Downsample(nn.Module):
# RESNETS
# RESNETS
# unet_glide.py
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
class
ResBlock
(
TimestepBlock
):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
downsampling.
"""
def
__init__
(
self
,
channels
,
emb_channels
,
dropout
,
out_channels
=
None
,
use_conv
=
False
,
use_scale_shift_norm
=
False
,
dims
=
2
,
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
overwrite
=
True
,
# TODO(Patrick) - use for glide at later stage
):
super
().
__init__
()
self
.
channels
=
channels
self
.
emb_channels
=
emb_channels
self
.
dropout
=
dropout
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_checkpoint
=
use_checkpoint
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
,
swish
=
1.0
),
nn
.
Identity
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
,
swish
=
0.0
),
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
)
if
self
.
out_channels
==
channels
:
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
self
.
overwrite
=
overwrite
self
.
is_overwritten
=
False
if
self
.
overwrite
:
in_channels
=
channels
out_channels
=
self
.
out_channels
conv_shortcut
=
False
dropout
=
0.0
temb_channels
=
emb_channels
groups
=
32
pre_norm
=
True
eps
=
1e-5
non_linearity
=
"silu"
self
.
pre_norm
=
pre_norm
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
# Add to init
self
.
time_embedding_norm
=
"scale_shift"
if
self
.
pre_norm
:
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
else
:
self
.
norm1
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
nonlinearity
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
up
,
self
.
down
=
up
,
down
# if self.up:
# self.h_upd = Upsample(in_channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(in_channels, use_conv=False, dims=dims)
# elif self.down:
# self.h_upd = Downsample(in_channels, use_conv=False, dims=dims, padding=1, name="op")
# self.x_upd = Downsample(in_channels, use_conv=False, dims=dims, padding=1, name="op")
def
set_weights
(
self
):
# TODO(Patrick): use for glide at later stage
self
.
norm1
.
weight
.
data
=
self
.
in_layers
[
0
].
weight
.
data
self
.
norm1
.
bias
.
data
=
self
.
in_layers
[
0
].
bias
.
data
self
.
conv1
.
weight
.
data
=
self
.
in_layers
[
-
1
].
weight
.
data
self
.
conv1
.
bias
.
data
=
self
.
in_layers
[
-
1
].
bias
.
data
self
.
temb_proj
.
weight
.
data
=
self
.
emb_layers
[
-
1
].
weight
.
data
self
.
temb_proj
.
bias
.
data
=
self
.
emb_layers
[
-
1
].
bias
.
data
self
.
norm2
.
weight
.
data
=
self
.
out_layers
[
0
].
weight
.
data
self
.
norm2
.
bias
.
data
=
self
.
out_layers
[
0
].
bias
.
data
self
.
conv2
.
weight
.
data
=
self
.
out_layers
[
-
1
].
weight
.
data
self
.
conv2
.
bias
.
data
=
self
.
out_layers
[
-
1
].
bias
.
data
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
nin_shortcut
.
weight
.
data
=
self
.
skip_connection
.
weight
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
skip_connection
.
bias
.
data
def
forward
(
self
,
x
,
emb
):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if
self
.
overwrite
:
# TODO(Patrick): use for glide at later stage
self
.
set_weights
()
orig_x
=
x
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
self
.
h_upd
(
h
)
x
=
self
.
x_upd
(
x
)
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
result
=
self
.
skip_connection
(
x
)
+
h
# TODO(Patrick) Use for glide at later stage
result
=
self
.
forward_2
(
orig_x
,
emb
)
return
result
def
forward_2
(
self
,
x
,
temb
):
if
self
.
overwrite
and
not
self
.
is_overwritten
:
self
.
set_weights
()
self
.
is_overwritten
=
True
h
=
x
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
self
.
up
or
self
.
down
:
x
=
self
.
x_upd
(
x
)
h
=
self
.
h_upd
(
h
)
h
=
self
.
conv1
(
h
)
temb
=
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
if
self
.
time_embedding_norm
==
"scale_shift"
:
scale
,
shift
=
torch
.
chunk
(
temb
,
2
,
dim
=
1
)
h
=
self
.
norm2
(
h
)
h
=
h
+
h
*
scale
+
shift
h
=
self
.
nonlinearity
(
h
)
else
:
h
=
h
+
temb
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
# unet.py, unet_grad_tts.py, unet_ldm.py
class
ResnetBlock
(
nn
.
Module
):
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -445,12 +223,9 @@ class ResnetBlock(nn.Module):
...
@@ -445,12 +223,9 @@ class ResnetBlock(nn.Module):
self
.
x_upd
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# TODO(Patrick) - this branch is never used I think => can be deleted!
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
self
.
is_overwritten
=
False
self
.
is_overwritten
=
False
self
.
overwrite_for_glide
=
overwrite_for_glide
self
.
overwrite_for_glide
=
overwrite_for_glide
self
.
overwrite_for_grad_tts
=
overwrite_for_grad_tts
self
.
overwrite_for_grad_tts
=
overwrite_for_grad_tts
...
@@ -497,8 +272,6 @@ class ResnetBlock(nn.Module):
...
@@ -497,8 +272,6 @@ class ResnetBlock(nn.Module):
)
)
if
self
.
out_channels
==
in_channels
:
if
self
.
out_channels
==
in_channels
:
self
.
skip_connection
=
nn
.
Identity
()
self
.
skip_connection
=
nn
.
Identity
()
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else
:
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
...
@@ -541,6 +314,8 @@ class ResnetBlock(nn.Module):
...
@@ -541,6 +314,8 @@ class ResnetBlock(nn.Module):
self
.
nin_shortcut
.
bias
.
data
=
self
.
skip_connection
.
bias
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
skip_connection
.
bias
.
data
def
forward
(
self
,
x
,
temb
,
mask
=
1.0
):
def
forward
(
self
,
x
,
temb
,
mask
=
1.0
):
# TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements
if
self
.
overwrite_for_grad_tts
and
not
self
.
is_overwritten
:
if
self
.
overwrite_for_grad_tts
and
not
self
.
is_overwritten
:
self
.
set_weights_grad_tts
()
self
.
set_weights_grad_tts
()
self
.
is_overwritten
=
True
self
.
is_overwritten
=
True
...
@@ -566,6 +341,7 @@ class ResnetBlock(nn.Module):
...
@@ -566,6 +341,7 @@ class ResnetBlock(nn.Module):
h
=
h
*
mask
h
=
h
*
mask
temb
=
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
temb
=
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
if
self
.
time_embedding_norm
==
"scale_shift"
:
if
self
.
time_embedding_norm
==
"scale_shift"
:
scale
,
shift
=
torch
.
chunk
(
temb
,
2
,
dim
=
1
)
scale
,
shift
=
torch
.
chunk
(
temb
,
2
,
dim
=
1
)
...
@@ -589,9 +365,6 @@ class ResnetBlock(nn.Module):
...
@@ -589,9 +365,6 @@ class ResnetBlock(nn.Module):
x
=
x
*
mask
x
=
x
*
mask
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
in_channels
!=
self
.
out_channels
:
# if self.use_conv_shortcut:
# x = self.conv_shortcut(x)
# else:
x
=
self
.
nin_shortcut
(
x
)
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
return
x
+
h
...
@@ -605,10 +378,6 @@ class Block(torch.nn.Module):
...
@@ -605,10 +378,6 @@ class Block(torch.nn.Module):
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
3
,
padding
=
1
),
torch
.
nn
.
GroupNorm
(
groups
,
dim_out
),
Mish
()
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
3
,
padding
=
1
),
torch
.
nn
.
GroupNorm
(
groups
,
dim_out
),
Mish
()
)
)
def
forward
(
self
,
x
,
mask
):
output
=
self
.
block
(
x
*
mask
)
return
output
*
mask
# unet_score_estimation.py
# unet_score_estimation.py
class
ResnetBlockBigGANpp
(
nn
.
Module
):
class
ResnetBlockBigGANpp
(
nn
.
Module
):
...
...
src/diffusers/models/unet_glide.py
View file @
efe1e60e
...
@@ -6,8 +6,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -6,8 +6,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock
,
TimestepBlock
,
Upsample
from
.resnet
import
ResnetBlock
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
...
@@ -191,15 +190,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -191,15 +190,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
layers
=
[
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=int(mult * model_channels),
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# )
ResnetBlock
(
ResnetBlock
(
in_channels
=
ch
,
in_channels
=
ch
,
out_channels
=
mult
*
model_channels
,
out_channels
=
mult
*
model_channels
,
...
@@ -207,7 +197,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -207,7 +197,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
temb_channels
=
time_embed_dim
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
overwrite_for_glide
=
True
,
)
)
]
]
...
@@ -229,16 +219,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -229,16 +219,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
out_ch
=
ch
out_ch
=
ch
self
.
input_blocks
.
append
(
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
TimestepEmbedSequential
(
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# down=True,
# )
ResnetBlock
(
ResnetBlock
(
in_channels
=
ch
,
in_channels
=
ch
,
out_channels
=
out_ch
,
out_channels
=
out_ch
,
...
@@ -246,9 +226,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -246,9 +226,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
temb_channels
=
time_embed_dim
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
overwrite_for_glide
=
True
,
down
=
True
down
=
True
,
)
)
if
resblock_updown
if
resblock_updown
else
Downsample
(
else
Downsample
(
...
@@ -262,21 +242,13 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -262,21 +242,13 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
self
.
middle_block
=
TimestepEmbedSequential
(
self
.
middle_block
=
TimestepEmbedSequential
(
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# ),
ResnetBlock
(
ResnetBlock
(
in_channels
=
ch
,
in_channels
=
ch
,
dropout
=
dropout
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
overwrite_for_glide
=
True
,
),
),
AttentionBlock
(
AttentionBlock
(
...
@@ -286,23 +258,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -286,23 +258,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
transformer_dim
,
encoder_channels
=
transformer_dim
,
),
),
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# ),
ResnetBlock
(
ResnetBlock
(
in_channels
=
ch
,
in_channels
=
ch
,
dropout
=
dropout
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
overwrite_for_glide
=
True
,
)
)
,
)
)
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
...
@@ -311,15 +275,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -311,15 +275,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for
i
in
range
(
num_res_blocks
+
1
):
for
i
in
range
(
num_res_blocks
+
1
):
ich
=
input_block_chans
.
pop
()
ich
=
input_block_chans
.
pop
()
layers
=
[
layers
=
[
# ResBlock(
# ch + ich,
# time_embed_dim,
# dropout,
# out_channels=int(model_channels * mult),
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# )
ResnetBlock
(
ResnetBlock
(
in_channels
=
ch
+
ich
,
in_channels
=
ch
+
ich
,
out_channels
=
model_channels
*
mult
,
out_channels
=
model_channels
*
mult
,
...
@@ -327,7 +282,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -327,7 +282,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
temb_channels
=
time_embed_dim
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
overwrite_for_glide
=
True
,
),
),
]
]
...
@@ -345,16 +300,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -345,16 +300,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
if
level
and
i
==
num_res_blocks
:
if
level
and
i
==
num_res_blocks
:
out_ch
=
ch
out_ch
=
ch
layers
.
append
(
layers
.
append
(
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# up=True,
# )
ResnetBlock
(
ResnetBlock
(
in_channels
=
ch
,
in_channels
=
ch
,
out_channels
=
out_ch
,
out_channels
=
out_ch
,
...
@@ -362,7 +307,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -362,7 +307,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
temb_channels
=
time_embed_dim
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
overwrite_for_glide
=
True
,
up
=
True
,
up
=
True
,
)
)
...
...
tests/test_modeling_utils.py
View file @
efe1e60e
...
@@ -795,7 +795,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -795,7 +795,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
sizes
=
(
32
,
32
)
sizes
=
(
32
,
32
)
noise
=
torch
.
randn
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
noise
=
torch
.
randn
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
9.
]).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
9.
0
]).
to
(
torch_device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output
=
model
(
noise
,
time_step
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment