Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
7bf89ba9
"ppocr/vscode:/vscode.git/clone" did not exist on "e4f998b997b48bb6bfa447e1fedfb8ef4dbb7d77"
Commit
7bf89ba9
authored
Jun 15, 2023
by
comfyanonymous
Browse files
Initialize more unet weights as the right dtype.
parent
e21d9ad4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
9 deletions
+10
-9
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+10
-9
No files found.
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
7bf89ba9
...
...
@@ -208,6 +208,7 @@ class ResBlock(TimestepBlock):
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
dtype
=
None
):
super
().
__init__
()
self
.
channels
=
channels
...
...
@@ -221,7 +222,7 @@ class ResBlock(TimestepBlock):
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
),
nn
.
SiLU
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
),
)
self
.
updown
=
up
or
down
...
...
@@ -247,7 +248,7 @@ class ResBlock(TimestepBlock):
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
)
),
)
...
...
@@ -255,10 +256,10 @@ class ResBlock(TimestepBlock):
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
,
dtype
=
dtype
)
def
forward
(
self
,
x
,
emb
):
"""
...
...
@@ -558,9 +559,9 @@ class UNetModel(nn.Module):
time_embed_dim
=
model_channels
*
4
self
.
time_embed
=
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
),
linear
(
model_channels
,
time_embed_dim
,
dtype
=
self
.
dtype
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
linear
(
time_embed_dim
,
time_embed_dim
,
dtype
=
self
.
dtype
),
)
if
self
.
num_classes
is
not
None
:
...
...
@@ -573,9 +574,9 @@ class UNetModel(nn.Module):
assert
adm_in_channels
is
not
None
self
.
label_emb
=
nn
.
Sequential
(
nn
.
Sequential
(
linear
(
adm_in_channels
,
time_embed_dim
),
linear
(
adm_in_channels
,
time_embed_dim
,
dtype
=
self
.
dtype
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
linear
(
time_embed_dim
,
time_embed_dim
,
dtype
=
self
.
dtype
),
)
)
else
:
...
...
@@ -584,7 +585,7 @@ class UNetModel(nn.Module):
self
.
input_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
)
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
)
)
]
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment