Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
358531be
Commit
358531be
authored
Jun 29, 2022
by
Patrick von Platen
Browse files
up
parent
597b7ae2
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
271 additions
and
93 deletions
+271
-93
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+177
-8
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+94
-85
No files found.
src/diffusers/models/resnet.py
View file @
358531be
...
@@ -162,7 +162,7 @@ class Downsample(nn.Module):
...
@@ -162,7 +162,7 @@ class Downsample(nn.Module):
# RESNETS
# RESNETS
# unet_glide.py
& unet_ldm.py
# unet_glide.py
class
ResBlock
(
TimestepBlock
):
class
ResBlock
(
TimestepBlock
):
"""
"""
A residual block that can optionally change the number of channels.
A residual block that can optionally change the number of channels.
...
@@ -188,6 +188,7 @@ class ResBlock(TimestepBlock):
...
@@ -188,6 +188,7 @@ class ResBlock(TimestepBlock):
use_checkpoint
=
False
,
use_checkpoint
=
False
,
up
=
False
,
up
=
False
,
down
=
False
,
down
=
False
,
overwrite
=
False
,
# TODO(Patrick) - use for glide at later stage
):
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
...
@@ -236,6 +237,65 @@ class ResBlock(TimestepBlock):
...
@@ -236,6 +237,65 @@ class ResBlock(TimestepBlock):
else
:
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
self
.
overwrite
=
overwrite
self
.
is_overwritten
=
False
if
self
.
overwrite
:
in_channels
=
channels
out_channels
=
self
.
out_channels
conv_shortcut
=
False
dropout
=
0.0
temb_channels
=
emb_channels
groups
=
32
pre_norm
=
True
eps
=
1e-5
non_linearity
=
"silu"
self
.
pre_norm
=
pre_norm
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
if
self
.
pre_norm
:
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
else
:
self
.
norm1
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
nonlinearity
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
set_weights
(
self
):
# TODO(Patrick): use for glide at later stage
self
.
norm1
.
weight
.
data
=
self
.
in_layers
[
0
].
weight
.
data
self
.
norm1
.
bias
.
data
=
self
.
in_layers
[
0
].
bias
.
data
self
.
conv1
.
weight
.
data
=
self
.
in_layers
[
-
1
].
weight
.
data
self
.
conv1
.
bias
.
data
=
self
.
in_layers
[
-
1
].
bias
.
data
self
.
temb_proj
.
weight
.
data
=
self
.
emb_layers
[
-
1
].
weight
.
data
self
.
temb_proj
.
bias
.
data
=
self
.
emb_layers
[
-
1
].
bias
.
data
self
.
norm2
.
weight
.
data
=
self
.
out_layers
[
0
].
weight
.
data
self
.
norm2
.
bias
.
data
=
self
.
out_layers
[
0
].
bias
.
data
self
.
conv2
.
weight
.
data
=
self
.
out_layers
[
-
1
].
weight
.
data
self
.
conv2
.
bias
.
data
=
self
.
out_layers
[
-
1
].
bias
.
data
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
nin_shortcut
.
weight
.
data
=
self
.
skip_connection
.
weight
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
skip_connection
.
bias
.
data
def
forward
(
self
,
x
,
emb
):
def
forward
(
self
,
x
,
emb
):
"""
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
Apply the block to a Tensor, conditioned on a timestep embedding.
...
@@ -243,6 +303,10 @@ class ResBlock(TimestepBlock):
...
@@ -243,6 +303,10 @@ class ResBlock(TimestepBlock):
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
:return: an [N x C x ...] Tensor of outputs.
"""
"""
if
self
.
overwrite
:
# TODO(Patrick): use for glide at later stage
self
.
set_weights
()
if
self
.
updown
:
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
in_rest
(
x
)
...
@@ -251,6 +315,7 @@ class ResBlock(TimestepBlock):
...
@@ -251,6 +315,7 @@ class ResBlock(TimestepBlock):
h
=
in_conv
(
h
)
h
=
in_conv
(
h
)
else
:
else
:
h
=
self
.
in_layers
(
x
)
h
=
self
.
in_layers
(
x
)
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
emb_out
=
emb_out
[...,
None
]
...
@@ -262,7 +327,50 @@ class ResBlock(TimestepBlock):
...
@@ -262,7 +327,50 @@ class ResBlock(TimestepBlock):
else
:
else
:
h
=
h
+
emb_out
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
result
=
self
.
skip_connection
(
x
)
+
h
# TODO(Patrick) Use for glide at later stage
# result = self.forward_2(x, emb)
return
result
def
forward_2
(
self
,
x
,
temb
,
mask
=
1.0
):
if
self
.
overwrite
and
not
self
.
is_overwritten
:
self
.
set_weights
()
self
.
is_overwritten
=
True
h
=
x
if
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
if
not
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
h
+
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
if
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
not
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
# unet.py and unet_grad_tts.py
# unet.py and unet_grad_tts.py
...
@@ -280,6 +388,7 @@ class ResnetBlock(nn.Module):
...
@@ -280,6 +388,7 @@ class ResnetBlock(nn.Module):
eps
=
1e-6
,
eps
=
1e-6
,
non_linearity
=
"swish"
,
non_linearity
=
"swish"
,
overwrite_for_grad_tts
=
False
,
overwrite_for_grad_tts
=
False
,
overwrite_for_ldm
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
pre_norm
=
pre_norm
self
.
pre_norm
=
pre_norm
...
@@ -302,15 +411,19 @@ class ResnetBlock(nn.Module):
...
@@ -302,15 +411,19 @@ class ResnetBlock(nn.Module):
self
.
nonlinearity
=
nonlinearity
self
.
nonlinearity
=
nonlinearity
elif
non_linearity
==
"mish"
:
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
Mish
()
self
.
nonlinearity
=
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
if
self
.
use_conv_shortcut
:
# TODO(Patrick) - this branch is never used I think => can be deleted!
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
is_overwritten
=
False
self
.
is_overwritten
=
False
self
.
overwrite_for_grad_tts
=
overwrite_for_grad_tts
self
.
overwrite_for_grad_tts
=
overwrite_for_grad_tts
self
.
overwrite_for_ldm
=
overwrite_for_ldm
if
self
.
overwrite_for_grad_tts
:
if
self
.
overwrite_for_grad_tts
:
dim
=
in_channels
dim
=
in_channels
dim_out
=
out_channels
dim_out
=
out_channels
...
@@ -324,6 +437,39 @@ class ResnetBlock(nn.Module):
...
@@ -324,6 +437,39 @@ class ResnetBlock(nn.Module):
self
.
res_conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
self
.
res_conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
else
:
else
:
self
.
res_conv
=
torch
.
nn
.
Identity
()
self
.
res_conv
=
torch
.
nn
.
Identity
()
elif
self
.
overwrite_for_ldm
:
dims
=
2
# eps = 1e-5
# non_linearity = "silu"
# overwrite_for_ldm
channels
=
in_channels
emb_channels
=
temb_channels
use_scale_shift_norm
=
False
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
,
swish
=
1.0
),
nn
.
Identity
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
,
swish
=
0.0
if
use_scale_shift_norm
else
1.0
),
nn
.
SiLU
()
if
use_scale_shift_norm
else
nn
.
Identity
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
)
if
self
.
out_channels
==
in_channels
:
self
.
skip_connection
=
nn
.
Identity
()
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
def
set_weights_grad_tts
(
self
):
def
set_weights_grad_tts
(
self
):
self
.
conv1
.
weight
.
data
=
self
.
block1
.
block
[
0
].
weight
.
data
self
.
conv1
.
weight
.
data
=
self
.
block1
.
block
[
0
].
weight
.
data
...
@@ -343,13 +489,36 @@ class ResnetBlock(nn.Module):
...
@@ -343,13 +489,36 @@ class ResnetBlock(nn.Module):
self
.
nin_shortcut
.
weight
.
data
=
self
.
res_conv
.
weight
.
data
self
.
nin_shortcut
.
weight
.
data
=
self
.
res_conv
.
weight
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
res_conv
.
bias
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
res_conv
.
bias
.
data
def
forward
(
self
,
x
,
temb
,
mask
=
None
):
def
set_weights_ldm
(
self
):
self
.
norm1
.
weight
.
data
=
self
.
in_layers
[
0
].
weight
.
data
self
.
norm1
.
bias
.
data
=
self
.
in_layers
[
0
].
bias
.
data
self
.
conv1
.
weight
.
data
=
self
.
in_layers
[
-
1
].
weight
.
data
self
.
conv1
.
bias
.
data
=
self
.
in_layers
[
-
1
].
bias
.
data
self
.
temb_proj
.
weight
.
data
=
self
.
emb_layers
[
-
1
].
weight
.
data
self
.
temb_proj
.
bias
.
data
=
self
.
emb_layers
[
-
1
].
bias
.
data
self
.
norm2
.
weight
.
data
=
self
.
out_layers
[
0
].
weight
.
data
self
.
norm2
.
bias
.
data
=
self
.
out_layers
[
0
].
bias
.
data
self
.
conv2
.
weight
.
data
=
self
.
out_layers
[
-
1
].
weight
.
data
self
.
conv2
.
bias
.
data
=
self
.
out_layers
[
-
1
].
bias
.
data
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
nin_shortcut
.
weight
.
data
=
self
.
skip_connection
.
weight
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
skip_connection
.
bias
.
data
def
forward
(
self
,
x
,
temb
,
mask
=
1.0
):
if
self
.
overwrite_for_grad_tts
and
not
self
.
is_overwritten
:
if
self
.
overwrite_for_grad_tts
and
not
self
.
is_overwritten
:
self
.
set_weights_grad_tts
()
self
.
set_weights_grad_tts
()
self
.
is_overwritten
=
True
self
.
is_overwritten
=
True
elif
self
.
overwrite_for_ldm
and
not
self
.
is_overwritten
:
self
.
set_weights_ldm
()
self
.
is_overwritten
=
True
h
=
x
h
=
x
h
=
h
*
mask
if
mask
is
not
None
else
h
h
=
h
*
mask
if
self
.
pre_norm
:
if
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
...
@@ -359,11 +528,11 @@ class ResnetBlock(nn.Module):
...
@@ -359,11 +528,11 @@ class ResnetBlock(nn.Module):
if
not
self
.
pre_norm
:
if
not
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
h
*
mask
if
mask
is
not
None
else
h
h
=
h
*
mask
h
=
h
+
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
h
+
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
h
*
mask
if
mask
is
not
None
else
h
h
=
h
*
mask
if
self
.
pre_norm
:
if
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
...
@@ -374,9 +543,9 @@ class ResnetBlock(nn.Module):
...
@@ -374,9 +543,9 @@ class ResnetBlock(nn.Module):
if
not
self
.
pre_norm
:
if
not
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
h
*
mask
if
mask
is
not
None
else
h
h
=
h
*
mask
x
=
x
*
mask
if
mask
is
not
None
else
x
x
=
x
*
mask
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
x
=
self
.
conv_shortcut
(
x
)
...
...
src/diffusers/models/unet_ldm.py
View file @
358531be
...
@@ -10,7 +10,9 @@ from ..configuration_utils import ConfigMixin
...
@@ -10,7 +10,9 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
from
.resnet
import
Downsample
,
TimestepBlock
,
Upsample
from
.resnet
import
ResnetBlock
#from .resnet import ResBlock
def
exists
(
val
):
def
exists
(
val
):
...
@@ -364,7 +366,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -364,7 +366,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def
forward
(
self
,
x
,
emb
,
context
=
None
):
def
forward
(
self
,
x
,
emb
,
context
=
None
):
for
layer
in
self
:
for
layer
in
self
:
if
isinstance
(
layer
,
TimestepBlock
):
if
isinstance
(
layer
,
TimestepBlock
)
or
isinstance
(
layer
,
ResnetBlock
)
:
x
=
layer
(
x
,
emb
)
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
SpatialTransformer
):
elif
isinstance
(
layer
,
SpatialTransformer
):
x
=
layer
(
x
,
context
)
x
=
layer
(
x
,
context
)
...
@@ -559,14 +561,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -559,14 +561,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
layers
=
[
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
dropout
,
out_channels
=
mult
*
model_channels
,
out_channels
=
mult
*
model_channels
,
dims
=
dims
,
dropout
=
dropout
,
use_checkpoint
=
use_checkpoint
,
temb_channels
=
time_embed_dim
,
use_scale_shift_norm
=
use_scale_shift_norm
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
)
)
]
]
ch
=
mult
*
model_channels
ch
=
mult
*
model_channels
...
@@ -599,16 +601,17 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -599,16 +601,17 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
out_ch
=
ch
out_ch
=
ch
self
.
input_blocks
.
append
(
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
TimestepEmbedSequential
(
ResBlock
(
# ResBlock(
ch
,
# ch,
time_embed_dim
,
# time_embed_dim,
dropout
,
# dropout,
out_channels
=
out_ch
,
# out_channels=out_ch,
dims
=
dims
,
# dims=dims,
use_checkpoint
=
use_checkpoint
,
# use_checkpoint=use_checkpoint,
use_scale_shift_norm
=
use_scale_shift_norm
,
# use_scale_shift_norm=use_scale_shift_norm,
down
=
True
,
# down=True,
)
# )
None
if
resblock_updown
if
resblock_updown
else
Downsample
(
else
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
...
@@ -629,13 +632,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -629,13 +632,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
# num_heads = 1
# num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
self
.
middle_block
=
TimestepEmbedSequential
(
self
.
middle_block
=
TimestepEmbedSequential
(
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
out_channels
=
None
,
dropout
,
dropout
=
dropout
,
dims
=
dims
,
temb_channels
=
time_embed_dim
,
use_checkpoint
=
use_checkpoint
,
eps
=
1e-5
,
use_scale_shift_norm
=
use_scale_shift_norm
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
),
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
...
@@ -646,13 +650,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -646,13 +650,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
)
)
if
not
use_spatial_transformer
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
),
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
),
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
out_channels
=
None
,
dropout
,
dropout
=
dropout
,
dims
=
dims
,
temb_channels
=
time_embed_dim
,
use_checkpoint
=
use_checkpoint
,
eps
=
1e-5
,
use_scale_shift_norm
=
use_scale_shift_norm
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
),
)
)
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
...
@@ -662,15 +667,15 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -662,15 +667,15 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for
i
in
range
(
num_res_blocks
+
1
):
for
i
in
range
(
num_res_blocks
+
1
):
ich
=
input_block_chans
.
pop
()
ich
=
input_block_chans
.
pop
()
layers
=
[
layers
=
[
ResBlock
(
ResnetBlock
(
ch
+
ich
,
in_channels
=
ch
+
ich
,
time_embed_dim
,
dropout
,
out_channels
=
model_channels
*
mult
,
out_channels
=
model_channels
*
mult
,
dims
=
dims
,
dropout
=
dropout
,
use_checkpoint
=
use_checkpoint
,
temb_channels
=
time_embed_dim
,
use_scale_shift_norm
=
use_scale_shift_norm
,
eps
=
1e-5
,
)
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
]
]
ch
=
model_channels
*
mult
ch
=
model_channels
*
mult
if
ds
in
attention_resolutions
:
if
ds
in
attention_resolutions
:
...
@@ -698,16 +703,17 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -698,16 +703,17 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
if
level
and
i
==
num_res_blocks
:
if
level
and
i
==
num_res_blocks
:
out_ch
=
ch
out_ch
=
ch
layers
.
append
(
layers
.
append
(
ResBlock
(
# ResBlock(
ch
,
# ch,
time_embed_dim
,
# time_embed_dim,
dropout
,
# dropout,
out_channels
=
out_ch
,
# out_channels=out_ch,
dims
=
dims
,
# dims=dims,
use_checkpoint
=
use_checkpoint
,
# use_checkpoint=use_checkpoint,
use_scale_shift_norm
=
use_scale_shift_norm
,
# use_scale_shift_norm=use_scale_shift_norm,
up
=
True
,
# up=True,
)
# )
None
if
resblock_updown
if
resblock_updown
else
Upsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
else
Upsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
)
...
@@ -842,15 +848,15 @@ class EncoderUNetModel(nn.Module):
...
@@ -842,15 +848,15 @@ class EncoderUNetModel(nn.Module):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
layers
=
[
ResBlock
(
Res
net
Block
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
out_channels
=
model_channels
*
mult
,
dropout
,
dropout
=
dropout
,
out
_channels
=
mult
*
model_channels
,
temb
_channels
=
time_embed_dim
,
dims
=
dims
,
eps
=
1e-5
,
use_checkpoint
=
use_checkpoint
,
non_linearity
=
"silu"
,
use_scale_shift_norm
=
use_scale_shift_norm
,
overwrite_for_ldm
=
True
,
)
)
,
]
]
ch
=
mult
*
model_channels
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
if
ds
in
attention_resolutions
:
...
@@ -870,16 +876,17 @@ class EncoderUNetModel(nn.Module):
...
@@ -870,16 +876,17 @@ class EncoderUNetModel(nn.Module):
out_ch
=
ch
out_ch
=
ch
self
.
input_blocks
.
append
(
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
TimestepEmbedSequential
(
ResBlock
(
# ResBlock(
ch
,
# ch,
time_embed_dim
,
# time_embed_dim,
dropout
,
# dropout,
out_channels
=
out_ch
,
# out_channels=out_ch,
dims
=
dims
,
# dims=dims,
use_checkpoint
=
use_checkpoint
,
# use_checkpoint=use_checkpoint,
use_scale_shift_norm
=
use_scale_shift_norm
,
# use_scale_shift_norm=use_scale_shift_norm,
down
=
True
,
# down=True,
)
# )
None
if
resblock_updown
if
resblock_updown
else
Downsample
(
else
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
...
@@ -892,13 +899,14 @@ class EncoderUNetModel(nn.Module):
...
@@ -892,13 +899,14 @@ class EncoderUNetModel(nn.Module):
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
self
.
middle_block
=
TimestepEmbedSequential
(
self
.
middle_block
=
TimestepEmbedSequential
(
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
out_channels
=
None
,
dropout
,
dropout
=
dropout
,
dims
=
dims
,
temb_channels
=
time_embed_dim
,
use_checkpoint
=
use_checkpoint
,
eps
=
1e-5
,
use_scale_shift_norm
=
use_scale_shift_norm
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
),
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
...
@@ -907,13 +915,14 @@ class EncoderUNetModel(nn.Module):
...
@@ -907,13 +915,14 @@ class EncoderUNetModel(nn.Module):
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
use_new_attention_order
=
use_new_attention_order
,
use_new_attention_order
=
use_new_attention_order
,
),
),
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
out_channels
=
None
,
dropout
,
dropout
=
dropout
,
dims
=
dims
,
temb_channels
=
time_embed_dim
,
use_checkpoint
=
use_checkpoint
,
eps
=
1e-5
,
use_scale_shift_norm
=
use_scale_shift_norm
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
),
)
)
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment