Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
7ff14b62
Commit
7ff14b62
authored
Jul 24, 2023
by
comfyanonymous
Browse files
ControlNetApplyAdvanced can now define when controlnet gets applied.
parent
d191c4f9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
79 additions
and
61 deletions
+79
-61
comfy/samplers.py
comfy/samplers.py
+12
-0
comfy/sd.py
comfy/sd.py
+63
-59
nodes.py
nodes.py
+4
-2
No files found.
comfy/samplers.py
View file @
7ff14b62
...
...
@@ -455,6 +455,16 @@ def calculate_start_end_timesteps(model, conds):
n
[
'timestep_end'
]
=
timestep_end
conds
[
t
]
=
[
x
[
0
],
n
]
def
pre_run_control
(
model
,
conds
):
for
t
in
range
(
len
(
conds
)):
x
=
conds
[
t
]
timestep_start
=
None
timestep_end
=
None
percent_to_timestep_function
=
lambda
a
:
model
.
sigma_to_t
(
model
.
t_to_sigma
(
torch
.
tensor
(
a
)
*
999.0
))
if
'control'
in
x
[
1
]:
x
[
1
][
'control'
].
pre_run
(
model
.
inner_model
,
percent_to_timestep_function
)
def
apply_empty_x_to_equal_area
(
conds
,
uncond
,
name
,
uncond_fill_func
):
cond_cnets
=
[]
cond_other
=
[]
...
...
@@ -607,6 +617,8 @@ class KSampler:
for
c
in
negative
:
create_cond_with_same_area_if_none
(
positive
,
c
)
pre_run_control
(
self
.
model_wrap
,
negative
+
positive
)
apply_empty_x_to_equal_area
(
list
(
filter
(
lambda
c
:
c
[
1
].
get
(
'control_apply_to_uncond'
,
False
)
==
True
,
positive
)),
negative
,
'control'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
apply_empty_x_to_equal_area
(
positive
,
negative
,
'gligen'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
...
...
comfy/sd.py
View file @
7ff14b62
...
...
@@ -673,16 +673,58 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
else
:
return
torch
.
cat
([
tensor
]
*
batched_number
,
dim
=
0
)
class
ControlNet
:
def
__init__
(
self
,
control_model
,
global_average_pooling
=
False
,
device
=
None
):
self
.
control_model
=
control_model
class
ControlBase
:
def
__init__
(
self
,
device
=
None
):
self
.
cond_hint_original
=
None
self
.
cond_hint
=
None
self
.
strength
=
1.0
self
.
timestep_percent_range
=
(
1.0
,
0.0
)
self
.
timestep_range
=
None
if
device
is
None
:
device
=
model_management
.
get_torch_device
()
self
.
device
=
device
self
.
previous_controlnet
=
None
def
set_cond_hint
(
self
,
cond_hint
,
strength
=
1.0
,
timestep_percent_range
=
(
1.0
,
0.0
)):
self
.
cond_hint_original
=
cond_hint
self
.
strength
=
strength
self
.
timestep_percent_range
=
timestep_percent_range
return
self
def
pre_run
(
self
,
model
,
percent_to_timestep_function
):
self
.
timestep_range
=
(
percent_to_timestep_function
(
self
.
timestep_percent_range
[
0
]),
percent_to_timestep_function
(
self
.
timestep_percent_range
[
1
]))
if
self
.
previous_controlnet
is
not
None
:
self
.
previous_controlnet
.
pre_run
(
model
,
percent_to_timestep_function
)
def
set_previous_controlnet
(
self
,
controlnet
):
self
.
previous_controlnet
=
controlnet
return
self
def
cleanup
(
self
):
if
self
.
previous_controlnet
is
not
None
:
self
.
previous_controlnet
.
cleanup
()
if
self
.
cond_hint
is
not
None
:
del
self
.
cond_hint
self
.
cond_hint
=
None
self
.
timestep_range
=
None
def
get_models
(
self
):
out
=
[]
if
self
.
previous_controlnet
is
not
None
:
out
+=
self
.
previous_controlnet
.
get_models
()
out
.
append
(
self
.
control_model
)
return
out
def
copy_to
(
self
,
c
):
c
.
cond_hint_original
=
self
.
cond_hint_original
c
.
strength
=
self
.
strength
c
.
timestep_percent_range
=
self
.
timestep_percent_range
class
ControlNet
(
ControlBase
):
def
__init__
(
self
,
control_model
,
global_average_pooling
=
False
,
device
=
None
):
super
().
__init__
(
device
)
self
.
control_model
=
control_model
self
.
global_average_pooling
=
global_average_pooling
def
get_control
(
self
,
x_noisy
,
t
,
cond
,
batched_number
):
...
...
@@ -690,6 +732,13 @@ class ControlNet:
if
self
.
previous_controlnet
is
not
None
:
control_prev
=
self
.
previous_controlnet
.
get_control
(
x_noisy
,
t
,
cond
,
batched_number
)
if
self
.
timestep_range
is
not
None
:
if
t
[
0
]
>
self
.
timestep_range
[
0
]
or
t
[
0
]
<
self
.
timestep_range
[
1
]:
if
control_prev
is
not
None
:
return
control_prev
else
:
return
{}
output_dtype
=
x_noisy
.
dtype
if
self
.
cond_hint
is
None
or
x_noisy
.
shape
[
2
]
*
8
!=
self
.
cond_hint
.
shape
[
2
]
or
x_noisy
.
shape
[
3
]
*
8
!=
self
.
cond_hint
.
shape
[
3
]:
if
self
.
cond_hint
is
not
None
:
...
...
@@ -737,35 +786,11 @@ class ControlNet:
out
[
'input'
]
=
control_prev
[
'input'
]
return
out
def
set_cond_hint
(
self
,
cond_hint
,
strength
=
1.0
):
self
.
cond_hint_original
=
cond_hint
self
.
strength
=
strength
return
self
def
set_previous_controlnet
(
self
,
controlnet
):
self
.
previous_controlnet
=
controlnet
return
self
def
cleanup
(
self
):
if
self
.
previous_controlnet
is
not
None
:
self
.
previous_controlnet
.
cleanup
()
if
self
.
cond_hint
is
not
None
:
del
self
.
cond_hint
self
.
cond_hint
=
None
def
copy
(
self
):
c
=
ControlNet
(
self
.
control_model
,
global_average_pooling
=
self
.
global_average_pooling
)
c
.
cond_hint_original
=
self
.
cond_hint_original
c
.
strength
=
self
.
strength
self
.
copy_to
(
c
)
return
c
def
get_models
(
self
):
out
=
[]
if
self
.
previous_controlnet
is
not
None
:
out
+=
self
.
previous_controlnet
.
get_models
()
out
.
append
(
self
.
control_model
)
return
out
def
load_controlnet
(
ckpt_path
,
model
=
None
):
controlnet_data
=
utils
.
load_torch_file
(
ckpt_path
,
safe_load
=
True
)
...
...
@@ -870,24 +895,25 @@ def load_controlnet(ckpt_path, model=None):
control
=
ControlNet
(
control_model
,
global_average_pooling
=
global_average_pooling
)
return
control
class
T2IAdapter
:
class
T2IAdapter
(
ControlBase
)
:
def
__init__
(
self
,
t2i_model
,
channels_in
,
device
=
None
):
super
().
__init__
(
device
)
self
.
t2i_model
=
t2i_model
self
.
channels_in
=
channels_in
self
.
strength
=
1.0
if
device
is
None
:
device
=
model_management
.
get_torch_device
()
self
.
device
=
device
self
.
previous_controlnet
=
None
self
.
control_input
=
None
self
.
cond_hint_original
=
None
self
.
cond_hint
=
None
def
get_control
(
self
,
x_noisy
,
t
,
cond
,
batched_number
):
control_prev
=
None
if
self
.
previous_controlnet
is
not
None
:
control_prev
=
self
.
previous_controlnet
.
get_control
(
x_noisy
,
t
,
cond
,
batched_number
)
if
self
.
timestep_range
is
not
None
:
if
t
[
0
]
>
self
.
timestep_range
[
0
]
or
t
[
0
]
<
self
.
timestep_range
[
1
]:
if
control_prev
is
not
None
:
return
control_prev
else
:
return
{}
if
self
.
cond_hint
is
None
or
x_noisy
.
shape
[
2
]
*
8
!=
self
.
cond_hint
.
shape
[
2
]
or
x_noisy
.
shape
[
3
]
*
8
!=
self
.
cond_hint
.
shape
[
3
]:
if
self
.
cond_hint
is
not
None
:
del
self
.
cond_hint
...
...
@@ -932,33 +958,11 @@ class T2IAdapter:
out
[
'output'
]
=
control_prev
[
'output'
]
return
out
def
set_cond_hint
(
self
,
cond_hint
,
strength
=
1.0
):
self
.
cond_hint_original
=
cond_hint
self
.
strength
=
strength
return
self
def
set_previous_controlnet
(
self
,
controlnet
):
self
.
previous_controlnet
=
controlnet
return
self
def
copy
(
self
):
c
=
T2IAdapter
(
self
.
t2i_model
,
self
.
channels_in
)
c
.
cond_hint_original
=
self
.
cond_hint_original
c
.
strength
=
self
.
strength
self
.
copy_to
(
c
)
return
c
def
cleanup
(
self
):
if
self
.
previous_controlnet
is
not
None
:
self
.
previous_controlnet
.
cleanup
()
if
self
.
cond_hint
is
not
None
:
del
self
.
cond_hint
self
.
cond_hint
=
None
def
get_models
(
self
):
out
=
[]
if
self
.
previous_controlnet
is
not
None
:
out
+=
self
.
previous_controlnet
.
get_models
()
return
out
def
load_t2i_adapter
(
t2i_data
):
keys
=
t2i_data
.
keys
()
...
...
nodes.py
View file @
7ff14b62
...
...
@@ -615,6 +615,8 @@ class ControlNetApplyAdvanced:
"control_net"
:
(
"CONTROL_NET"
,
),
"image"
:
(
"IMAGE"
,
),
"strength"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
10.0
,
"step"
:
0.01
}),
"start"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.001
}),
"end"
:
(
"FLOAT"
,
{
"default"
:
0.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.001
})
}}
RETURN_TYPES
=
(
"CONDITIONING"
,
"CONDITIONING"
)
...
...
@@ -623,7 +625,7 @@ class ControlNetApplyAdvanced:
CATEGORY
=
"conditioning"
def
apply_controlnet
(
self
,
positive
,
negative
,
control_net
,
image
,
strength
):
def
apply_controlnet
(
self
,
positive
,
negative
,
control_net
,
image
,
strength
,
start
,
end
):
if
strength
==
0
:
return
(
positive
,
negative
)
...
...
@@ -640,7 +642,7 @@ class ControlNetApplyAdvanced:
if
prev_cnet
in
cnets
:
c_net
=
cnets
[
prev_cnet
]
else
:
c_net
=
control_net
.
copy
().
set_cond_hint
(
control_hint
,
strength
)
c_net
=
control_net
.
copy
().
set_cond_hint
(
control_hint
,
strength
,
(
start
,
end
)
)
c_net
.
set_previous_controlnet
(
prev_cnet
)
cnets
[
prev_cnet
]
=
c_net
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment