Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
7ff14b62
"git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "82b0ebd48fe5ba2185bba38054fcfd2cb827d33f"
Commit
7ff14b62
authored
Jul 24, 2023
by
comfyanonymous
Browse files
ControlNetApplyAdvanced can now define when controlnet gets applied.
parent
d191c4f9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
79 additions
and
61 deletions
+79
-61
comfy/samplers.py
comfy/samplers.py
+12
-0
comfy/sd.py
comfy/sd.py
+63
-59
nodes.py
nodes.py
+4
-2
No files found.
comfy/samplers.py
View file @
7ff14b62
...
@@ -455,6 +455,16 @@ def calculate_start_end_timesteps(model, conds):
...
@@ -455,6 +455,16 @@ def calculate_start_end_timesteps(model, conds):
n
[
'timestep_end'
]
=
timestep_end
n
[
'timestep_end'
]
=
timestep_end
conds
[
t
]
=
[
x
[
0
],
n
]
conds
[
t
]
=
[
x
[
0
],
n
]
def
pre_run_control
(
model
,
conds
):
for
t
in
range
(
len
(
conds
)):
x
=
conds
[
t
]
timestep_start
=
None
timestep_end
=
None
percent_to_timestep_function
=
lambda
a
:
model
.
sigma_to_t
(
model
.
t_to_sigma
(
torch
.
tensor
(
a
)
*
999.0
))
if
'control'
in
x
[
1
]:
x
[
1
][
'control'
].
pre_run
(
model
.
inner_model
,
percent_to_timestep_function
)
def
apply_empty_x_to_equal_area
(
conds
,
uncond
,
name
,
uncond_fill_func
):
def
apply_empty_x_to_equal_area
(
conds
,
uncond
,
name
,
uncond_fill_func
):
cond_cnets
=
[]
cond_cnets
=
[]
cond_other
=
[]
cond_other
=
[]
...
@@ -607,6 +617,8 @@ class KSampler:
...
@@ -607,6 +617,8 @@ class KSampler:
for
c
in
negative
:
for
c
in
negative
:
create_cond_with_same_area_if_none
(
positive
,
c
)
create_cond_with_same_area_if_none
(
positive
,
c
)
pre_run_control
(
self
.
model_wrap
,
negative
+
positive
)
apply_empty_x_to_equal_area
(
list
(
filter
(
lambda
c
:
c
[
1
].
get
(
'control_apply_to_uncond'
,
False
)
==
True
,
positive
)),
negative
,
'control'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
apply_empty_x_to_equal_area
(
list
(
filter
(
lambda
c
:
c
[
1
].
get
(
'control_apply_to_uncond'
,
False
)
==
True
,
positive
)),
negative
,
'control'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
apply_empty_x_to_equal_area
(
positive
,
negative
,
'gligen'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
apply_empty_x_to_equal_area
(
positive
,
negative
,
'gligen'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
...
...
comfy/sd.py
View file @
7ff14b62
...
@@ -673,16 +673,58 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
...
@@ -673,16 +673,58 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
else
:
else
:
return
torch
.
cat
([
tensor
]
*
batched_number
,
dim
=
0
)
return
torch
.
cat
([
tensor
]
*
batched_number
,
dim
=
0
)
class
ControlNet
:
class
ControlBase
:
def
__init__
(
self
,
control_model
,
global_average_pooling
=
False
,
device
=
None
):
def
__init__
(
self
,
device
=
None
):
self
.
control_model
=
control_model
self
.
cond_hint_original
=
None
self
.
cond_hint_original
=
None
self
.
cond_hint
=
None
self
.
cond_hint
=
None
self
.
strength
=
1.0
self
.
strength
=
1.0
self
.
timestep_percent_range
=
(
1.0
,
0.0
)
self
.
timestep_range
=
None
if
device
is
None
:
if
device
is
None
:
device
=
model_management
.
get_torch_device
()
device
=
model_management
.
get_torch_device
()
self
.
device
=
device
self
.
device
=
device
self
.
previous_controlnet
=
None
self
.
previous_controlnet
=
None
def
set_cond_hint
(
self
,
cond_hint
,
strength
=
1.0
,
timestep_percent_range
=
(
1.0
,
0.0
)):
self
.
cond_hint_original
=
cond_hint
self
.
strength
=
strength
self
.
timestep_percent_range
=
timestep_percent_range
return
self
def
pre_run
(
self
,
model
,
percent_to_timestep_function
):
self
.
timestep_range
=
(
percent_to_timestep_function
(
self
.
timestep_percent_range
[
0
]),
percent_to_timestep_function
(
self
.
timestep_percent_range
[
1
]))
if
self
.
previous_controlnet
is
not
None
:
self
.
previous_controlnet
.
pre_run
(
model
,
percent_to_timestep_function
)
def
set_previous_controlnet
(
self
,
controlnet
):
self
.
previous_controlnet
=
controlnet
return
self
def
cleanup
(
self
):
if
self
.
previous_controlnet
is
not
None
:
self
.
previous_controlnet
.
cleanup
()
if
self
.
cond_hint
is
not
None
:
del
self
.
cond_hint
self
.
cond_hint
=
None
self
.
timestep_range
=
None
def
get_models
(
self
):
out
=
[]
if
self
.
previous_controlnet
is
not
None
:
out
+=
self
.
previous_controlnet
.
get_models
()
out
.
append
(
self
.
control_model
)
return
out
def
copy_to
(
self
,
c
):
c
.
cond_hint_original
=
self
.
cond_hint_original
c
.
strength
=
self
.
strength
c
.
timestep_percent_range
=
self
.
timestep_percent_range
class
ControlNet
(
ControlBase
):
def
__init__
(
self
,
control_model
,
global_average_pooling
=
False
,
device
=
None
):
super
().
__init__
(
device
)
self
.
control_model
=
control_model
self
.
global_average_pooling
=
global_average_pooling
self
.
global_average_pooling
=
global_average_pooling
def
get_control
(
self
,
x_noisy
,
t
,
cond
,
batched_number
):
def
get_control
(
self
,
x_noisy
,
t
,
cond
,
batched_number
):
...
@@ -690,6 +732,13 @@ class ControlNet:
...
@@ -690,6 +732,13 @@ class ControlNet:
if
self
.
previous_controlnet
is
not
None
:
if
self
.
previous_controlnet
is
not
None
:
control_prev
=
self
.
previous_controlnet
.
get_control
(
x_noisy
,
t
,
cond
,
batched_number
)
control_prev
=
self
.
previous_controlnet
.
get_control
(
x_noisy
,
t
,
cond
,
batched_number
)
if
self
.
timestep_range
is
not
None
:
if
t
[
0
]
>
self
.
timestep_range
[
0
]
or
t
[
0
]
<
self
.
timestep_range
[
1
]:
if
control_prev
is
not
None
:
return
control_prev
else
:
return
{}
output_dtype
=
x_noisy
.
dtype
output_dtype
=
x_noisy
.
dtype
if
self
.
cond_hint
is
None
or
x_noisy
.
shape
[
2
]
*
8
!=
self
.
cond_hint
.
shape
[
2
]
or
x_noisy
.
shape
[
3
]
*
8
!=
self
.
cond_hint
.
shape
[
3
]:
if
self
.
cond_hint
is
None
or
x_noisy
.
shape
[
2
]
*
8
!=
self
.
cond_hint
.
shape
[
2
]
or
x_noisy
.
shape
[
3
]
*
8
!=
self
.
cond_hint
.
shape
[
3
]:
if
self
.
cond_hint
is
not
None
:
if
self
.
cond_hint
is
not
None
:
...
@@ -737,35 +786,11 @@ class ControlNet:
...
@@ -737,35 +786,11 @@ class ControlNet:
out
[
'input'
]
=
control_prev
[
'input'
]
out
[
'input'
]
=
control_prev
[
'input'
]
return
out
return
out
def
set_cond_hint
(
self
,
cond_hint
,
strength
=
1.0
):
self
.
cond_hint_original
=
cond_hint
self
.
strength
=
strength
return
self
def
set_previous_controlnet
(
self
,
controlnet
):
self
.
previous_controlnet
=
controlnet
return
self
def
cleanup
(
self
):
if
self
.
previous_controlnet
is
not
None
:
self
.
previous_controlnet
.
cleanup
()
if
self
.
cond_hint
is
not
None
:
del
self
.
cond_hint
self
.
cond_hint
=
None
def
copy
(
self
):
def
copy
(
self
):
c
=
ControlNet
(
self
.
control_model
,
global_average_pooling
=
self
.
global_average_pooling
)
c
=
ControlNet
(
self
.
control_model
,
global_average_pooling
=
self
.
global_average_pooling
)
c
.
cond_hint_original
=
self
.
cond_hint_original
self
.
copy_to
(
c
)
c
.
strength
=
self
.
strength
return
c
return
c
def
get_models
(
self
):
out
=
[]
if
self
.
previous_controlnet
is
not
None
:
out
+=
self
.
previous_controlnet
.
get_models
()
out
.
append
(
self
.
control_model
)
return
out
def
load_controlnet
(
ckpt_path
,
model
=
None
):
def
load_controlnet
(
ckpt_path
,
model
=
None
):
controlnet_data
=
utils
.
load_torch_file
(
ckpt_path
,
safe_load
=
True
)
controlnet_data
=
utils
.
load_torch_file
(
ckpt_path
,
safe_load
=
True
)
...
@@ -870,24 +895,25 @@ def load_controlnet(ckpt_path, model=None):
...
@@ -870,24 +895,25 @@ def load_controlnet(ckpt_path, model=None):
control
=
ControlNet
(
control_model
,
global_average_pooling
=
global_average_pooling
)
control
=
ControlNet
(
control_model
,
global_average_pooling
=
global_average_pooling
)
return
control
return
control
class
T2IAdapter
:
class
T2IAdapter
(
ControlBase
)
:
def
__init__
(
self
,
t2i_model
,
channels_in
,
device
=
None
):
def
__init__
(
self
,
t2i_model
,
channels_in
,
device
=
None
):
super
().
__init__
(
device
)
self
.
t2i_model
=
t2i_model
self
.
t2i_model
=
t2i_model
self
.
channels_in
=
channels_in
self
.
channels_in
=
channels_in
self
.
strength
=
1.0
if
device
is
None
:
device
=
model_management
.
get_torch_device
()
self
.
device
=
device
self
.
previous_controlnet
=
None
self
.
control_input
=
None
self
.
control_input
=
None
self
.
cond_hint_original
=
None
self
.
cond_hint
=
None
def
get_control
(
self
,
x_noisy
,
t
,
cond
,
batched_number
):
def
get_control
(
self
,
x_noisy
,
t
,
cond
,
batched_number
):
control_prev
=
None
control_prev
=
None
if
self
.
previous_controlnet
is
not
None
:
if
self
.
previous_controlnet
is
not
None
:
control_prev
=
self
.
previous_controlnet
.
get_control
(
x_noisy
,
t
,
cond
,
batched_number
)
control_prev
=
self
.
previous_controlnet
.
get_control
(
x_noisy
,
t
,
cond
,
batched_number
)
if
self
.
timestep_range
is
not
None
:
if
t
[
0
]
>
self
.
timestep_range
[
0
]
or
t
[
0
]
<
self
.
timestep_range
[
1
]:
if
control_prev
is
not
None
:
return
control_prev
else
:
return
{}
if
self
.
cond_hint
is
None
or
x_noisy
.
shape
[
2
]
*
8
!=
self
.
cond_hint
.
shape
[
2
]
or
x_noisy
.
shape
[
3
]
*
8
!=
self
.
cond_hint
.
shape
[
3
]:
if
self
.
cond_hint
is
None
or
x_noisy
.
shape
[
2
]
*
8
!=
self
.
cond_hint
.
shape
[
2
]
or
x_noisy
.
shape
[
3
]
*
8
!=
self
.
cond_hint
.
shape
[
3
]:
if
self
.
cond_hint
is
not
None
:
if
self
.
cond_hint
is
not
None
:
del
self
.
cond_hint
del
self
.
cond_hint
...
@@ -932,33 +958,11 @@ class T2IAdapter:
...
@@ -932,33 +958,11 @@ class T2IAdapter:
out
[
'output'
]
=
control_prev
[
'output'
]
out
[
'output'
]
=
control_prev
[
'output'
]
return
out
return
out
def
set_cond_hint
(
self
,
cond_hint
,
strength
=
1.0
):
self
.
cond_hint_original
=
cond_hint
self
.
strength
=
strength
return
self
def
set_previous_controlnet
(
self
,
controlnet
):
self
.
previous_controlnet
=
controlnet
return
self
def
copy
(
self
):
def
copy
(
self
):
c
=
T2IAdapter
(
self
.
t2i_model
,
self
.
channels_in
)
c
=
T2IAdapter
(
self
.
t2i_model
,
self
.
channels_in
)
c
.
cond_hint_original
=
self
.
cond_hint_original
self
.
copy_to
(
c
)
c
.
strength
=
self
.
strength
return
c
return
c
def
cleanup
(
self
):
if
self
.
previous_controlnet
is
not
None
:
self
.
previous_controlnet
.
cleanup
()
if
self
.
cond_hint
is
not
None
:
del
self
.
cond_hint
self
.
cond_hint
=
None
def
get_models
(
self
):
out
=
[]
if
self
.
previous_controlnet
is
not
None
:
out
+=
self
.
previous_controlnet
.
get_models
()
return
out
def
load_t2i_adapter
(
t2i_data
):
def
load_t2i_adapter
(
t2i_data
):
keys
=
t2i_data
.
keys
()
keys
=
t2i_data
.
keys
()
...
...
nodes.py
View file @
7ff14b62
...
@@ -615,6 +615,8 @@ class ControlNetApplyAdvanced:
...
@@ -615,6 +615,8 @@ class ControlNetApplyAdvanced:
"control_net"
:
(
"CONTROL_NET"
,
),
"control_net"
:
(
"CONTROL_NET"
,
),
"image"
:
(
"IMAGE"
,
),
"image"
:
(
"IMAGE"
,
),
"strength"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
10.0
,
"step"
:
0.01
}),
"strength"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
10.0
,
"step"
:
0.01
}),
"start"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.001
}),
"end"
:
(
"FLOAT"
,
{
"default"
:
0.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.001
})
}}
}}
RETURN_TYPES
=
(
"CONDITIONING"
,
"CONDITIONING"
)
RETURN_TYPES
=
(
"CONDITIONING"
,
"CONDITIONING"
)
...
@@ -623,7 +625,7 @@ class ControlNetApplyAdvanced:
...
@@ -623,7 +625,7 @@ class ControlNetApplyAdvanced:
CATEGORY
=
"conditioning"
CATEGORY
=
"conditioning"
def
apply_controlnet
(
self
,
positive
,
negative
,
control_net
,
image
,
strength
):
def
apply_controlnet
(
self
,
positive
,
negative
,
control_net
,
image
,
strength
,
start
,
end
):
if
strength
==
0
:
if
strength
==
0
:
return
(
positive
,
negative
)
return
(
positive
,
negative
)
...
@@ -640,7 +642,7 @@ class ControlNetApplyAdvanced:
...
@@ -640,7 +642,7 @@ class ControlNetApplyAdvanced:
if
prev_cnet
in
cnets
:
if
prev_cnet
in
cnets
:
c_net
=
cnets
[
prev_cnet
]
c_net
=
cnets
[
prev_cnet
]
else
:
else
:
c_net
=
control_net
.
copy
().
set_cond_hint
(
control_hint
,
strength
)
c_net
=
control_net
.
copy
().
set_cond_hint
(
control_hint
,
strength
,
(
start
,
end
)
)
c_net
.
set_previous_controlnet
(
prev_cnet
)
c_net
.
set_previous_controlnet
(
prev_cnet
)
cnets
[
prev_cnet
]
=
c_net
cnets
[
prev_cnet
]
=
c_net
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment