Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
b65eb377
Unverified
Commit
b65eb377
authored
Jun 29, 2022
by
Patrick von Platen
Committed by
GitHub
Jun 29, 2022
Browse files
Merge pull request #46 from huggingface/merge_ldm_resnet
[ResNet Refactor] Merge ldm into resnet
parents
66ee73ee
26ce60c4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
271 additions
and
92 deletions
+271
-92
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+177
-8
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+94
-84
No files found.
src/diffusers/models/resnet.py
View file @
b65eb377
...
@@ -162,7 +162,7 @@ class Downsample(nn.Module):
...
@@ -162,7 +162,7 @@ class Downsample(nn.Module):
# RESNETS
# RESNETS
# unet_glide.py
& unet_ldm.py
# unet_glide.py
class
ResBlock
(
TimestepBlock
):
class
ResBlock
(
TimestepBlock
):
"""
"""
A residual block that can optionally change the number of channels.
A residual block that can optionally change the number of channels.
...
@@ -188,6 +188,7 @@ class ResBlock(TimestepBlock):
...
@@ -188,6 +188,7 @@ class ResBlock(TimestepBlock):
use_checkpoint
=
False
,
use_checkpoint
=
False
,
up
=
False
,
up
=
False
,
down
=
False
,
down
=
False
,
overwrite
=
False
,
# TODO(Patrick) - use for glide at later stage
):
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
...
@@ -236,6 +237,65 @@ class ResBlock(TimestepBlock):
...
@@ -236,6 +237,65 @@ class ResBlock(TimestepBlock):
else
:
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
self
.
overwrite
=
overwrite
self
.
is_overwritten
=
False
if
self
.
overwrite
:
in_channels
=
channels
out_channels
=
self
.
out_channels
conv_shortcut
=
False
dropout
=
0.0
temb_channels
=
emb_channels
groups
=
32
pre_norm
=
True
eps
=
1e-5
non_linearity
=
"silu"
self
.
pre_norm
=
pre_norm
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
if
self
.
pre_norm
:
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
else
:
self
.
norm1
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
nonlinearity
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
set_weights
(
self
):
# TODO(Patrick): use for glide at later stage
self
.
norm1
.
weight
.
data
=
self
.
in_layers
[
0
].
weight
.
data
self
.
norm1
.
bias
.
data
=
self
.
in_layers
[
0
].
bias
.
data
self
.
conv1
.
weight
.
data
=
self
.
in_layers
[
-
1
].
weight
.
data
self
.
conv1
.
bias
.
data
=
self
.
in_layers
[
-
1
].
bias
.
data
self
.
temb_proj
.
weight
.
data
=
self
.
emb_layers
[
-
1
].
weight
.
data
self
.
temb_proj
.
bias
.
data
=
self
.
emb_layers
[
-
1
].
bias
.
data
self
.
norm2
.
weight
.
data
=
self
.
out_layers
[
0
].
weight
.
data
self
.
norm2
.
bias
.
data
=
self
.
out_layers
[
0
].
bias
.
data
self
.
conv2
.
weight
.
data
=
self
.
out_layers
[
-
1
].
weight
.
data
self
.
conv2
.
bias
.
data
=
self
.
out_layers
[
-
1
].
bias
.
data
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
nin_shortcut
.
weight
.
data
=
self
.
skip_connection
.
weight
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
skip_connection
.
bias
.
data
def
forward
(
self
,
x
,
emb
):
def
forward
(
self
,
x
,
emb
):
"""
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
Apply the block to a Tensor, conditioned on a timestep embedding.
...
@@ -243,6 +303,10 @@ class ResBlock(TimestepBlock):
...
@@ -243,6 +303,10 @@ class ResBlock(TimestepBlock):
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
:return: an [N x C x ...] Tensor of outputs.
"""
"""
if
self
.
overwrite
:
# TODO(Patrick): use for glide at later stage
self
.
set_weights
()
if
self
.
updown
:
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
in_rest
(
x
)
...
@@ -251,6 +315,7 @@ class ResBlock(TimestepBlock):
...
@@ -251,6 +315,7 @@ class ResBlock(TimestepBlock):
h
=
in_conv
(
h
)
h
=
in_conv
(
h
)
else
:
else
:
h
=
self
.
in_layers
(
x
)
h
=
self
.
in_layers
(
x
)
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
emb_out
=
emb_out
[...,
None
]
...
@@ -262,7 +327,50 @@ class ResBlock(TimestepBlock):
...
@@ -262,7 +327,50 @@ class ResBlock(TimestepBlock):
else
:
else
:
h
=
h
+
emb_out
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
result
=
self
.
skip_connection
(
x
)
+
h
# TODO(Patrick) Use for glide at later stage
# result = self.forward_2(x, emb)
return
result
def
forward_2
(
self
,
x
,
temb
,
mask
=
1.0
):
if
self
.
overwrite
and
not
self
.
is_overwritten
:
self
.
set_weights
()
self
.
is_overwritten
=
True
h
=
x
if
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
if
not
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
h
+
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
if
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
not
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
# unet.py and unet_grad_tts.py
# unet.py and unet_grad_tts.py
...
@@ -280,6 +388,7 @@ class ResnetBlock(nn.Module):
...
@@ -280,6 +388,7 @@ class ResnetBlock(nn.Module):
eps
=
1e-6
,
eps
=
1e-6
,
non_linearity
=
"swish"
,
non_linearity
=
"swish"
,
overwrite_for_grad_tts
=
False
,
overwrite_for_grad_tts
=
False
,
overwrite_for_ldm
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
pre_norm
=
pre_norm
self
.
pre_norm
=
pre_norm
...
@@ -302,15 +411,19 @@ class ResnetBlock(nn.Module):
...
@@ -302,15 +411,19 @@ class ResnetBlock(nn.Module):
self
.
nonlinearity
=
nonlinearity
self
.
nonlinearity
=
nonlinearity
elif
non_linearity
==
"mish"
:
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
Mish
()
self
.
nonlinearity
=
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
if
self
.
use_conv_shortcut
:
# TODO(Patrick) - this branch is never used I think => can be deleted!
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
is_overwritten
=
False
self
.
is_overwritten
=
False
self
.
overwrite_for_grad_tts
=
overwrite_for_grad_tts
self
.
overwrite_for_grad_tts
=
overwrite_for_grad_tts
self
.
overwrite_for_ldm
=
overwrite_for_ldm
if
self
.
overwrite_for_grad_tts
:
if
self
.
overwrite_for_grad_tts
:
dim
=
in_channels
dim
=
in_channels
dim_out
=
out_channels
dim_out
=
out_channels
...
@@ -324,6 +437,39 @@ class ResnetBlock(nn.Module):
...
@@ -324,6 +437,39 @@ class ResnetBlock(nn.Module):
self
.
res_conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
self
.
res_conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
else
:
else
:
self
.
res_conv
=
torch
.
nn
.
Identity
()
self
.
res_conv
=
torch
.
nn
.
Identity
()
elif
self
.
overwrite_for_ldm
:
dims
=
2
# eps = 1e-5
# non_linearity = "silu"
# overwrite_for_ldm
channels
=
in_channels
emb_channels
=
temb_channels
use_scale_shift_norm
=
False
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
,
swish
=
1.0
),
nn
.
Identity
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
,
swish
=
0.0
if
use_scale_shift_norm
else
1.0
),
nn
.
SiLU
()
if
use_scale_shift_norm
else
nn
.
Identity
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
)
if
self
.
out_channels
==
in_channels
:
self
.
skip_connection
=
nn
.
Identity
()
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
def
set_weights_grad_tts
(
self
):
def
set_weights_grad_tts
(
self
):
self
.
conv1
.
weight
.
data
=
self
.
block1
.
block
[
0
].
weight
.
data
self
.
conv1
.
weight
.
data
=
self
.
block1
.
block
[
0
].
weight
.
data
...
@@ -343,13 +489,36 @@ class ResnetBlock(nn.Module):
...
@@ -343,13 +489,36 @@ class ResnetBlock(nn.Module):
self
.
nin_shortcut
.
weight
.
data
=
self
.
res_conv
.
weight
.
data
self
.
nin_shortcut
.
weight
.
data
=
self
.
res_conv
.
weight
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
res_conv
.
bias
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
res_conv
.
bias
.
data
def
forward
(
self
,
x
,
temb
,
mask
=
None
):
def
set_weights_ldm
(
self
):
self
.
norm1
.
weight
.
data
=
self
.
in_layers
[
0
].
weight
.
data
self
.
norm1
.
bias
.
data
=
self
.
in_layers
[
0
].
bias
.
data
self
.
conv1
.
weight
.
data
=
self
.
in_layers
[
-
1
].
weight
.
data
self
.
conv1
.
bias
.
data
=
self
.
in_layers
[
-
1
].
bias
.
data
self
.
temb_proj
.
weight
.
data
=
self
.
emb_layers
[
-
1
].
weight
.
data
self
.
temb_proj
.
bias
.
data
=
self
.
emb_layers
[
-
1
].
bias
.
data
self
.
norm2
.
weight
.
data
=
self
.
out_layers
[
0
].
weight
.
data
self
.
norm2
.
bias
.
data
=
self
.
out_layers
[
0
].
bias
.
data
self
.
conv2
.
weight
.
data
=
self
.
out_layers
[
-
1
].
weight
.
data
self
.
conv2
.
bias
.
data
=
self
.
out_layers
[
-
1
].
bias
.
data
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
nin_shortcut
.
weight
.
data
=
self
.
skip_connection
.
weight
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
skip_connection
.
bias
.
data
def
forward
(
self
,
x
,
temb
,
mask
=
1.0
):
if
self
.
overwrite_for_grad_tts
and
not
self
.
is_overwritten
:
if
self
.
overwrite_for_grad_tts
and
not
self
.
is_overwritten
:
self
.
set_weights_grad_tts
()
self
.
set_weights_grad_tts
()
self
.
is_overwritten
=
True
self
.
is_overwritten
=
True
elif
self
.
overwrite_for_ldm
and
not
self
.
is_overwritten
:
self
.
set_weights_ldm
()
self
.
is_overwritten
=
True
h
=
x
h
=
x
h
=
h
*
mask
if
mask
is
not
None
else
h
h
=
h
*
mask
if
self
.
pre_norm
:
if
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
...
@@ -359,11 +528,11 @@ class ResnetBlock(nn.Module):
...
@@ -359,11 +528,11 @@ class ResnetBlock(nn.Module):
if
not
self
.
pre_norm
:
if
not
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
h
*
mask
if
mask
is
not
None
else
h
h
=
h
*
mask
h
=
h
+
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
h
+
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
h
*
mask
if
mask
is
not
None
else
h
h
=
h
*
mask
if
self
.
pre_norm
:
if
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
...
@@ -374,9 +543,9 @@ class ResnetBlock(nn.Module):
...
@@ -374,9 +543,9 @@ class ResnetBlock(nn.Module):
if
not
self
.
pre_norm
:
if
not
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
h
*
mask
if
mask
is
not
None
else
h
h
=
h
*
mask
x
=
x
*
mask
if
mask
is
not
None
else
x
x
=
x
*
mask
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
x
=
self
.
conv_shortcut
(
x
)
...
...
src/diffusers/models/unet_ldm.py
View file @
b65eb377
...
@@ -10,7 +10,10 @@ from ..configuration_utils import ConfigMixin
...
@@ -10,7 +10,10 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock
,
TimestepBlock
,
Upsample
# from .resnet import ResBlock
def
exists
(
val
):
def
exists
(
val
):
...
@@ -364,7 +367,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -364,7 +367,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def
forward
(
self
,
x
,
emb
,
context
=
None
):
def
forward
(
self
,
x
,
emb
,
context
=
None
):
for
layer
in
self
:
for
layer
in
self
:
if
isinstance
(
layer
,
TimestepBlock
):
if
isinstance
(
layer
,
TimestepBlock
)
or
isinstance
(
layer
,
ResnetBlock
)
:
x
=
layer
(
x
,
emb
)
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
SpatialTransformer
):
elif
isinstance
(
layer
,
SpatialTransformer
):
x
=
layer
(
x
,
context
)
x
=
layer
(
x
,
context
)
...
@@ -559,14 +562,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -559,14 +562,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
layers
=
[
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
dropout
,
out_channels
=
mult
*
model_channels
,
out_channels
=
mult
*
model_channels
,
dims
=
dims
,
dropout
=
dropout
,
use_checkpoint
=
use_checkpoint
,
temb_channels
=
time_embed_dim
,
use_scale_shift_norm
=
use_scale_shift_norm
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
)
)
]
]
ch
=
mult
*
model_channels
ch
=
mult
*
model_channels
...
@@ -599,16 +602,17 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -599,16 +602,17 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
out_ch
=
ch
out_ch
=
ch
self
.
input_blocks
.
append
(
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
TimestepEmbedSequential
(
ResBlock
(
# ResBlock(
ch
,
# ch,
time_embed_dim
,
# time_embed_dim,
dropout
,
# dropout,
out_channels
=
out_ch
,
# out_channels=out_ch,
dims
=
dims
,
# dims=dims,
use_checkpoint
=
use_checkpoint
,
# use_checkpoint=use_checkpoint,
use_scale_shift_norm
=
use_scale_shift_norm
,
# use_scale_shift_norm=use_scale_shift_norm,
down
=
True
,
# down=True,
)
# )
None
if
resblock_updown
if
resblock_updown
else
Downsample
(
else
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
...
@@ -629,13 +633,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -629,13 +633,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
# num_heads = 1
# num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
self
.
middle_block
=
TimestepEmbedSequential
(
self
.
middle_block
=
TimestepEmbedSequential
(
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
out_channels
=
None
,
dropout
,
dropout
=
dropout
,
dims
=
dims
,
temb_channels
=
time_embed_dim
,
use_checkpoint
=
use_checkpoint
,
eps
=
1e-5
,
use_scale_shift_norm
=
use_scale_shift_norm
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
),
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
...
@@ -646,13 +651,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -646,13 +651,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
)
)
if
not
use_spatial_transformer
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
),
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
),
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
out_channels
=
None
,
dropout
,
dropout
=
dropout
,
dims
=
dims
,
temb_channels
=
time_embed_dim
,
use_checkpoint
=
use_checkpoint
,
eps
=
1e-5
,
use_scale_shift_norm
=
use_scale_shift_norm
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
),
)
)
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
...
@@ -662,15 +668,15 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -662,15 +668,15 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for
i
in
range
(
num_res_blocks
+
1
):
for
i
in
range
(
num_res_blocks
+
1
):
ich
=
input_block_chans
.
pop
()
ich
=
input_block_chans
.
pop
()
layers
=
[
layers
=
[
ResBlock
(
ResnetBlock
(
ch
+
ich
,
in_channels
=
ch
+
ich
,
time_embed_dim
,
dropout
,
out_channels
=
model_channels
*
mult
,
out_channels
=
model_channels
*
mult
,
dims
=
dims
,
dropout
=
dropout
,
use_checkpoint
=
use_checkpoint
,
temb_channels
=
time_embed_dim
,
use_scale_shift_norm
=
use_scale_shift_norm
,
eps
=
1e-5
,
)
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
]
]
ch
=
model_channels
*
mult
ch
=
model_channels
*
mult
if
ds
in
attention_resolutions
:
if
ds
in
attention_resolutions
:
...
@@ -698,16 +704,17 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -698,16 +704,17 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
if
level
and
i
==
num_res_blocks
:
if
level
and
i
==
num_res_blocks
:
out_ch
=
ch
out_ch
=
ch
layers
.
append
(
layers
.
append
(
ResBlock
(
# ResBlock(
ch
,
# ch,
time_embed_dim
,
# time_embed_dim,
dropout
,
# dropout,
out_channels
=
out_ch
,
# out_channels=out_ch,
dims
=
dims
,
# dims=dims,
use_checkpoint
=
use_checkpoint
,
# use_checkpoint=use_checkpoint,
use_scale_shift_norm
=
use_scale_shift_norm
,
# use_scale_shift_norm=use_scale_shift_norm,
up
=
True
,
# up=True,
)
# )
None
if
resblock_updown
if
resblock_updown
else
Upsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
else
Upsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
)
...
@@ -842,15 +849,15 @@ class EncoderUNetModel(nn.Module):
...
@@ -842,15 +849,15 @@ class EncoderUNetModel(nn.Module):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
layers
=
[
ResBlock
(
Res
net
Block
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
out_channels
=
model_channels
*
mult
,
dropout
,
dropout
=
dropout
,
out
_channels
=
mult
*
model_channels
,
temb
_channels
=
time_embed_dim
,
dims
=
dims
,
eps
=
1e-5
,
use_checkpoint
=
use_checkpoint
,
non_linearity
=
"silu"
,
use_scale_shift_norm
=
use_scale_shift_norm
,
overwrite_for_ldm
=
True
,
)
)
,
]
]
ch
=
mult
*
model_channels
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
if
ds
in
attention_resolutions
:
...
@@ -870,16 +877,17 @@ class EncoderUNetModel(nn.Module):
...
@@ -870,16 +877,17 @@ class EncoderUNetModel(nn.Module):
out_ch
=
ch
out_ch
=
ch
self
.
input_blocks
.
append
(
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
TimestepEmbedSequential
(
ResBlock
(
# ResBlock(
ch
,
# ch,
time_embed_dim
,
# time_embed_dim,
dropout
,
# dropout,
out_channels
=
out_ch
,
# out_channels=out_ch,
dims
=
dims
,
# dims=dims,
use_checkpoint
=
use_checkpoint
,
# use_checkpoint=use_checkpoint,
use_scale_shift_norm
=
use_scale_shift_norm
,
# use_scale_shift_norm=use_scale_shift_norm,
down
=
True
,
# down=True,
)
# )
None
if
resblock_updown
if
resblock_updown
else
Downsample
(
else
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
...
@@ -892,13 +900,14 @@ class EncoderUNetModel(nn.Module):
...
@@ -892,13 +900,14 @@ class EncoderUNetModel(nn.Module):
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
self
.
middle_block
=
TimestepEmbedSequential
(
self
.
middle_block
=
TimestepEmbedSequential
(
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
out_channels
=
None
,
dropout
,
dropout
=
dropout
,
dims
=
dims
,
temb_channels
=
time_embed_dim
,
use_checkpoint
=
use_checkpoint
,
eps
=
1e-5
,
use_scale_shift_norm
=
use_scale_shift_norm
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
),
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
...
@@ -907,13 +916,14 @@ class EncoderUNetModel(nn.Module):
...
@@ -907,13 +916,14 @@ class EncoderUNetModel(nn.Module):
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
use_new_attention_order
=
use_new_attention_order
,
use_new_attention_order
=
use_new_attention_order
,
),
),
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
out_channels
=
None
,
dropout
,
dropout
=
dropout
,
dims
=
dims
,
temb_channels
=
time_embed_dim
,
use_checkpoint
=
use_checkpoint
,
eps
=
1e-5
,
use_scale_shift_norm
=
use_scale_shift_norm
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
),
)
)
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment