Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
8270c625
Commit
8270c625
authored
Jul 16, 2024
by
comfyanonymous
Browse files
Add SetUnionControlNetType to set the type of the union controlnet model.
parent
821f9387
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
45 additions
and
1 deletion
+45
-1
comfy/controlnet.py
comfy/controlnet.py
+7
-1
comfy_extras/nodes_controlnet.py
comfy_extras/nodes_controlnet.py
+37
-0
nodes.py
nodes.py
+1
-0
No files found.
comfy/controlnet.py
View file @
8270c625
...
@@ -45,6 +45,7 @@ class ControlBase:
...
@@ -45,6 +45,7 @@ class ControlBase:
self
.
timestep_range
=
None
self
.
timestep_range
=
None
self
.
compression_ratio
=
8
self
.
compression_ratio
=
8
self
.
upscale_algorithm
=
'nearest-exact'
self
.
upscale_algorithm
=
'nearest-exact'
self
.
extra_args
=
{}
if
device
is
None
:
if
device
is
None
:
device
=
comfy
.
model_management
.
get_torch_device
()
device
=
comfy
.
model_management
.
get_torch_device
()
...
@@ -90,6 +91,7 @@ class ControlBase:
...
@@ -90,6 +91,7 @@ class ControlBase:
c
.
compression_ratio
=
self
.
compression_ratio
c
.
compression_ratio
=
self
.
compression_ratio
c
.
upscale_algorithm
=
self
.
upscale_algorithm
c
.
upscale_algorithm
=
self
.
upscale_algorithm
c
.
latent_format
=
self
.
latent_format
c
.
latent_format
=
self
.
latent_format
c
.
extra_args
=
self
.
extra_args
.
copy
()
c
.
vae
=
self
.
vae
c
.
vae
=
self
.
vae
def
inference_memory_requirements
(
self
,
dtype
):
def
inference_memory_requirements
(
self
,
dtype
):
...
@@ -135,6 +137,10 @@ class ControlBase:
...
@@ -135,6 +137,10 @@ class ControlBase:
o
[
i
]
=
prev_val
+
o
[
i
]
#TODO: change back to inplace add if shared tensors stop being an issue
o
[
i
]
=
prev_val
+
o
[
i
]
#TODO: change back to inplace add if shared tensors stop being an issue
return
out
return
out
def
set_extra_arg
(
self
,
argument
,
value
=
None
):
self
.
extra_args
[
argument
]
=
value
class
ControlNet
(
ControlBase
):
class
ControlNet
(
ControlBase
):
def
__init__
(
self
,
control_model
=
None
,
global_average_pooling
=
False
,
compression_ratio
=
8
,
latent_format
=
None
,
device
=
None
,
load_device
=
None
,
manual_cast_dtype
=
None
):
def
__init__
(
self
,
control_model
=
None
,
global_average_pooling
=
False
,
compression_ratio
=
8
,
latent_format
=
None
,
device
=
None
,
load_device
=
None
,
manual_cast_dtype
=
None
):
super
().
__init__
(
device
)
super
().
__init__
(
device
)
...
@@ -191,7 +197,7 @@ class ControlNet(ControlBase):
...
@@ -191,7 +197,7 @@ class ControlNet(ControlBase):
timestep
=
self
.
model_sampling_current
.
timestep
(
t
)
timestep
=
self
.
model_sampling_current
.
timestep
(
t
)
x_noisy
=
self
.
model_sampling_current
.
calculate_input
(
t
,
x_noisy
)
x_noisy
=
self
.
model_sampling_current
.
calculate_input
(
t
,
x_noisy
)
control
=
self
.
control_model
(
x
=
x_noisy
.
to
(
dtype
),
hint
=
self
.
cond_hint
,
timesteps
=
timestep
.
float
(),
context
=
context
.
to
(
dtype
),
y
=
y
)
control
=
self
.
control_model
(
x
=
x_noisy
.
to
(
dtype
),
hint
=
self
.
cond_hint
,
timesteps
=
timestep
.
float
(),
context
=
context
.
to
(
dtype
),
y
=
y
,
**
self
.
extra_args
)
return
self
.
control_merge
(
control
,
control_prev
,
output_dtype
)
return
self
.
control_merge
(
control
,
control_prev
,
output_dtype
)
def
copy
(
self
):
def
copy
(
self
):
...
...
comfy_extras/nodes_controlnet.py
0 → 100644
View file @
8270c625
UNION_CONTROLNET_TYPES
=
{
"auto"
:
-
1
,
"openpose"
:
0
,
"depth"
:
1
,
"hed/pidi/scribble/ted"
:
2
,
"canny/lineart/anime_lineart/mlsd"
:
3
,
"normal"
:
4
,
"segment"
:
5
,
"tile"
:
6
,
"repaint"
:
7
,
}
class
SetUnionControlNetType
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"control_net"
:
(
"CONTROL_NET"
,
),
"type"
:
(
list
(
UNION_CONTROLNET_TYPES
.
keys
()),)
}}
CATEGORY
=
"conditioning"
RETURN_TYPES
=
(
"CONTROL_NET"
,)
FUNCTION
=
"set_controlnet_type"
def
set_controlnet_type
(
self
,
control_net
,
type
):
control_net
=
control_net
.
copy
()
type_number
=
UNION_CONTROLNET_TYPES
[
type
]
if
type_number
>=
0
:
control_net
.
set_extra_arg
(
"control_type"
,
[
type_number
])
else
:
control_net
.
set_extra_arg
(
"control_type"
,
[])
return
(
control_net
,)
NODE_CLASS_MAPPINGS
=
{
"SetUnionControlNetType"
:
SetUnionControlNetType
,
}
nodes.py
View file @
8270c625
...
@@ -2036,6 +2036,7 @@ def init_builtin_extra_nodes():
...
@@ -2036,6 +2036,7 @@ def init_builtin_extra_nodes():
"nodes_audio.py"
,
"nodes_audio.py"
,
"nodes_sd3.py"
,
"nodes_sd3.py"
,
"nodes_gits.py"
,
"nodes_gits.py"
,
"nodes_controlnet.py"
,
]
]
import_failed
=
[]
import_failed
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment