Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
fd6f93b2
"vscode:/vscode.git/clone" did not exist on "9c241d83fc15c0e9b11e4726a87511c0f8c68d51"
Commit
fd6f93b2
authored
Jun 30, 2022
by
Patrick von Platen
Browse files
all glide passes
parent
db934c67
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
117 additions
and
62 deletions
+117
-62
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+8
-14
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+109
-48
No files found.
src/diffusers/models/resnet.py
View file @
fd6f93b2
...
...
@@ -378,9 +378,6 @@ class ResBlock(TimestepBlock):
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
...
...
@@ -426,7 +423,7 @@ class ResnetBlock(nn.Module):
if
time_embedding_norm
==
"default"
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
if
time_embedding_norm
==
"scale_shift"
:
el
if
time_embedding_norm
==
"scale_shift"
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
...
...
@@ -489,7 +486,7 @@ class ResnetBlock(nn.Module):
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
if
u
se
_
scale_shift
_norm
else
self
.
out_channels
,
2
*
self
.
out_channels
if
se
lf
.
time_embedding_norm
==
"
scale_shift
"
else
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
...
...
@@ -551,9 +548,6 @@ class ResnetBlock(nn.Module):
self
.
set_weights_ldm
()
self
.
is_overwritten
=
True
if
self
.
up
or
self
.
down
:
x
=
self
.
x_upd
(
x
)
h
=
x
h
=
h
*
mask
if
self
.
pre_norm
:
...
...
@@ -561,6 +555,7 @@ class ResnetBlock(nn.Module):
h
=
self
.
nonlinearity
(
h
)
if
self
.
up
or
self
.
down
:
x
=
self
.
x_upd
(
x
)
h
=
self
.
h_upd
(
h
)
h
=
self
.
conv1
(
h
)
...
...
@@ -571,7 +566,6 @@ class ResnetBlock(nn.Module):
h
=
h
*
mask
temb
=
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
if
self
.
time_embedding_norm
==
"scale_shift"
:
scale
,
shift
=
torch
.
chunk
(
temb
,
2
,
dim
=
1
)
...
...
@@ -595,9 +589,9 @@ class ResnetBlock(nn.Module):
x
=
x
*
mask
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
#
if self.use_conv_shortcut:
#
x = self.conv_shortcut(x)
#
else:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
...
...
src/diffusers/models/unet_glide.py
View file @
fd6f93b2
...
...
@@ -7,6 +7,7 @@ from ..modeling_utils import ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
from
.resnet
import
ResnetBlock
def
convert_module_to_f16
(
l
):
...
...
@@ -101,7 +102,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def
forward
(
self
,
x
,
emb
,
encoder_out
=
None
):
for
layer
in
self
:
if
isinstance
(
layer
,
TimestepBlock
):
if
isinstance
(
layer
,
TimestepBlock
)
or
isinstance
(
layer
,
ResnetBlock
)
:
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
AttentionBlock
):
x
=
layer
(
x
,
encoder_out
)
...
...
@@ -190,14 +191,24 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
int
(
mult
*
model_channels
),
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=int(mult * model_channels),
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# )
ResnetBlock
(
in_channels
=
ch
,
out_channels
=
mult
*
model_channels
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
,
overwrite_for_glide
=
True
,
)
]
ch
=
int
(
mult
*
model_channels
)
...
...
@@ -218,15 +229,26 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
out_ch
=
ch
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# down=True,
# )
ResnetBlock
(
in_channels
=
ch
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
,
overwrite_for_glide
=
True
,
down
=
True
)
if
resblock_updown
else
Downsample
(
...
...
@@ -240,13 +262,22 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
self
.
_feature_size
+=
ch
self
.
middle_block
=
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# ),
ResnetBlock
(
in_channels
=
ch
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
,
overwrite_for_glide
=
True
,
),
AttentionBlock
(
ch
,
...
...
@@ -255,14 +286,23 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
num_head_channels
=
num_head_channels
,
encoder_channels
=
transformer_dim
,
),
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
),
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# ),
ResnetBlock
(
in_channels
=
ch
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
,
overwrite_for_glide
=
True
,
)
)
self
.
_feature_size
+=
ch
...
...
@@ -271,15 +311,25 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for
i
in
range
(
num_res_blocks
+
1
):
ich
=
input_block_chans
.
pop
()
layers
=
[
ResBlock
(
ch
+
ich
,
time_embed_dim
,
dropout
,
out_channels
=
int
(
model_channels
*
mult
),
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)
# ResBlock(
# ch + ich,
# time_embed_dim,
# dropout,
# out_channels=int(model_channels * mult),
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# )
ResnetBlock
(
in_channels
=
ch
+
ich
,
out_channels
=
model_channels
*
mult
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
,
overwrite_for_glide
=
True
,
),
]
ch
=
int
(
model_channels
*
mult
)
if
ds
in
attention_resolutions
:
...
...
@@ -295,14 +345,25 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
if
level
and
i
==
num_res_blocks
:
out_ch
=
ch
layers
.
append
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# up=True,
# )
ResnetBlock
(
in_channels
=
ch
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
,
overwrite_for_glide
=
True
,
up
=
True
,
)
if
resblock_updown
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment