Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
efe1e60e
Commit
efe1e60e
authored
Jun 30, 2022
by
Patrick von Platen
Browse files
merge glide into resnets
parent
fd6f93b2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
302 deletions
+16
-302
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+6
-237
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+9
-64
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+1
-1
No files found.
src/diffusers/models/resnet.py
View file @
efe1e60e
...
@@ -161,229 +161,7 @@ class Downsample(nn.Module):
...
@@ -161,229 +161,7 @@ class Downsample(nn.Module):
# RESNETS
# RESNETS
# unet_glide.py
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
class
ResBlock
(
TimestepBlock
):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
downsampling.
"""
def
__init__
(
self
,
channels
,
emb_channels
,
dropout
,
out_channels
=
None
,
use_conv
=
False
,
use_scale_shift_norm
=
False
,
dims
=
2
,
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
overwrite
=
True
,
# TODO(Patrick) - use for glide at later stage
):
super
().
__init__
()
self
.
channels
=
channels
self
.
emb_channels
=
emb_channels
self
.
dropout
=
dropout
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_checkpoint
=
use_checkpoint
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
,
swish
=
1.0
),
nn
.
Identity
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
,
swish
=
0.0
),
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
)
if
self
.
out_channels
==
channels
:
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
self
.
overwrite
=
overwrite
self
.
is_overwritten
=
False
if
self
.
overwrite
:
in_channels
=
channels
out_channels
=
self
.
out_channels
conv_shortcut
=
False
dropout
=
0.0
temb_channels
=
emb_channels
groups
=
32
pre_norm
=
True
eps
=
1e-5
non_linearity
=
"silu"
self
.
pre_norm
=
pre_norm
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
# Add to init
self
.
time_embedding_norm
=
"scale_shift"
if
self
.
pre_norm
:
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
else
:
self
.
norm1
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
nonlinearity
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
up
,
self
.
down
=
up
,
down
# if self.up:
# self.h_upd = Upsample(in_channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(in_channels, use_conv=False, dims=dims)
# elif self.down:
# self.h_upd = Downsample(in_channels, use_conv=False, dims=dims, padding=1, name="op")
# self.x_upd = Downsample(in_channels, use_conv=False, dims=dims, padding=1, name="op")
def
set_weights
(
self
):
# TODO(Patrick): use for glide at later stage
self
.
norm1
.
weight
.
data
=
self
.
in_layers
[
0
].
weight
.
data
self
.
norm1
.
bias
.
data
=
self
.
in_layers
[
0
].
bias
.
data
self
.
conv1
.
weight
.
data
=
self
.
in_layers
[
-
1
].
weight
.
data
self
.
conv1
.
bias
.
data
=
self
.
in_layers
[
-
1
].
bias
.
data
self
.
temb_proj
.
weight
.
data
=
self
.
emb_layers
[
-
1
].
weight
.
data
self
.
temb_proj
.
bias
.
data
=
self
.
emb_layers
[
-
1
].
bias
.
data
self
.
norm2
.
weight
.
data
=
self
.
out_layers
[
0
].
weight
.
data
self
.
norm2
.
bias
.
data
=
self
.
out_layers
[
0
].
bias
.
data
self
.
conv2
.
weight
.
data
=
self
.
out_layers
[
-
1
].
weight
.
data
self
.
conv2
.
bias
.
data
=
self
.
out_layers
[
-
1
].
bias
.
data
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
nin_shortcut
.
weight
.
data
=
self
.
skip_connection
.
weight
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
skip_connection
.
bias
.
data
def
forward
(
self
,
x
,
emb
):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if
self
.
overwrite
:
# TODO(Patrick): use for glide at later stage
self
.
set_weights
()
orig_x
=
x
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
self
.
h_upd
(
h
)
x
=
self
.
x_upd
(
x
)
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
result
=
self
.
skip_connection
(
x
)
+
h
# TODO(Patrick) Use for glide at later stage
result
=
self
.
forward_2
(
orig_x
,
emb
)
return
result
def
forward_2
(
self
,
x
,
temb
):
if
self
.
overwrite
and
not
self
.
is_overwritten
:
self
.
set_weights
()
self
.
is_overwritten
=
True
h
=
x
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
self
.
up
or
self
.
down
:
x
=
self
.
x_upd
(
x
)
h
=
self
.
h_upd
(
h
)
h
=
self
.
conv1
(
h
)
temb
=
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
if
self
.
time_embedding_norm
==
"scale_shift"
:
scale
,
shift
=
torch
.
chunk
(
temb
,
2
,
dim
=
1
)
h
=
self
.
norm2
(
h
)
h
=
h
+
h
*
scale
+
shift
h
=
self
.
nonlinearity
(
h
)
else
:
h
=
h
+
temb
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
# unet.py, unet_grad_tts.py, unet_ldm.py
class
ResnetBlock
(
nn
.
Module
):
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -445,12 +223,9 @@ class ResnetBlock(nn.Module):
...
@@ -445,12 +223,9 @@ class ResnetBlock(nn.Module):
self
.
x_upd
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# TODO(Patrick) - this branch is never used I think => can be deleted!
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
self
.
is_overwritten
=
False
self
.
is_overwritten
=
False
self
.
overwrite_for_glide
=
overwrite_for_glide
self
.
overwrite_for_glide
=
overwrite_for_glide
self
.
overwrite_for_grad_tts
=
overwrite_for_grad_tts
self
.
overwrite_for_grad_tts
=
overwrite_for_grad_tts
...
@@ -497,8 +272,6 @@ class ResnetBlock(nn.Module):
...
@@ -497,8 +272,6 @@ class ResnetBlock(nn.Module):
)
)
if
self
.
out_channels
==
in_channels
:
if
self
.
out_channels
==
in_channels
:
self
.
skip_connection
=
nn
.
Identity
()
self
.
skip_connection
=
nn
.
Identity
()
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else
:
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
...
@@ -541,6 +314,8 @@ class ResnetBlock(nn.Module):
...
@@ -541,6 +314,8 @@ class ResnetBlock(nn.Module):
self
.
nin_shortcut
.
bias
.
data
=
self
.
skip_connection
.
bias
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
skip_connection
.
bias
.
data
def
forward
(
self
,
x
,
temb
,
mask
=
1.0
):
def
forward
(
self
,
x
,
temb
,
mask
=
1.0
):
# TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements
if
self
.
overwrite_for_grad_tts
and
not
self
.
is_overwritten
:
if
self
.
overwrite_for_grad_tts
and
not
self
.
is_overwritten
:
self
.
set_weights_grad_tts
()
self
.
set_weights_grad_tts
()
self
.
is_overwritten
=
True
self
.
is_overwritten
=
True
...
@@ -566,6 +341,7 @@ class ResnetBlock(nn.Module):
...
@@ -566,6 +341,7 @@ class ResnetBlock(nn.Module):
h
=
h
*
mask
h
=
h
*
mask
temb
=
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
temb
=
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
if
self
.
time_embedding_norm
==
"scale_shift"
:
if
self
.
time_embedding_norm
==
"scale_shift"
:
scale
,
shift
=
torch
.
chunk
(
temb
,
2
,
dim
=
1
)
scale
,
shift
=
torch
.
chunk
(
temb
,
2
,
dim
=
1
)
...
@@ -589,9 +365,6 @@ class ResnetBlock(nn.Module):
...
@@ -589,9 +365,6 @@ class ResnetBlock(nn.Module):
x
=
x
*
mask
x
=
x
*
mask
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
in_channels
!=
self
.
out_channels
:
# if self.use_conv_shortcut:
# x = self.conv_shortcut(x)
# else:
x
=
self
.
nin_shortcut
(
x
)
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
return
x
+
h
...
@@ -605,10 +378,6 @@ class Block(torch.nn.Module):
...
@@ -605,10 +378,6 @@ class Block(torch.nn.Module):
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
3
,
padding
=
1
),
torch
.
nn
.
GroupNorm
(
groups
,
dim_out
),
Mish
()
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
3
,
padding
=
1
),
torch
.
nn
.
GroupNorm
(
groups
,
dim_out
),
Mish
()
)
)
def
forward
(
self
,
x
,
mask
):
output
=
self
.
block
(
x
*
mask
)
return
output
*
mask
# unet_score_estimation.py
# unet_score_estimation.py
class
ResnetBlockBigGANpp
(
nn
.
Module
):
class
ResnetBlockBigGANpp
(
nn
.
Module
):
...
...
src/diffusers/models/unet_glide.py
View file @
efe1e60e
...
@@ -6,8 +6,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -6,8 +6,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock
,
TimestepBlock
,
Upsample
from
.resnet
import
ResnetBlock
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
...
@@ -191,15 +190,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -191,15 +190,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
layers
=
[
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=int(mult * model_channels),
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# )
ResnetBlock
(
ResnetBlock
(
in_channels
=
ch
,
in_channels
=
ch
,
out_channels
=
mult
*
model_channels
,
out_channels
=
mult
*
model_channels
,
...
@@ -207,7 +197,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -207,7 +197,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
temb_channels
=
time_embed_dim
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
overwrite_for_glide
=
True
,
)
)
]
]
...
@@ -229,16 +219,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -229,16 +219,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
out_ch
=
ch
out_ch
=
ch
self
.
input_blocks
.
append
(
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
TimestepEmbedSequential
(
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# down=True,
# )
ResnetBlock
(
ResnetBlock
(
in_channels
=
ch
,
in_channels
=
ch
,
out_channels
=
out_ch
,
out_channels
=
out_ch
,
...
@@ -246,9 +226,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -246,9 +226,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
temb_channels
=
time_embed_dim
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
overwrite_for_glide
=
True
,
down
=
True
down
=
True
,
)
)
if
resblock_updown
if
resblock_updown
else
Downsample
(
else
Downsample
(
...
@@ -262,21 +242,13 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -262,21 +242,13 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
self
.
middle_block
=
TimestepEmbedSequential
(
self
.
middle_block
=
TimestepEmbedSequential
(
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# ),
ResnetBlock
(
ResnetBlock
(
in_channels
=
ch
,
in_channels
=
ch
,
dropout
=
dropout
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
overwrite_for_glide
=
True
,
),
),
AttentionBlock
(
AttentionBlock
(
...
@@ -286,23 +258,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -286,23 +258,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
transformer_dim
,
encoder_channels
=
transformer_dim
,
),
),
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# ),
ResnetBlock
(
ResnetBlock
(
in_channels
=
ch
,
in_channels
=
ch
,
dropout
=
dropout
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
overwrite_for_glide
=
True
,
)
)
,
)
)
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
...
@@ -311,15 +275,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -311,15 +275,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for
i
in
range
(
num_res_blocks
+
1
):
for
i
in
range
(
num_res_blocks
+
1
):
ich
=
input_block_chans
.
pop
()
ich
=
input_block_chans
.
pop
()
layers
=
[
layers
=
[
# ResBlock(
# ch + ich,
# time_embed_dim,
# dropout,
# out_channels=int(model_channels * mult),
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# )
ResnetBlock
(
ResnetBlock
(
in_channels
=
ch
+
ich
,
in_channels
=
ch
+
ich
,
out_channels
=
model_channels
*
mult
,
out_channels
=
model_channels
*
mult
,
...
@@ -327,7 +282,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -327,7 +282,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
temb_channels
=
time_embed_dim
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
overwrite_for_glide
=
True
,
),
),
]
]
...
@@ -345,16 +300,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -345,16 +300,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
if
level
and
i
==
num_res_blocks
:
if
level
and
i
==
num_res_blocks
:
out_ch
=
ch
out_ch
=
ch
layers
.
append
(
layers
.
append
(
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# up=True,
# )
ResnetBlock
(
ResnetBlock
(
in_channels
=
ch
,
in_channels
=
ch
,
out_channels
=
out_ch
,
out_channels
=
out_ch
,
...
@@ -362,7 +307,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -362,7 +307,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
temb_channels
=
time_embed_dim
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
overwrite_for_glide
=
True
,
up
=
True
,
up
=
True
,
)
)
...
...
tests/test_modeling_utils.py
View file @
efe1e60e
...
@@ -795,7 +795,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -795,7 +795,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
sizes
=
(
32
,
32
)
sizes
=
(
32
,
32
)
noise
=
torch
.
randn
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
noise
=
torch
.
randn
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
9.
]).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
9.
0
]).
to
(
torch_device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output
=
model
(
noise
,
time_step
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment