Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
d6dee8af
Commit
d6dee8af
authored
May 10, 2023
by
comfyanonymous
Browse files
Only validate each input once.
parent
02ca1c67
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
24 deletions
+20
-24
execution.py
execution.py
+18
-22
main.py
main.py
+1
-1
server.py
server.py
+1
-1
No files found.
execution.py
View file @
d6dee8af
...
...
@@ -147,7 +147,7 @@ class PromptExecutor:
self
.
old_prompt
=
{}
self
.
server
=
server
def
execute
(
self
,
prompt
,
extra_data
=
{}):
def
execute
(
self
,
prompt
,
extra_data
=
{}
,
execute_outputs
=
[]
):
nodes
.
interrupt_processing
(
False
)
if
"client_id"
in
extra_data
:
...
...
@@ -172,27 +172,15 @@ class PromptExecutor:
executed
=
set
()
try
:
to_execute
=
[]
for
x
in
prompt
:
class_
=
nodes
.
NODE_CLASS_MAPPINGS
[
prompt
[
x
][
'class_type'
]]
if
hasattr
(
class_
,
'OUTPUT_NODE'
):
to_execute
+=
[(
0
,
x
)]
for
x
in
list
(
execute_outputs
):
to_execute
+=
[(
0
,
x
)]
while
len
(
to_execute
)
>
0
:
#always execute the output that depends on the least amount of unexecuted nodes first
to_execute
=
sorted
(
list
(
map
(
lambda
a
:
(
len
(
recursive_will_execute
(
prompt
,
self
.
outputs
,
a
[
-
1
])),
a
[
-
1
]),
to_execute
)))
x
=
to_execute
.
pop
(
0
)[
-
1
]
class_
=
nodes
.
NODE_CLASS_MAPPINGS
[
prompt
[
x
][
'class_type'
]]
if
hasattr
(
class_
,
'OUTPUT_NODE'
):
if
class_
.
OUTPUT_NODE
==
True
:
valid
=
False
try
:
m
=
validate_inputs
(
prompt
,
x
)
valid
=
m
[
0
]
except
:
valid
=
False
if
valid
:
recursive_execute
(
self
.
server
,
prompt
,
self
.
outputs
,
x
,
extra_data
,
executed
)
recursive_execute
(
self
.
server
,
prompt
,
self
.
outputs
,
x
,
extra_data
,
executed
)
except
Exception
as
e
:
if
isinstance
(
e
,
comfy
.
model_management
.
InterruptProcessingException
):
print
(
"Processing interrupted"
)
...
...
@@ -219,8 +207,11 @@ class PromptExecutor:
comfy
.
model_management
.
soft_empty_cache
()
def
validate_inputs
(
prompt
,
item
):
def
validate_inputs
(
prompt
,
item
,
validated
):
unique_id
=
item
if
unique_id
in
validated
:
return
validated
[
unique_id
]
inputs
=
prompt
[
unique_id
][
'inputs'
]
class_type
=
prompt
[
unique_id
][
'class_type'
]
obj_class
=
nodes
.
NODE_CLASS_MAPPINGS
[
class_type
]
...
...
@@ -241,8 +232,9 @@ def validate_inputs(prompt, item):
r
=
nodes
.
NODE_CLASS_MAPPINGS
[
o_class_type
].
RETURN_TYPES
if
r
[
val
[
1
]]
!=
type_input
:
return
(
False
,
"Return type mismatch. {}, {}, {} != {}"
.
format
(
class_type
,
x
,
r
[
val
[
1
]],
type_input
))
r
=
validate_inputs
(
prompt
,
o_id
)
r
=
validate_inputs
(
prompt
,
o_id
,
validated
)
if
r
[
0
]
==
False
:
validated
[
o_id
]
=
r
return
r
else
:
if
type_input
==
"INT"
:
...
...
@@ -270,7 +262,10 @@ def validate_inputs(prompt, item):
if
isinstance
(
type_input
,
list
):
if
val
not
in
type_input
:
return
(
False
,
"Value not in list. {}, {}: {} not in {}"
.
format
(
class_type
,
x
,
val
,
type_input
))
return
(
True
,
""
)
ret
=
(
True
,
""
)
validated
[
unique_id
]
=
ret
return
ret
def
validate_prompt
(
prompt
):
outputs
=
set
()
...
...
@@ -284,11 +279,12 @@ def validate_prompt(prompt):
good_outputs
=
set
()
errors
=
[]
validated
=
{}
for
o
in
outputs
:
valid
=
False
reason
=
""
try
:
m
=
validate_inputs
(
prompt
,
o
)
m
=
validate_inputs
(
prompt
,
o
,
validated
)
valid
=
m
[
0
]
reason
=
m
[
1
]
except
Exception
as
e
:
...
...
@@ -297,7 +293,7 @@ def validate_prompt(prompt):
reason
=
"Parsing error"
if
valid
==
True
:
good_outputs
.
add
(
x
)
good_outputs
.
add
(
o
)
else
:
print
(
"Failed to validate prompt for output {} {}"
.
format
(
o
,
reason
))
print
(
"output will be ignored"
)
...
...
@@ -307,7 +303,7 @@ def validate_prompt(prompt):
errors_list
=
"
\n
"
.
join
(
set
(
map
(
lambda
a
:
"{}"
.
format
(
a
[
1
]),
errors
)))
return
(
False
,
"Prompt has no properly connected outputs
\n
{}"
.
format
(
errors_list
))
return
(
True
,
""
)
return
(
True
,
""
,
list
(
good_outputs
)
)
class
PromptQueue
:
...
...
main.py
View file @
d6dee8af
...
...
@@ -33,7 +33,7 @@ def prompt_worker(q, server):
e
=
execution
.
PromptExecutor
(
server
)
while
True
:
item
,
item_id
=
q
.
get
()
e
.
execute
(
item
[
-
2
],
item
[
-
1
])
e
.
execute
(
item
[
-
3
],
item
[
-
2
],
item
[
-
1
])
q
.
task_done
(
item_id
,
e
.
outputs
)
async
def
run
(
server
,
address
=
''
,
port
=
8188
,
verbose
=
True
,
call_on_start
=
None
):
...
...
server.py
View file @
d6dee8af
...
...
@@ -312,7 +312,7 @@ class PromptServer():
if
"client_id"
in
json_data
:
extra_data
[
"client_id"
]
=
json_data
[
"client_id"
]
if
valid
[
0
]:
self
.
prompt_queue
.
put
((
number
,
id
(
prompt
),
prompt
,
extra_data
))
self
.
prompt_queue
.
put
((
number
,
id
(
prompt
),
prompt
,
extra_data
,
valid
[
2
]
))
else
:
resp_code
=
400
out_string
=
valid
[
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment