Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
26ce60c4
Commit
26ce60c4
authored
Jun 29, 2022
by
Patrick von Platen
Browse files
up
parent
358531be
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
48 deletions
+49
-48
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+7
-7
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+42
-41
No files found.
src/diffusers/models/resnet.py
View file @
26ce60c4
...
@@ -330,8 +330,8 @@ class ResBlock(TimestepBlock):
...
@@ -330,8 +330,8 @@ class ResBlock(TimestepBlock):
result
=
self
.
skip_connection
(
x
)
+
h
result
=
self
.
skip_connection
(
x
)
+
h
# TODO(Patrick) Use for glide at later stage
# TODO(Patrick) Use for glide at later stage
# result = self.forward_2(x, emb)
# result = self.forward_2(x, emb)
return
result
return
result
...
@@ -439,9 +439,9 @@ class ResnetBlock(nn.Module):
...
@@ -439,9 +439,9 @@ class ResnetBlock(nn.Module):
self
.
res_conv
=
torch
.
nn
.
Identity
()
self
.
res_conv
=
torch
.
nn
.
Identity
()
elif
self
.
overwrite_for_ldm
:
elif
self
.
overwrite_for_ldm
:
dims
=
2
dims
=
2
# eps = 1e-5
# eps = 1e-5
# non_linearity = "silu"
# non_linearity = "silu"
# overwrite_for_ldm
# overwrite_for_ldm
channels
=
in_channels
channels
=
in_channels
emb_channels
=
temb_channels
emb_channels
=
temb_channels
use_scale_shift_norm
=
False
use_scale_shift_norm
=
False
...
@@ -466,8 +466,8 @@ class ResnetBlock(nn.Module):
...
@@ -466,8 +466,8 @@ class ResnetBlock(nn.Module):
)
)
if
self
.
out_channels
==
in_channels
:
if
self
.
out_channels
==
in_channels
:
self
.
skip_connection
=
nn
.
Identity
()
self
.
skip_connection
=
nn
.
Identity
()
# elif use_conv:
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else
:
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
...
...
src/diffusers/models/unet_ldm.py
View file @
26ce60c4
...
@@ -10,9 +10,10 @@ from ..configuration_utils import ConfigMixin
...
@@ -10,9 +10,10 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
TimestepBlock
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock
,
TimestepBlock
,
Upsample
from
.resnet
import
ResnetBlock
#from .resnet import ResBlock
# from .resnet import ResBlock
def
exists
(
val
):
def
exists
(
val
):
...
@@ -601,16 +602,16 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -601,16 +602,16 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
out_ch
=
ch
out_ch
=
ch
self
.
input_blocks
.
append
(
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
TimestepEmbedSequential
(
# ResBlock(
# ResBlock(
# ch,
# ch,
# time_embed_dim,
# time_embed_dim,
# dropout,
# dropout,
# out_channels=out_ch,
# out_channels=out_ch,
# dims=dims,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# use_scale_shift_norm=use_scale_shift_norm,
# down=True,
# down=True,
# )
# )
None
None
if
resblock_updown
if
resblock_updown
else
Downsample
(
else
Downsample
(
...
@@ -703,16 +704,16 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -703,16 +704,16 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
if
level
and
i
==
num_res_blocks
:
if
level
and
i
==
num_res_blocks
:
out_ch
=
ch
out_ch
=
ch
layers
.
append
(
layers
.
append
(
# ResBlock(
# ResBlock(
# ch,
# ch,
# time_embed_dim,
# time_embed_dim,
# dropout,
# dropout,
# out_channels=out_ch,
# out_channels=out_ch,
# dims=dims,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# use_scale_shift_norm=use_scale_shift_norm,
# up=True,
# up=True,
# )
# )
None
None
if
resblock_updown
if
resblock_updown
else
Upsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
else
Upsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
...
@@ -876,16 +877,16 @@ class EncoderUNetModel(nn.Module):
...
@@ -876,16 +877,16 @@ class EncoderUNetModel(nn.Module):
out_ch
=
ch
out_ch
=
ch
self
.
input_blocks
.
append
(
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
TimestepEmbedSequential
(
# ResBlock(
# ResBlock(
# ch,
# ch,
# time_embed_dim,
# time_embed_dim,
# dropout,
# dropout,
# out_channels=out_ch,
# out_channels=out_ch,
# dims=dims,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# use_scale_shift_norm=use_scale_shift_norm,
# down=True,
# down=True,
# )
# )
None
None
if
resblock_updown
if
resblock_updown
else
Downsample
(
else
Downsample
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment