Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
7ff14b62
Commit
7ff14b62
authored
Jul 24, 2023
by
comfyanonymous
Browse files
ControlNetApplyAdvanced can now define when controlnet gets applied.
parent
d191c4f9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
79 additions
and
61 deletions
+79
-61
comfy/samplers.py
comfy/samplers.py
+12
-0
comfy/sd.py
comfy/sd.py
+63
-59
nodes.py
nodes.py
+4
-2
No files found.
comfy/samplers.py
View file @
7ff14b62
...
@@ -455,6 +455,16 @@ def calculate_start_end_timesteps(model, conds):
...
@@ -455,6 +455,16 @@ def calculate_start_end_timesteps(model, conds):
n
[
'timestep_end'
]
=
timestep_end
n
[
'timestep_end'
]
=
timestep_end
conds
[
t
]
=
[
x
[
0
],
n
]
conds
[
t
]
=
[
x
[
0
],
n
]
def
pre_run_control
(
model
,
conds
):
for
t
in
range
(
len
(
conds
)):
x
=
conds
[
t
]
timestep_start
=
None
timestep_end
=
None
percent_to_timestep_function
=
lambda
a
:
model
.
sigma_to_t
(
model
.
t_to_sigma
(
torch
.
tensor
(
a
)
*
999.0
))
if
'control'
in
x
[
1
]:
x
[
1
][
'control'
].
pre_run
(
model
.
inner_model
,
percent_to_timestep_function
)
def
apply_empty_x_to_equal_area
(
conds
,
uncond
,
name
,
uncond_fill_func
):
def
apply_empty_x_to_equal_area
(
conds
,
uncond
,
name
,
uncond_fill_func
):
cond_cnets
=
[]
cond_cnets
=
[]
cond_other
=
[]
cond_other
=
[]
...
@@ -607,6 +617,8 @@ class KSampler:
...
@@ -607,6 +617,8 @@ class KSampler:
for
c
in
negative
:
for
c
in
negative
:
create_cond_with_same_area_if_none
(
positive
,
c
)
create_cond_with_same_area_if_none
(
positive
,
c
)
pre_run_control
(
self
.
model_wrap
,
negative
+
positive
)
apply_empty_x_to_equal_area
(
list
(
filter
(
lambda
c
:
c
[
1
].
get
(
'control_apply_to_uncond'
,
False
)
==
True
,
positive
)),
negative
,
'control'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
apply_empty_x_to_equal_area
(
list
(
filter
(
lambda
c
:
c
[
1
].
get
(
'control_apply_to_uncond'
,
False
)
==
True
,
positive
)),
negative
,
'control'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
apply_empty_x_to_equal_area
(
positive
,
negative
,
'gligen'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
apply_empty_x_to_equal_area
(
positive
,
negative
,
'gligen'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
...
...
comfy/sd.py
View file @
7ff14b62
...
@@ -673,16 +673,58 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
...
@@ -673,16 +673,58 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
else
:
else
:
return
torch
.
cat
([
tensor
]
*
batched_number
,
dim
=
0
)
return
torch
.
cat
([
tensor
]
*
batched_number
,
dim
=
0
)
class
ControlNet
:
class
ControlBase
:
def
__init__
(
self
,
control_model
,
global_average_pooling
=
False
,
device
=
None
):
def
__init__
(
self
,
device
=
None
):
self
.
control_model
=
control_model
self
.
cond_hint_original
=
None
self
.
cond_hint_original
=
None
self
.
cond_hint
=
None
self
.
cond_hint
=
None
self
.
strength
=
1.0
self
.
strength
=
1.0
self
.
timestep_percent_range
=
(
1.0
,
0.0
)
self
.
timestep_range
=
None
if
device
is
None
:
if
device
is
None
:
device
=
model_management
.
get_torch_device
()
device
=
model_management
.
get_torch_device
()
self
.
device
=
device
self
.
device
=
device
self
.
previous_controlnet
=
None
self
.
previous_controlnet
=
None
def
set_cond_hint
(
self
,
cond_hint
,
strength
=
1.0
,
timestep_percent_range
=
(
1.0
,
0.0
)):
self
.
cond_hint_original
=
cond_hint
self
.
strength
=
strength
self
.
timestep_percent_range
=
timestep_percent_range
return
self
def
pre_run
(
self
,
model
,
percent_to_timestep_function
):
self
.
timestep_range
=
(
percent_to_timestep_function
(
self
.
timestep_percent_range
[
0
]),
percent_to_timestep_function
(
self
.
timestep_percent_range
[
1
]))
if
self
.
previous_controlnet
is
not
None
:
self
.
previous_controlnet
.
pre_run
(
model
,
percent_to_timestep_function
)
def
set_previous_controlnet
(
self
,
controlnet
):
self
.
previous_controlnet
=
controlnet
return
self
def
cleanup
(
self
):
if
self
.
previous_controlnet
is
not
None
:
self
.
previous_controlnet
.
cleanup
()
if
self
.
cond_hint
is
not
None
:
del
self
.
cond_hint
self
.
cond_hint
=
None
self
.
timestep_range
=
None
def
get_models
(
self
):
out
=
[]
if
self
.
previous_controlnet
is
not
None
:
out
+=
self
.
previous_controlnet
.
get_models
()
out
.
append
(
self
.
control_model
)
return
out
def
copy_to
(
self
,
c
):
c
.
cond_hint_original
=
self
.
cond_hint_original
c
.
strength
=
self
.
strength
c
.
timestep_percent_range
=
self
.
timestep_percent_range
class
ControlNet
(
ControlBase
):
def
__init__
(
self
,
control_model
,
global_average_pooling
=
False
,
device
=
None
):
super
().
__init__
(
device
)
self
.
control_model
=
control_model
self
.
global_average_pooling
=
global_average_pooling
self
.
global_average_pooling
=
global_average_pooling
def
get_control
(
self
,
x_noisy
,
t
,
cond
,
batched_number
):
def
get_control
(
self
,
x_noisy
,
t
,
cond
,
batched_number
):
...
@@ -690,6 +732,13 @@ class ControlNet:
...
@@ -690,6 +732,13 @@ class ControlNet:
if
self
.
previous_controlnet
is
not
None
:
if
self
.
previous_controlnet
is
not
None
:
control_prev
=
self
.
previous_controlnet
.
get_control
(
x_noisy
,
t
,
cond
,
batched_number
)
control_prev
=
self
.
previous_controlnet
.
get_control
(
x_noisy
,
t
,
cond
,
batched_number
)
if
self
.
timestep_range
is
not
None
:
if
t
[
0
]
>
self
.
timestep_range
[
0
]
or
t
[
0
]
<
self
.
timestep_range
[
1
]:
if
control_prev
is
not
None
:
return
control_prev
else
:
return
{}
output_dtype
=
x_noisy
.
dtype
output_dtype
=
x_noisy
.
dtype
if
self
.
cond_hint
is
None
or
x_noisy
.
shape
[
2
]
*
8
!=
self
.
cond_hint
.
shape
[
2
]
or
x_noisy
.
shape
[
3
]
*
8
!=
self
.
cond_hint
.
shape
[
3
]:
if
self
.
cond_hint
is
None
or
x_noisy
.
shape
[
2
]
*
8
!=
self
.
cond_hint
.
shape
[
2
]
or
x_noisy
.
shape
[
3
]
*
8
!=
self
.
cond_hint
.
shape
[
3
]:
if
self
.
cond_hint
is
not
None
:
if
self
.
cond_hint
is
not
None
:
...
@@ -737,35 +786,11 @@ class ControlNet:
...
@@ -737,35 +786,11 @@ class ControlNet:
out
[
'input'
]
=
control_prev
[
'input'
]
out
[
'input'
]
=
control_prev
[
'input'
]
return
out
return
out
def
set_cond_hint
(
self
,
cond_hint
,
strength
=
1.0
):
self
.
cond_hint_original
=
cond_hint
self
.
strength
=
strength
return
self
def
set_previous_controlnet
(
self
,
controlnet
):
self
.
previous_controlnet
=
controlnet
return
self
def
cleanup
(
self
):
if
self
.
previous_controlnet
is
not
None
:
self
.
previous_controlnet
.
cleanup
()
if
self
.
cond_hint
is
not
None
:
del
self
.
cond_hint
self
.
cond_hint
=
None
def
copy
(
self
):
def
copy
(
self
):
c
=
ControlNet
(
self
.
control_model
,
global_average_pooling
=
self
.
global_average_pooling
)
c
=
ControlNet
(
self
.
control_model
,
global_average_pooling
=
self
.
global_average_pooling
)
c
.
cond_hint_original
=
self
.
cond_hint_original
self
.
copy_to
(
c
)
c
.
strength
=
self
.
strength
return
c
return
c
def
get_models
(
self
):
out
=
[]
if
self
.
previous_controlnet
is
not
None
:
out
+=
self
.
previous_controlnet
.
get_models
()
out
.
append
(
self
.
control_model
)
return
out
def
load_controlnet
(
ckpt_path
,
model
=
None
):
def
load_controlnet
(
ckpt_path
,
model
=
None
):
controlnet_data
=
utils
.
load_torch_file
(
ckpt_path
,
safe_load
=
True
)
controlnet_data
=
utils
.
load_torch_file
(
ckpt_path
,
safe_load
=
True
)
...
@@ -870,24 +895,25 @@ def load_controlnet(ckpt_path, model=None):
...
@@ -870,24 +895,25 @@ def load_controlnet(ckpt_path, model=None):
control
=
ControlNet
(
control_model
,
global_average_pooling
=
global_average_pooling
)
control
=
ControlNet
(
control_model
,
global_average_pooling
=
global_average_pooling
)
return
control
return
control
class
T2IAdapter
:
class
T2IAdapter
(
ControlBase
)
:
def
__init__
(
self
,
t2i_model
,
channels_in
,
device
=
None
):
def
__init__
(
self
,
t2i_model
,
channels_in
,
device
=
None
):
super
().
__init__
(
device
)
self
.
t2i_model
=
t2i_model
self
.
t2i_model
=
t2i_model
self
.
channels_in
=
channels_in
self
.
channels_in
=
channels_in
self
.
strength
=
1.0
if
device
is
None
:
device
=
model_management
.
get_torch_device
()
self
.
device
=
device
self
.
previous_controlnet
=
None
self
.
control_input
=
None
self
.
control_input
=
None
self
.
cond_hint_original
=
None
self
.
cond_hint
=
None
def
get_control
(
self
,
x_noisy
,
t
,
cond
,
batched_number
):
def
get_control
(
self
,
x_noisy
,
t
,
cond
,
batched_number
):
control_prev
=
None
control_prev
=
None
if
self
.
previous_controlnet
is
not
None
:
if
self
.
previous_controlnet
is
not
None
:
control_prev
=
self
.
previous_controlnet
.
get_control
(
x_noisy
,
t
,
cond
,
batched_number
)
control_prev
=
self
.
previous_controlnet
.
get_control
(
x_noisy
,
t
,
cond
,
batched_number
)
if
self
.
timestep_range
is
not
None
:
if
t
[
0
]
>
self
.
timestep_range
[
0
]
or
t
[
0
]
<
self
.
timestep_range
[
1
]:
if
control_prev
is
not
None
:
return
control_prev
else
:
return
{}
if
self
.
cond_hint
is
None
or
x_noisy
.
shape
[
2
]
*
8
!=
self
.
cond_hint
.
shape
[
2
]
or
x_noisy
.
shape
[
3
]
*
8
!=
self
.
cond_hint
.
shape
[
3
]:
if
self
.
cond_hint
is
None
or
x_noisy
.
shape
[
2
]
*
8
!=
self
.
cond_hint
.
shape
[
2
]
or
x_noisy
.
shape
[
3
]
*
8
!=
self
.
cond_hint
.
shape
[
3
]:
if
self
.
cond_hint
is
not
None
:
if
self
.
cond_hint
is
not
None
:
del
self
.
cond_hint
del
self
.
cond_hint
...
@@ -932,33 +958,11 @@ class T2IAdapter:
...
@@ -932,33 +958,11 @@ class T2IAdapter:
out
[
'output'
]
=
control_prev
[
'output'
]
out
[
'output'
]
=
control_prev
[
'output'
]
return
out
return
out
def
set_cond_hint
(
self
,
cond_hint
,
strength
=
1.0
):
self
.
cond_hint_original
=
cond_hint
self
.
strength
=
strength
return
self
def
set_previous_controlnet
(
self
,
controlnet
):
self
.
previous_controlnet
=
controlnet
return
self
def
copy
(
self
):
def
copy
(
self
):
c
=
T2IAdapter
(
self
.
t2i_model
,
self
.
channels_in
)
c
=
T2IAdapter
(
self
.
t2i_model
,
self
.
channels_in
)
c
.
cond_hint_original
=
self
.
cond_hint_original
self
.
copy_to
(
c
)
c
.
strength
=
self
.
strength
return
c
return
c
def
cleanup
(
self
):
if
self
.
previous_controlnet
is
not
None
:
self
.
previous_controlnet
.
cleanup
()
if
self
.
cond_hint
is
not
None
:
del
self
.
cond_hint
self
.
cond_hint
=
None
def
get_models
(
self
):
out
=
[]
if
self
.
previous_controlnet
is
not
None
:
out
+=
self
.
previous_controlnet
.
get_models
()
return
out
def
load_t2i_adapter
(
t2i_data
):
def
load_t2i_adapter
(
t2i_data
):
keys
=
t2i_data
.
keys
()
keys
=
t2i_data
.
keys
()
...
...
nodes.py
View file @
7ff14b62
...
@@ -615,6 +615,8 @@ class ControlNetApplyAdvanced:
...
@@ -615,6 +615,8 @@ class ControlNetApplyAdvanced:
"control_net"
:
(
"CONTROL_NET"
,
),
"control_net"
:
(
"CONTROL_NET"
,
),
"image"
:
(
"IMAGE"
,
),
"image"
:
(
"IMAGE"
,
),
"strength"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
10.0
,
"step"
:
0.01
}),
"strength"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
10.0
,
"step"
:
0.01
}),
"start"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.001
}),
"end"
:
(
"FLOAT"
,
{
"default"
:
0.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.001
})
}}
}}
RETURN_TYPES
=
(
"CONDITIONING"
,
"CONDITIONING"
)
RETURN_TYPES
=
(
"CONDITIONING"
,
"CONDITIONING"
)
...
@@ -623,7 +625,7 @@ class ControlNetApplyAdvanced:
...
@@ -623,7 +625,7 @@ class ControlNetApplyAdvanced:
CATEGORY
=
"conditioning"
CATEGORY
=
"conditioning"
def
apply_controlnet
(
self
,
positive
,
negative
,
control_net
,
image
,
strength
):
def
apply_controlnet
(
self
,
positive
,
negative
,
control_net
,
image
,
strength
,
start
,
end
):
if
strength
==
0
:
if
strength
==
0
:
return
(
positive
,
negative
)
return
(
positive
,
negative
)
...
@@ -640,7 +642,7 @@ class ControlNetApplyAdvanced:
...
@@ -640,7 +642,7 @@ class ControlNetApplyAdvanced:
if
prev_cnet
in
cnets
:
if
prev_cnet
in
cnets
:
c_net
=
cnets
[
prev_cnet
]
c_net
=
cnets
[
prev_cnet
]
else
:
else
:
c_net
=
control_net
.
copy
().
set_cond_hint
(
control_hint
,
strength
)
c_net
=
control_net
.
copy
().
set_cond_hint
(
control_hint
,
strength
,
(
start
,
end
)
)
c_net
.
set_previous_controlnet
(
prev_cnet
)
c_net
.
set_previous_controlnet
(
prev_cnet
)
cnets
[
prev_cnet
]
=
c_net
cnets
[
prev_cnet
]
=
c_net
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment