Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
62df8dd6
Commit
62df8dd6
authored
Feb 22, 2023
by
comfyanonymous
Browse files
Add a node to load diff controlnets.
parent
3ae61a2b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
1 deletion
+35
-1
comfy/sd.py
comfy/sd.py
+16
-1
nodes.py
nodes.py
+19
-0
No files found.
comfy/sd.py
View file @
62df8dd6
...
...
@@ -400,7 +400,7 @@ class ControlNet:
out
.
append
(
self
.
control_model
)
return
out
def
load_controlnet
(
ckpt_path
):
def
load_controlnet
(
ckpt_path
,
model
=
None
):
controlnet_data
=
load_torch_file
(
ckpt_path
)
pth_key
=
'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
pth
=
False
...
...
@@ -437,6 +437,21 @@ def load_controlnet(ckpt_path):
use_fp16
=
use_fp16
)
if
pth
:
if
'difference'
in
controlnet_data
:
if
model
is
not
None
:
m
=
model
.
patch_model
()
model_sd
=
m
.
state_dict
()
for
x
in
controlnet_data
:
c_m
=
"control_model."
if
x
.
startswith
(
c_m
):
sd_key
=
"model.diffusion_model.{}"
.
format
(
x
[
len
(
c_m
):])
if
sd_key
in
model_sd
:
cd
=
controlnet_data
[
x
]
cd
+=
model_sd
[
sd_key
].
type
(
cd
.
dtype
).
to
(
cd
.
device
)
model
.
unpatch_model
()
else
:
print
(
"WARNING: Loaded a diff controlnet without a model. It will very likely not work."
)
class
WeightsLoader
(
torch
.
nn
.
Module
):
pass
w
=
WeightsLoader
()
...
...
nodes.py
View file @
62df8dd6
...
...
@@ -232,6 +232,24 @@ class ControlNetLoader:
controlnet
=
comfy
.
sd
.
load_controlnet
(
controlnet_path
)
return
(
controlnet
,)
class
DiffControlNetLoader
:
models_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"models"
)
controlnet_dir
=
os
.
path
.
join
(
models_dir
,
"controlnet"
)
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"model"
:
(
"MODEL"
,),
"control_net_name"
:
(
filter_files_extensions
(
recursive_search
(
s
.
controlnet_dir
),
supported_pt_extensions
),
)}}
RETURN_TYPES
=
(
"CONTROL_NET"
,)
FUNCTION
=
"load_controlnet"
CATEGORY
=
"loaders"
def
load_controlnet
(
self
,
model
,
control_net_name
):
controlnet_path
=
os
.
path
.
join
(
self
.
controlnet_dir
,
control_net_name
)
controlnet
=
comfy
.
sd
.
load_controlnet
(
controlnet_path
,
model
)
return
(
controlnet
,)
class
ControlNetApply
:
@
classmethod
...
...
@@ -770,6 +788,7 @@ NODE_CLASS_MAPPINGS = {
"CLIPLoader"
:
CLIPLoader
,
"ControlNetApply"
:
ControlNetApply
,
"ControlNetLoader"
:
ControlNetLoader
,
"DiffControlNetLoader"
:
DiffControlNetLoader
,
}
CUSTOM_NODE_PATH
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"custom_nodes"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment