Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
220a72d3
Commit
220a72d3
authored
Feb 17, 2023
by
comfyanonymous
Browse files
Use fp16 for fp16 control nets.
parent
71354c7c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
6 deletions
+26
-6
comfy/sd.py
comfy/sd.py
+26
-6
No files found.
comfy/sd.py
View file @
220a72d3
import
torch
import
contextlib
import
sd1_clip
import
sd2_clip
...
...
@@ -327,23 +328,36 @@ class VAE:
return
samples
class
ControlNet
:
def
__init__
(
self
,
control_model
):
def
__init__
(
self
,
control_model
,
device
=
"cuda"
):
self
.
control_model
=
control_model
self
.
cond_hint_original
=
None
self
.
cond_hint
=
None
self
.
strength
=
1.0
self
.
device
=
device
def
get_control
(
self
,
x_noisy
,
t
,
cond_txt
):
output_dtype
=
x_noisy
.
dtype
if
self
.
cond_hint
is
None
or
x_noisy
.
shape
[
2
]
*
8
!=
self
.
cond_hint
.
shape
[
2
]
or
x_noisy
.
shape
[
3
]
*
8
!=
self
.
cond_hint
.
shape
[
3
]:
if
self
.
cond_hint
is
not
None
:
del
self
.
cond_hint
self
.
cond_hint
=
None
self
.
cond_hint
=
utils
.
common_upscale
(
self
.
cond_hint_original
,
x_noisy
.
shape
[
3
]
*
8
,
x_noisy
.
shape
[
2
]
*
8
,
'nearest-exact'
,
"center"
).
to
(
x_noisy
.
device
)
print
(
"set cond_hint"
,
self
.
cond_hint
.
shape
)
control
=
self
.
control_model
(
x
=
x_noisy
,
hint
=
self
.
cond_hint
,
timesteps
=
t
,
context
=
cond_txt
)
self
.
cond_hint
=
utils
.
common_upscale
(
self
.
cond_hint_original
,
x_noisy
.
shape
[
3
]
*
8
,
x_noisy
.
shape
[
2
]
*
8
,
'nearest-exact'
,
"center"
).
to
(
self
.
control_model
.
dtype
).
to
(
self
.
device
)
if
self
.
control_model
.
dtype
==
torch
.
float16
:
precision_scope
=
torch
.
autocast
else
:
precision_scope
=
contextlib
.
nullcontext
with
precision_scope
(
self
.
device
):
control
=
self
.
control_model
(
x
=
x_noisy
,
hint
=
self
.
cond_hint
,
timesteps
=
t
,
context
=
cond_txt
)
out
=
[]
autocast_enabled
=
torch
.
is_autocast_enabled
()
for
x
in
control
:
x
*=
self
.
strength
return
control
if
x
.
dtype
!=
output_dtype
and
not
autocast_enabled
:
x
=
x
.
to
(
output_dtype
)
out
.
append
(
x
)
return
out
def
set_cond_hint
(
self
,
cond_hint
,
strength
=
1.0
):
self
.
cond_hint_original
=
cond_hint
...
...
@@ -377,6 +391,11 @@ def load_controlnet(ckpt_path):
return
None
context_dim
=
controlnet_data
[
key
].
shape
[
1
]
use_fp16
=
False
if
controlnet_data
[
key
].
dtype
==
torch
.
float16
:
use_fp16
=
True
control_model
=
cldm
.
ControlNet
(
image_size
=
32
,
in_channels
=
4
,
hint_channels
=
3
,
...
...
@@ -389,7 +408,8 @@ def load_controlnet(ckpt_path):
transformer_depth
=
1
,
context_dim
=
context_dim
,
use_checkpoint
=
True
,
legacy
=
False
)
legacy
=
False
,
use_fp16
=
use_fp16
)
if
pth
:
class
WeightsLoader
(
torch
.
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment