Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
fb188cd3
Unverified
Commit
fb188cd3
authored
Jul 01, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 01, 2022
Browse files
Merge pull request #55 from huggingface/refactor_glide
[Resnet] Merge glide resnet into general resnet
parents
c1c4dea9
efe1e60e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
100 additions
and
286 deletions
+100
-286
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+45
-237
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+53
-47
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+2
-2
No files found.
src/diffusers/models/resnet.py
View file @
fb188cd3
import
string
from
abc
import
abstractmethod
from
abc
import
abstractmethod
import
numpy
as
np
import
numpy
as
np
...
@@ -162,221 +161,7 @@ class Downsample(nn.Module):
...
@@ -162,221 +161,7 @@ class Downsample(nn.Module):
# RESNETS
# RESNETS
# unet_glide.py
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
class
ResBlock
(
TimestepBlock
):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
downsampling.
"""
def
__init__
(
self
,
channels
,
emb_channels
,
dropout
,
out_channels
=
None
,
use_conv
=
False
,
use_scale_shift_norm
=
False
,
dims
=
2
,
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
overwrite
=
False
,
# TODO(Patrick) - use for glide at later stage
):
super
().
__init__
()
self
.
channels
=
channels
self
.
emb_channels
=
emb_channels
self
.
dropout
=
dropout
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_checkpoint
=
use_checkpoint
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
,
swish
=
1.0
),
nn
.
Identity
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
updown
=
up
or
down
# if self.updown:
# import ipdb; ipdb.set_trace()
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
,
swish
=
0.0
if
use_scale_shift_norm
else
1.0
),
nn
.
SiLU
()
if
use_scale_shift_norm
else
nn
.
Identity
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
)
if
self
.
out_channels
==
channels
:
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
self
.
overwrite
=
overwrite
self
.
is_overwritten
=
False
if
self
.
overwrite
:
in_channels
=
channels
out_channels
=
self
.
out_channels
conv_shortcut
=
False
dropout
=
0.0
temb_channels
=
emb_channels
groups
=
32
pre_norm
=
True
eps
=
1e-5
non_linearity
=
"silu"
self
.
pre_norm
=
pre_norm
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
if
self
.
pre_norm
:
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
else
:
self
.
norm1
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
nonlinearity
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
set_weights
(
self
):
# TODO(Patrick): use for glide at later stage
self
.
norm1
.
weight
.
data
=
self
.
in_layers
[
0
].
weight
.
data
self
.
norm1
.
bias
.
data
=
self
.
in_layers
[
0
].
bias
.
data
self
.
conv1
.
weight
.
data
=
self
.
in_layers
[
-
1
].
weight
.
data
self
.
conv1
.
bias
.
data
=
self
.
in_layers
[
-
1
].
bias
.
data
self
.
temb_proj
.
weight
.
data
=
self
.
emb_layers
[
-
1
].
weight
.
data
self
.
temb_proj
.
bias
.
data
=
self
.
emb_layers
[
-
1
].
bias
.
data
self
.
norm2
.
weight
.
data
=
self
.
out_layers
[
0
].
weight
.
data
self
.
norm2
.
bias
.
data
=
self
.
out_layers
[
0
].
bias
.
data
self
.
conv2
.
weight
.
data
=
self
.
out_layers
[
-
1
].
weight
.
data
self
.
conv2
.
bias
.
data
=
self
.
out_layers
[
-
1
].
bias
.
data
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
nin_shortcut
.
weight
.
data
=
self
.
skip_connection
.
weight
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
skip_connection
.
bias
.
data
def
forward
(
self
,
x
,
emb
):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if
self
.
overwrite
:
# TODO(Patrick): use for glide at later stage
self
.
set_weights
()
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
self
.
h_upd
(
h
)
x
=
self
.
x_upd
(
x
)
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
result
=
self
.
skip_connection
(
x
)
+
h
# TODO(Patrick) Use for glide at later stage
# result = self.forward_2(x, emb)
return
result
def
forward_2
(
self
,
x
,
temb
,
mask
=
1.0
):
if
self
.
overwrite
and
not
self
.
is_overwritten
:
self
.
set_weights
()
self
.
is_overwritten
=
True
h
=
x
if
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
if
not
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
h
+
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
if
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
not
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
# unet.py and unet_grad_tts.py
class
ResnetBlock
(
nn
.
Module
):
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -390,8 +175,12 @@ class ResnetBlock(nn.Module):
...
@@ -390,8 +175,12 @@ class ResnetBlock(nn.Module):
pre_norm
=
True
,
pre_norm
=
True
,
eps
=
1e-6
,
eps
=
1e-6
,
non_linearity
=
"swish"
,
non_linearity
=
"swish"
,
time_embedding_norm
=
"default"
,
up
=
False
,
down
=
False
,
overwrite_for_grad_tts
=
False
,
overwrite_for_grad_tts
=
False
,
overwrite_for_ldm
=
False
,
overwrite_for_ldm
=
False
,
overwrite_for_glide
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
pre_norm
=
pre_norm
self
.
pre_norm
=
pre_norm
...
@@ -399,6 +188,9 @@ class ResnetBlock(nn.Module):
...
@@ -399,6 +188,9 @@ class ResnetBlock(nn.Module):
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
use_conv_shortcut
=
conv_shortcut
self
.
time_embedding_norm
=
time_embedding_norm
self
.
up
=
up
self
.
down
=
down
if
self
.
pre_norm
:
if
self
.
pre_norm
:
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
...
@@ -406,10 +198,16 @@ class ResnetBlock(nn.Module):
...
@@ -406,10 +198,16 @@ class ResnetBlock(nn.Module):
self
.
norm1
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm1
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
if
time_embedding_norm
==
"default"
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
elif
time_embedding_norm
==
"scale_shift"
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
non_linearity
==
"swish"
:
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
nonlinearity
self
.
nonlinearity
=
nonlinearity
elif
non_linearity
==
"mish"
:
elif
non_linearity
==
"mish"
:
...
@@ -417,16 +215,21 @@ class ResnetBlock(nn.Module):
...
@@ -417,16 +215,21 @@ class ResnetBlock(nn.Module):
elif
non_linearity
==
"silu"
:
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
self
.
nonlinearity
=
nn
.
SiLU
()
if
up
:
self
.
h_upd
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
self
.
x_upd
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
elif
down
:
self
.
h_upd
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# TODO(Patrick) - this branch is never used I think => can be deleted!
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
self
.
is_overwritten
=
False
self
.
is_overwritten
=
False
self
.
overwrite_for_glide
=
overwrite_for_glide
self
.
overwrite_for_grad_tts
=
overwrite_for_grad_tts
self
.
overwrite_for_grad_tts
=
overwrite_for_grad_tts
self
.
overwrite_for_ldm
=
overwrite_for_ldm
self
.
overwrite_for_ldm
=
overwrite_for_ldm
or
overwrite_for_glide
if
self
.
overwrite_for_grad_tts
:
if
self
.
overwrite_for_grad_tts
:
dim
=
in_channels
dim
=
in_channels
dim_out
=
out_channels
dim_out
=
out_channels
...
@@ -458,7 +261,7 @@ class ResnetBlock(nn.Module):
...
@@ -458,7 +261,7 @@ class ResnetBlock(nn.Module):
nn
.
SiLU
(),
nn
.
SiLU
(),
linear
(
linear
(
emb_channels
,
emb_channels
,
2
*
self
.
out_channels
if
u
se
_
scale_shift
_norm
else
self
.
out_channels
,
2
*
self
.
out_channels
if
se
lf
.
time_embedding_norm
==
"
scale_shift
"
else
self
.
out_channels
,
),
),
)
)
self
.
out_layers
=
nn
.
Sequential
(
self
.
out_layers
=
nn
.
Sequential
(
...
@@ -469,8 +272,6 @@ class ResnetBlock(nn.Module):
...
@@ -469,8 +272,6 @@ class ResnetBlock(nn.Module):
)
)
if
self
.
out_channels
==
in_channels
:
if
self
.
out_channels
==
in_channels
:
self
.
skip_connection
=
nn
.
Identity
()
self
.
skip_connection
=
nn
.
Identity
()
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else
:
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
...
@@ -513,6 +314,8 @@ class ResnetBlock(nn.Module):
...
@@ -513,6 +314,8 @@ class ResnetBlock(nn.Module):
self
.
nin_shortcut
.
bias
.
data
=
self
.
skip_connection
.
bias
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
skip_connection
.
bias
.
data
def
forward
(
self
,
x
,
temb
,
mask
=
1.0
):
def
forward
(
self
,
x
,
temb
,
mask
=
1.0
):
# TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements
if
self
.
overwrite_for_grad_tts
and
not
self
.
is_overwritten
:
if
self
.
overwrite_for_grad_tts
and
not
self
.
is_overwritten
:
self
.
set_weights_grad_tts
()
self
.
set_weights_grad_tts
()
self
.
is_overwritten
=
True
self
.
is_overwritten
=
True
...
@@ -526,6 +329,10 @@ class ResnetBlock(nn.Module):
...
@@ -526,6 +329,10 @@ class ResnetBlock(nn.Module):
h
=
self
.
norm1
(
h
)
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
self
.
up
or
self
.
down
:
x
=
self
.
x_upd
(
x
)
h
=
self
.
h_upd
(
h
)
h
=
self
.
conv1
(
h
)
h
=
self
.
conv1
(
h
)
if
not
self
.
pre_norm
:
if
not
self
.
pre_norm
:
...
@@ -533,12 +340,20 @@ class ResnetBlock(nn.Module):
...
@@ -533,12 +340,20 @@ class ResnetBlock(nn.Module):
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
h
*
mask
h
=
h
*
mask
h
=
h
+
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
temb
=
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
if
self
.
time_embedding_norm
==
"scale_shift"
:
scale
,
shift
=
torch
.
chunk
(
temb
,
2
,
dim
=
1
)
h
=
h
*
mask
if
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
norm2
(
h
)
h
=
h
+
h
*
scale
+
shift
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
elif
self
.
time_embedding_norm
==
"default"
:
h
=
h
+
temb
h
=
h
*
mask
if
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
h
=
self
.
conv2
(
h
)
...
@@ -550,10 +365,7 @@ class ResnetBlock(nn.Module):
...
@@ -550,10 +365,7 @@ class ResnetBlock(nn.Module):
x
=
x
*
mask
x
=
x
*
mask
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
nin_shortcut
(
x
)
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
return
x
+
h
...
@@ -566,10 +378,6 @@ class Block(torch.nn.Module):
...
@@ -566,10 +378,6 @@ class Block(torch.nn.Module):
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
3
,
padding
=
1
),
torch
.
nn
.
GroupNorm
(
groups
,
dim_out
),
Mish
()
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
3
,
padding
=
1
),
torch
.
nn
.
GroupNorm
(
groups
,
dim_out
),
Mish
()
)
)
def
forward
(
self
,
x
,
mask
):
output
=
self
.
block
(
x
*
mask
)
return
output
*
mask
# unet_score_estimation.py
# unet_score_estimation.py
class
ResnetBlockBigGANpp
(
nn
.
Module
):
class
ResnetBlockBigGANpp
(
nn
.
Module
):
...
...
src/diffusers/models/unet_glide.py
View file @
fb188cd3
...
@@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
from
.resnet
import
Downsample
,
Res
net
Block
,
TimestepBlock
,
Upsample
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
...
@@ -101,7 +101,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -101,7 +101,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def
forward
(
self
,
x
,
emb
,
encoder_out
=
None
):
def
forward
(
self
,
x
,
emb
,
encoder_out
=
None
):
for
layer
in
self
:
for
layer
in
self
:
if
isinstance
(
layer
,
TimestepBlock
):
if
isinstance
(
layer
,
TimestepBlock
)
or
isinstance
(
layer
,
ResnetBlock
)
:
x
=
layer
(
x
,
emb
)
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
AttentionBlock
):
elif
isinstance
(
layer
,
AttentionBlock
):
x
=
layer
(
x
,
encoder_out
)
x
=
layer
(
x
,
encoder_out
)
...
@@ -190,14 +190,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -190,14 +190,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
layers
=
[
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
out_channels
=
mult
*
model_channels
,
dropout
,
dropout
=
dropout
,
out_channels
=
int
(
mult
*
model_channels
),
temb_channels
=
time_embed_dim
,
dims
=
dims
,
eps
=
1e-5
,
use_checkpoint
=
use_checkpoint
,
non_linearity
=
"silu"
,
use_scale_shift_norm
=
use_scale_shift_norm
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
)
)
]
]
ch
=
int
(
mult
*
model_channels
)
ch
=
int
(
mult
*
model_channels
)
...
@@ -218,14 +219,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -218,14 +219,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
out_ch
=
ch
out_ch
=
ch
self
.
input_blocks
.
append
(
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
TimestepEmbedSequential
(
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
out_channels
=
out_ch
,
dims
=
dims
,
dropout
=
dropout
,
use_checkpoint
=
use_checkpoint
,
temb_channels
=
time_embed_dim
,
use_scale_shift_norm
=
use_scale_shift_norm
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
down
=
True
,
down
=
True
,
)
)
if
resblock_updown
if
resblock_updown
...
@@ -240,13 +242,14 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -240,13 +242,14 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
self
.
middle_block
=
TimestepEmbedSequential
(
self
.
middle_block
=
TimestepEmbedSequential
(
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
dropout
=
dropout
,
dropout
,
temb_channels
=
time_embed_dim
,
dims
=
dims
,
eps
=
1e-5
,
use_checkpoint
=
use_checkpoint
,
non_linearity
=
"silu"
,
use_scale_shift_norm
=
use_scale_shift_norm
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
),
),
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
...
@@ -255,13 +258,14 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -255,13 +258,14 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
transformer_dim
,
encoder_channels
=
transformer_dim
,
),
),
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
dropout
=
dropout
,
dropout
,
temb_channels
=
time_embed_dim
,
dims
=
dims
,
eps
=
1e-5
,
use_checkpoint
=
use_checkpoint
,
non_linearity
=
"silu"
,
use_scale_shift_norm
=
use_scale_shift_norm
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
),
),
)
)
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
...
@@ -271,15 +275,16 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -271,15 +275,16 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for
i
in
range
(
num_res_blocks
+
1
):
for
i
in
range
(
num_res_blocks
+
1
):
ich
=
input_block_chans
.
pop
()
ich
=
input_block_chans
.
pop
()
layers
=
[
layers
=
[
ResBlock
(
ResnetBlock
(
ch
+
ich
,
in_channels
=
ch
+
ich
,
time_embed_dim
,
out_channels
=
model_channels
*
mult
,
dropout
,
dropout
=
dropout
,
out_channels
=
int
(
model_channels
*
mult
),
temb_channels
=
time_embed_dim
,
dims
=
dims
,
eps
=
1e-5
,
use_checkpoint
=
use_checkpoint
,
non_linearity
=
"silu"
,
use_scale_shift_norm
=
use_scale_shift_norm
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
)
overwrite_for_glide
=
True
,
),
]
]
ch
=
int
(
model_channels
*
mult
)
ch
=
int
(
model_channels
*
mult
)
if
ds
in
attention_resolutions
:
if
ds
in
attention_resolutions
:
...
@@ -295,14 +300,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -295,14 +300,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
if
level
and
i
==
num_res_blocks
:
if
level
and
i
==
num_res_blocks
:
out_ch
=
ch
out_ch
=
ch
layers
.
append
(
layers
.
append
(
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
out_channels
=
out_ch
,
dims
=
dims
,
dropout
=
dropout
,
use_checkpoint
=
use_checkpoint
,
temb_channels
=
time_embed_dim
,
use_scale_shift_norm
=
use_scale_shift_norm
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
up
=
True
,
up
=
True
,
)
)
if
resblock_updown
if
resblock_updown
...
...
tests/test_modeling_utils.py
View file @
fb188cd3
...
@@ -259,7 +259,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -259,7 +259,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-
2
))
class
GlideSuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
GlideSuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
...
@@ -795,7 +795,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -795,7 +795,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
sizes
=
(
32
,
32
)
sizes
=
(
32
,
32
)
noise
=
torch
.
randn
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
noise
=
torch
.
randn
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
9.
]).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
9.
0
]).
to
(
torch_device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output
=
model
(
noise
,
time_step
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment