Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
49d2e5bb
Commit
49d2e5bb
authored
Feb 27, 2023
by
comfyanonymous
Browse files
Move some stuff from main.py to execution.py
parent
f30a2deb
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
358 additions
and
352 deletions
+358
-352
execution.py
execution.py
+352
-0
main.py
main.py
+4
-350
server.py
server.py
+2
-2
No files found.
execution.py
0 → 100644
View file @
49d2e5bb
import
os
import
sys
import
copy
import
json
import
threading
import
heapq
import
traceback
import
torch
import
nodes
def
get_input_data
(
inputs
,
class_def
,
outputs
=
{},
prompt
=
{},
extra_data
=
{}):
valid_inputs
=
class_def
.
INPUT_TYPES
()
input_data_all
=
{}
for
x
in
inputs
:
input_data
=
inputs
[
x
]
if
isinstance
(
input_data
,
list
):
input_unique_id
=
input_data
[
0
]
output_index
=
input_data
[
1
]
obj
=
outputs
[
input_unique_id
][
output_index
]
input_data_all
[
x
]
=
obj
else
:
if
(
"required"
in
valid_inputs
and
x
in
valid_inputs
[
"required"
])
or
(
"optional"
in
valid_inputs
and
x
in
valid_inputs
[
"optional"
]):
input_data_all
[
x
]
=
input_data
if
"hidden"
in
valid_inputs
:
h
=
valid_inputs
[
"hidden"
]
for
x
in
h
:
if
h
[
x
]
==
"PROMPT"
:
input_data_all
[
x
]
=
prompt
if
h
[
x
]
==
"EXTRA_PNGINFO"
:
if
"extra_pnginfo"
in
extra_data
:
input_data_all
[
x
]
=
extra_data
[
'extra_pnginfo'
]
return
input_data_all
def
recursive_execute
(
server
,
prompt
,
outputs
,
current_item
,
extra_data
=
{}):
unique_id
=
current_item
inputs
=
prompt
[
unique_id
][
'inputs'
]
class_type
=
prompt
[
unique_id
][
'class_type'
]
class_def
=
nodes
.
NODE_CLASS_MAPPINGS
[
class_type
]
if
unique_id
in
outputs
:
return
[]
executed
=
[]
for
x
in
inputs
:
input_data
=
inputs
[
x
]
if
isinstance
(
input_data
,
list
):
input_unique_id
=
input_data
[
0
]
output_index
=
input_data
[
1
]
if
input_unique_id
not
in
outputs
:
executed
+=
recursive_execute
(
server
,
prompt
,
outputs
,
input_unique_id
,
extra_data
)
input_data_all
=
get_input_data
(
inputs
,
class_def
,
outputs
,
prompt
,
extra_data
)
if
server
.
client_id
is
not
None
:
server
.
send_sync
(
"executing"
,
{
"node"
:
unique_id
},
server
.
client_id
)
obj
=
class_def
()
outputs
[
unique_id
]
=
getattr
(
obj
,
obj
.
FUNCTION
)(
**
input_data_all
)
if
"ui"
in
outputs
[
unique_id
]
and
server
.
client_id
is
not
None
:
server
.
send_sync
(
"executed"
,
{
"node"
:
unique_id
,
"output"
:
outputs
[
unique_id
][
"ui"
]
},
server
.
client_id
)
return
executed
+
[
unique_id
]
def
recursive_will_execute
(
prompt
,
outputs
,
current_item
):
unique_id
=
current_item
inputs
=
prompt
[
unique_id
][
'inputs'
]
will_execute
=
[]
if
unique_id
in
outputs
:
return
[]
for
x
in
inputs
:
input_data
=
inputs
[
x
]
if
isinstance
(
input_data
,
list
):
input_unique_id
=
input_data
[
0
]
output_index
=
input_data
[
1
]
if
input_unique_id
not
in
outputs
:
will_execute
+=
recursive_will_execute
(
prompt
,
outputs
,
input_unique_id
)
return
will_execute
+
[
unique_id
]
def
recursive_output_delete_if_changed
(
prompt
,
old_prompt
,
outputs
,
current_item
):
unique_id
=
current_item
inputs
=
prompt
[
unique_id
][
'inputs'
]
class_type
=
prompt
[
unique_id
][
'class_type'
]
class_def
=
nodes
.
NODE_CLASS_MAPPINGS
[
class_type
]
is_changed_old
=
''
is_changed
=
''
if
hasattr
(
class_def
,
'IS_CHANGED'
):
if
unique_id
in
old_prompt
and
'is_changed'
in
old_prompt
[
unique_id
]:
is_changed_old
=
old_prompt
[
unique_id
][
'is_changed'
]
if
'is_changed'
not
in
prompt
[
unique_id
]:
input_data_all
=
get_input_data
(
inputs
,
class_def
)
is_changed
=
class_def
.
IS_CHANGED
(
**
input_data_all
)
prompt
[
unique_id
][
'is_changed'
]
=
is_changed
else
:
is_changed
=
prompt
[
unique_id
][
'is_changed'
]
if
unique_id
not
in
outputs
:
return
True
to_delete
=
False
if
is_changed
!=
is_changed_old
:
to_delete
=
True
elif
unique_id
not
in
old_prompt
:
to_delete
=
True
elif
inputs
==
old_prompt
[
unique_id
][
'inputs'
]:
for
x
in
inputs
:
input_data
=
inputs
[
x
]
if
isinstance
(
input_data
,
list
):
input_unique_id
=
input_data
[
0
]
output_index
=
input_data
[
1
]
if
input_unique_id
in
outputs
:
to_delete
=
recursive_output_delete_if_changed
(
prompt
,
old_prompt
,
outputs
,
input_unique_id
)
else
:
to_delete
=
True
if
to_delete
:
break
else
:
to_delete
=
True
if
to_delete
:
d
=
outputs
.
pop
(
unique_id
)
del
d
return
to_delete
class
PromptExecutor
:
def
__init__
(
self
,
server
):
self
.
outputs
=
{}
self
.
old_prompt
=
{}
self
.
server
=
server
def
execute
(
self
,
prompt
,
extra_data
=
{}):
if
"client_id"
in
extra_data
:
self
.
server
.
client_id
=
extra_data
[
"client_id"
]
else
:
self
.
server
.
client_id
=
None
with
torch
.
no_grad
():
for
x
in
prompt
:
recursive_output_delete_if_changed
(
prompt
,
self
.
old_prompt
,
self
.
outputs
,
x
)
current_outputs
=
set
(
self
.
outputs
.
keys
())
executed
=
[]
try
:
to_execute
=
[]
for
x
in
prompt
:
class_
=
nodes
.
NODE_CLASS_MAPPINGS
[
prompt
[
x
][
'class_type'
]]
if
hasattr
(
class_
,
'OUTPUT_NODE'
):
to_execute
+=
[(
0
,
x
)]
while
len
(
to_execute
)
>
0
:
#always execute the output that depends on the least amount of unexecuted nodes first
to_execute
=
sorted
(
list
(
map
(
lambda
a
:
(
len
(
recursive_will_execute
(
prompt
,
self
.
outputs
,
a
[
-
1
])),
a
[
-
1
]),
to_execute
)))
x
=
to_execute
.
pop
(
0
)[
-
1
]
class_
=
nodes
.
NODE_CLASS_MAPPINGS
[
prompt
[
x
][
'class_type'
]]
if
hasattr
(
class_
,
'OUTPUT_NODE'
):
if
class_
.
OUTPUT_NODE
==
True
:
valid
=
False
try
:
m
=
validate_inputs
(
prompt
,
x
)
valid
=
m
[
0
]
except
:
valid
=
False
if
valid
:
executed
+=
recursive_execute
(
self
.
server
,
prompt
,
self
.
outputs
,
x
,
extra_data
)
except
Exception
as
e
:
print
(
traceback
.
format_exc
())
to_delete
=
[]
for
o
in
self
.
outputs
:
if
o
not
in
current_outputs
:
to_delete
+=
[
o
]
if
o
in
self
.
old_prompt
:
d
=
self
.
old_prompt
.
pop
(
o
)
del
d
for
o
in
to_delete
:
d
=
self
.
outputs
.
pop
(
o
)
del
d
else
:
executed
=
set
(
executed
)
for
x
in
executed
:
self
.
old_prompt
[
x
]
=
copy
.
deepcopy
(
prompt
[
x
])
finally
:
if
self
.
server
.
client_id
is
not
None
:
self
.
server
.
send_sync
(
"executing"
,
{
"node"
:
None
},
self
.
server
.
client_id
)
torch
.
cuda
.
empty_cache
()
def
validate_inputs
(
prompt
,
item
):
unique_id
=
item
inputs
=
prompt
[
unique_id
][
'inputs'
]
class_type
=
prompt
[
unique_id
][
'class_type'
]
obj_class
=
nodes
.
NODE_CLASS_MAPPINGS
[
class_type
]
class_inputs
=
obj_class
.
INPUT_TYPES
()
required_inputs
=
class_inputs
[
'required'
]
for
x
in
required_inputs
:
if
x
not
in
inputs
:
return
(
False
,
"Required input is missing. {}, {}"
.
format
(
class_type
,
x
))
val
=
inputs
[
x
]
info
=
required_inputs
[
x
]
type_input
=
info
[
0
]
if
isinstance
(
val
,
list
):
if
len
(
val
)
!=
2
:
return
(
False
,
"Bad Input. {}, {}"
.
format
(
class_type
,
x
))
o_id
=
val
[
0
]
o_class_type
=
prompt
[
o_id
][
'class_type'
]
r
=
nodes
.
NODE_CLASS_MAPPINGS
[
o_class_type
].
RETURN_TYPES
if
r
[
val
[
1
]]
!=
type_input
:
return
(
False
,
"Return type mismatch. {}, {}, {} != {}"
.
format
(
class_type
,
x
,
r
[
val
[
1
]],
type_input
))
r
=
validate_inputs
(
prompt
,
o_id
)
if
r
[
0
]
==
False
:
return
r
else
:
if
type_input
==
"INT"
:
val
=
int
(
val
)
inputs
[
x
]
=
val
if
type_input
==
"FLOAT"
:
val
=
float
(
val
)
inputs
[
x
]
=
val
if
type_input
==
"STRING"
:
val
=
str
(
val
)
inputs
[
x
]
=
val
if
len
(
info
)
>
1
:
if
"min"
in
info
[
1
]
and
val
<
info
[
1
][
"min"
]:
return
(
False
,
"Value smaller than min. {}, {}"
.
format
(
class_type
,
x
))
if
"max"
in
info
[
1
]
and
val
>
info
[
1
][
"max"
]:
return
(
False
,
"Value bigger than max. {}, {}"
.
format
(
class_type
,
x
))
if
isinstance
(
type_input
,
list
):
if
val
not
in
type_input
:
return
(
False
,
"Value not in list. {}, {}: {} not in {}"
.
format
(
class_type
,
x
,
val
,
type_input
))
return
(
True
,
""
)
def
validate_prompt
(
prompt
):
outputs
=
set
()
for
x
in
prompt
:
class_
=
nodes
.
NODE_CLASS_MAPPINGS
[
prompt
[
x
][
'class_type'
]]
if
hasattr
(
class_
,
'OUTPUT_NODE'
)
and
class_
.
OUTPUT_NODE
==
True
:
outputs
.
add
(
x
)
if
len
(
outputs
)
==
0
:
return
(
False
,
"Prompt has no outputs"
)
good_outputs
=
set
()
errors
=
[]
for
o
in
outputs
:
valid
=
False
reason
=
""
try
:
m
=
validate_inputs
(
prompt
,
o
)
valid
=
m
[
0
]
reason
=
m
[
1
]
except
:
valid
=
False
reason
=
"Parsing error"
if
valid
==
True
:
good_outputs
.
add
(
x
)
else
:
print
(
"Failed to validate prompt for output {} {}"
.
format
(
o
,
reason
))
print
(
"output will be ignored"
)
errors
+=
[(
o
,
reason
)]
if
len
(
good_outputs
)
==
0
:
errors_list
=
"
\n
"
.
join
(
map
(
lambda
a
:
"{}"
.
format
(
a
[
1
]),
errors
))
return
(
False
,
"Prompt has no properly connected outputs
\n
{}"
.
format
(
errors_list
))
return
(
True
,
""
)
class
PromptQueue
:
def
__init__
(
self
,
server
):
self
.
server
=
server
self
.
mutex
=
threading
.
RLock
()
self
.
not_empty
=
threading
.
Condition
(
self
.
mutex
)
self
.
task_counter
=
0
self
.
queue
=
[]
self
.
currently_running
=
{}
self
.
history
=
{}
server
.
prompt_queue
=
self
def
put
(
self
,
item
):
with
self
.
mutex
:
heapq
.
heappush
(
self
.
queue
,
item
)
self
.
server
.
queue_updated
()
self
.
not_empty
.
notify
()
def
get
(
self
):
with
self
.
not_empty
:
while
len
(
self
.
queue
)
==
0
:
self
.
not_empty
.
wait
()
item
=
heapq
.
heappop
(
self
.
queue
)
i
=
self
.
task_counter
self
.
currently_running
[
i
]
=
copy
.
deepcopy
(
item
)
self
.
task_counter
+=
1
self
.
server
.
queue_updated
()
return
(
item
,
i
)
def
task_done
(
self
,
item_id
,
outputs
):
with
self
.
mutex
:
prompt
=
self
.
currently_running
.
pop
(
item_id
)
self
.
history
[
prompt
[
1
]]
=
{
"prompt"
:
prompt
,
"outputs"
:
{}
}
for
o
in
outputs
:
if
"ui"
in
outputs
[
o
]:
self
.
history
[
prompt
[
1
]][
"outputs"
][
o
]
=
outputs
[
o
][
"ui"
]
self
.
server
.
queue_updated
()
def
get_current_queue
(
self
):
with
self
.
mutex
:
out
=
[]
for
x
in
self
.
currently_running
.
values
():
out
+=
[
x
]
return
(
out
,
copy
.
deepcopy
(
self
.
queue
))
def
get_tasks_remaining
(
self
):
with
self
.
mutex
:
return
len
(
self
.
queue
)
+
len
(
self
.
currently_running
)
def
wipe_queue
(
self
):
with
self
.
mutex
:
self
.
queue
=
[]
self
.
server
.
queue_updated
()
def
delete_queue_item
(
self
,
function
):
with
self
.
mutex
:
for
x
in
range
(
len
(
self
.
queue
)):
if
function
(
self
.
queue
[
x
]):
if
len
(
self
.
queue
)
==
1
:
self
.
wipe_queue
()
else
:
self
.
queue
.
pop
(
x
)
heapq
.
heapify
(
self
.
queue
)
self
.
server
.
queue_updated
()
return
True
return
False
def
get_history
(
self
):
with
self
.
mutex
:
return
copy
.
deepcopy
(
self
.
history
)
def
wipe_history
(
self
):
with
self
.
mutex
:
self
.
history
=
{}
def
delete_history_item
(
self
,
id_to_delete
):
with
self
.
mutex
:
self
.
history
.
pop
(
id_to_delete
,
None
)
main.py
View file @
49d2e5bb
import
os
import
sys
import
copy
import
json
import
threading
import
heapq
import
traceback
import
asyncio
if
os
.
name
==
"nt"
:
import
logging
logging
.
getLogger
(
"xformers"
).
addFilter
(
lambda
record
:
'A matching Triton is not available'
not
in
record
.
getMessage
())
import
execution
import
server
if
__name__
==
"__main__"
:
...
...
@@ -32,357 +30,13 @@ if __name__ == "__main__":
print
(
"disabling upcasting of attention"
)
os
.
environ
[
'ATTN_PRECISION'
]
=
"fp16"
import
torch
import
nodes
def
get_input_data
(
inputs
,
class_def
,
outputs
=
{},
prompt
=
{},
extra_data
=
{}):
valid_inputs
=
class_def
.
INPUT_TYPES
()
input_data_all
=
{}
for
x
in
inputs
:
input_data
=
inputs
[
x
]
if
isinstance
(
input_data
,
list
):
input_unique_id
=
input_data
[
0
]
output_index
=
input_data
[
1
]
obj
=
outputs
[
input_unique_id
][
output_index
]
input_data_all
[
x
]
=
obj
else
:
if
(
"required"
in
valid_inputs
and
x
in
valid_inputs
[
"required"
])
or
(
"optional"
in
valid_inputs
and
x
in
valid_inputs
[
"optional"
]):
input_data_all
[
x
]
=
input_data
if
"hidden"
in
valid_inputs
:
h
=
valid_inputs
[
"hidden"
]
for
x
in
h
:
if
h
[
x
]
==
"PROMPT"
:
input_data_all
[
x
]
=
prompt
if
h
[
x
]
==
"EXTRA_PNGINFO"
:
if
"extra_pnginfo"
in
extra_data
:
input_data_all
[
x
]
=
extra_data
[
'extra_pnginfo'
]
return
input_data_all
def
recursive_execute
(
server
,
prompt
,
outputs
,
current_item
,
extra_data
=
{}):
unique_id
=
current_item
inputs
=
prompt
[
unique_id
][
'inputs'
]
class_type
=
prompt
[
unique_id
][
'class_type'
]
class_def
=
nodes
.
NODE_CLASS_MAPPINGS
[
class_type
]
if
unique_id
in
outputs
:
return
[]
executed
=
[]
for
x
in
inputs
:
input_data
=
inputs
[
x
]
if
isinstance
(
input_data
,
list
):
input_unique_id
=
input_data
[
0
]
output_index
=
input_data
[
1
]
if
input_unique_id
not
in
outputs
:
executed
+=
recursive_execute
(
server
,
prompt
,
outputs
,
input_unique_id
,
extra_data
)
input_data_all
=
get_input_data
(
inputs
,
class_def
,
outputs
,
prompt
,
extra_data
)
if
server
.
client_id
is
not
None
:
server
.
send_sync
(
"executing"
,
{
"node"
:
unique_id
},
server
.
client_id
)
obj
=
class_def
()
outputs
[
unique_id
]
=
getattr
(
obj
,
obj
.
FUNCTION
)(
**
input_data_all
)
if
"ui"
in
outputs
[
unique_id
]
and
server
.
client_id
is
not
None
:
server
.
send_sync
(
"executed"
,
{
"node"
:
unique_id
,
"output"
:
outputs
[
unique_id
][
"ui"
]
},
server
.
client_id
)
return
executed
+
[
unique_id
]
def
recursive_will_execute
(
prompt
,
outputs
,
current_item
):
unique_id
=
current_item
inputs
=
prompt
[
unique_id
][
'inputs'
]
will_execute
=
[]
if
unique_id
in
outputs
:
return
[]
for
x
in
inputs
:
input_data
=
inputs
[
x
]
if
isinstance
(
input_data
,
list
):
input_unique_id
=
input_data
[
0
]
output_index
=
input_data
[
1
]
if
input_unique_id
not
in
outputs
:
will_execute
+=
recursive_will_execute
(
prompt
,
outputs
,
input_unique_id
)
return
will_execute
+
[
unique_id
]
def
recursive_output_delete_if_changed
(
prompt
,
old_prompt
,
outputs
,
current_item
):
unique_id
=
current_item
inputs
=
prompt
[
unique_id
][
'inputs'
]
class_type
=
prompt
[
unique_id
][
'class_type'
]
class_def
=
nodes
.
NODE_CLASS_MAPPINGS
[
class_type
]
is_changed_old
=
''
is_changed
=
''
if
hasattr
(
class_def
,
'IS_CHANGED'
):
if
unique_id
in
old_prompt
and
'is_changed'
in
old_prompt
[
unique_id
]:
is_changed_old
=
old_prompt
[
unique_id
][
'is_changed'
]
if
'is_changed'
not
in
prompt
[
unique_id
]:
input_data_all
=
get_input_data
(
inputs
,
class_def
)
is_changed
=
class_def
.
IS_CHANGED
(
**
input_data_all
)
prompt
[
unique_id
][
'is_changed'
]
=
is_changed
else
:
is_changed
=
prompt
[
unique_id
][
'is_changed'
]
if
unique_id
not
in
outputs
:
return
True
to_delete
=
False
if
is_changed
!=
is_changed_old
:
to_delete
=
True
elif
unique_id
not
in
old_prompt
:
to_delete
=
True
elif
inputs
==
old_prompt
[
unique_id
][
'inputs'
]:
for
x
in
inputs
:
input_data
=
inputs
[
x
]
if
isinstance
(
input_data
,
list
):
input_unique_id
=
input_data
[
0
]
output_index
=
input_data
[
1
]
if
input_unique_id
in
outputs
:
to_delete
=
recursive_output_delete_if_changed
(
prompt
,
old_prompt
,
outputs
,
input_unique_id
)
else
:
to_delete
=
True
if
to_delete
:
break
else
:
to_delete
=
True
if
to_delete
:
d
=
outputs
.
pop
(
unique_id
)
del
d
return
to_delete
class
PromptExecutor
:
def
__init__
(
self
,
server
):
self
.
outputs
=
{}
self
.
old_prompt
=
{}
self
.
server
=
server
def
execute
(
self
,
prompt
,
extra_data
=
{}):
if
"client_id"
in
extra_data
:
self
.
server
.
client_id
=
extra_data
[
"client_id"
]
else
:
self
.
server
.
client_id
=
None
with
torch
.
no_grad
():
for
x
in
prompt
:
recursive_output_delete_if_changed
(
prompt
,
self
.
old_prompt
,
self
.
outputs
,
x
)
current_outputs
=
set
(
self
.
outputs
.
keys
())
executed
=
[]
try
:
to_execute
=
[]
for
x
in
prompt
:
class_
=
nodes
.
NODE_CLASS_MAPPINGS
[
prompt
[
x
][
'class_type'
]]
if
hasattr
(
class_
,
'OUTPUT_NODE'
):
to_execute
+=
[(
0
,
x
)]
while
len
(
to_execute
)
>
0
:
#always execute the output that depends on the least amount of unexecuted nodes first
to_execute
=
sorted
(
list
(
map
(
lambda
a
:
(
len
(
recursive_will_execute
(
prompt
,
self
.
outputs
,
a
[
-
1
])),
a
[
-
1
]),
to_execute
)))
x
=
to_execute
.
pop
(
0
)[
-
1
]
class_
=
nodes
.
NODE_CLASS_MAPPINGS
[
prompt
[
x
][
'class_type'
]]
if
hasattr
(
class_
,
'OUTPUT_NODE'
):
if
class_
.
OUTPUT_NODE
==
True
:
valid
=
False
try
:
m
=
validate_inputs
(
prompt
,
x
)
valid
=
m
[
0
]
except
:
valid
=
False
if
valid
:
executed
+=
recursive_execute
(
self
.
server
,
prompt
,
self
.
outputs
,
x
,
extra_data
)
except
Exception
as
e
:
print
(
traceback
.
format_exc
())
to_delete
=
[]
for
o
in
self
.
outputs
:
if
o
not
in
current_outputs
:
to_delete
+=
[
o
]
if
o
in
self
.
old_prompt
:
d
=
self
.
old_prompt
.
pop
(
o
)
del
d
for
o
in
to_delete
:
d
=
self
.
outputs
.
pop
(
o
)
del
d
else
:
executed
=
set
(
executed
)
for
x
in
executed
:
self
.
old_prompt
[
x
]
=
copy
.
deepcopy
(
prompt
[
x
])
finally
:
if
self
.
server
.
client_id
is
not
None
:
self
.
server
.
send_sync
(
"executing"
,
{
"node"
:
None
},
self
.
server
.
client_id
)
torch
.
cuda
.
empty_cache
()
def
validate_inputs
(
prompt
,
item
):
unique_id
=
item
inputs
=
prompt
[
unique_id
][
'inputs'
]
class_type
=
prompt
[
unique_id
][
'class_type'
]
obj_class
=
nodes
.
NODE_CLASS_MAPPINGS
[
class_type
]
class_inputs
=
obj_class
.
INPUT_TYPES
()
required_inputs
=
class_inputs
[
'required'
]
for
x
in
required_inputs
:
if
x
not
in
inputs
:
return
(
False
,
"Required input is missing. {}, {}"
.
format
(
class_type
,
x
))
val
=
inputs
[
x
]
info
=
required_inputs
[
x
]
type_input
=
info
[
0
]
if
isinstance
(
val
,
list
):
if
len
(
val
)
!=
2
:
return
(
False
,
"Bad Input. {}, {}"
.
format
(
class_type
,
x
))
o_id
=
val
[
0
]
o_class_type
=
prompt
[
o_id
][
'class_type'
]
r
=
nodes
.
NODE_CLASS_MAPPINGS
[
o_class_type
].
RETURN_TYPES
if
r
[
val
[
1
]]
!=
type_input
:
return
(
False
,
"Return type mismatch. {}, {}"
.
format
(
class_type
,
x
))
r
=
validate_inputs
(
prompt
,
o_id
)
if
r
[
0
]
==
False
:
return
r
else
:
if
type_input
==
"INT"
:
val
=
int
(
val
)
inputs
[
x
]
=
val
if
type_input
==
"FLOAT"
:
val
=
float
(
val
)
inputs
[
x
]
=
val
if
type_input
==
"STRING"
:
val
=
str
(
val
)
inputs
[
x
]
=
val
if
len
(
info
)
>
1
:
if
"min"
in
info
[
1
]
and
val
<
info
[
1
][
"min"
]:
return
(
False
,
"Value smaller than min. {}, {}"
.
format
(
class_type
,
x
))
if
"max"
in
info
[
1
]
and
val
>
info
[
1
][
"max"
]:
return
(
False
,
"Value bigger than max. {}, {}"
.
format
(
class_type
,
x
))
if
isinstance
(
type_input
,
list
):
if
val
not
in
type_input
:
return
(
False
,
"Value not in list. {}, {}: {} not in {}"
.
format
(
class_type
,
x
,
val
,
type_input
))
return
(
True
,
""
)
def
validate_prompt
(
prompt
):
outputs
=
set
()
for
x
in
prompt
:
class_
=
nodes
.
NODE_CLASS_MAPPINGS
[
prompt
[
x
][
'class_type'
]]
if
hasattr
(
class_
,
'OUTPUT_NODE'
)
and
class_
.
OUTPUT_NODE
==
True
:
outputs
.
add
(
x
)
if
len
(
outputs
)
==
0
:
return
(
False
,
"Prompt has no outputs"
)
good_outputs
=
set
()
errors
=
[]
for
o
in
outputs
:
valid
=
False
reason
=
""
try
:
m
=
validate_inputs
(
prompt
,
o
)
valid
=
m
[
0
]
reason
=
m
[
1
]
except
:
valid
=
False
reason
=
"Parsing error"
if
valid
==
True
:
good_outputs
.
add
(
x
)
else
:
print
(
"Failed to validate prompt for output {} {}"
.
format
(
o
,
reason
))
print
(
"output will be ignored"
)
errors
+=
[(
o
,
reason
)]
if
len
(
good_outputs
)
==
0
:
errors_list
=
"
\n
"
.
join
(
map
(
lambda
a
:
"{}"
.
format
(
a
[
1
]),
errors
))
return
(
False
,
"Prompt has no properly connected outputs
\n
{}"
.
format
(
errors_list
))
return
(
True
,
""
)
def
prompt_worker
(
q
,
server
):
e
=
PromptExecutor
(
server
)
e
=
execution
.
PromptExecutor
(
server
)
while
True
:
item
,
item_id
=
q
.
get
()
e
.
execute
(
item
[
-
2
],
item
[
-
1
])
q
.
task_done
(
item_id
,
e
.
outputs
)
class
PromptQueue
:
def
__init__
(
self
,
server
):
self
.
server
=
server
self
.
mutex
=
threading
.
RLock
()
self
.
not_empty
=
threading
.
Condition
(
self
.
mutex
)
self
.
task_counter
=
0
self
.
queue
=
[]
self
.
currently_running
=
{}
self
.
history
=
{}
server
.
prompt_queue
=
self
def
put
(
self
,
item
):
with
self
.
mutex
:
heapq
.
heappush
(
self
.
queue
,
item
)
self
.
server
.
queue_updated
()
self
.
not_empty
.
notify
()
def
get
(
self
):
with
self
.
not_empty
:
while
len
(
self
.
queue
)
==
0
:
self
.
not_empty
.
wait
()
item
=
heapq
.
heappop
(
self
.
queue
)
i
=
self
.
task_counter
self
.
currently_running
[
i
]
=
copy
.
deepcopy
(
item
)
self
.
task_counter
+=
1
self
.
server
.
queue_updated
()
return
(
item
,
i
)
def
task_done
(
self
,
item_id
,
outputs
):
with
self
.
mutex
:
prompt
=
self
.
currently_running
.
pop
(
item_id
)
self
.
history
[
prompt
[
1
]]
=
{
"prompt"
:
prompt
,
"outputs"
:
{}
}
for
o
in
outputs
:
if
"ui"
in
outputs
[
o
]:
self
.
history
[
prompt
[
1
]][
"outputs"
][
o
]
=
outputs
[
o
][
"ui"
]
self
.
server
.
queue_updated
()
def
get_current_queue
(
self
):
with
self
.
mutex
:
out
=
[]
for
x
in
self
.
currently_running
.
values
():
out
+=
[
x
]
return
(
out
,
copy
.
deepcopy
(
self
.
queue
))
def
get_tasks_remaining
(
self
):
with
self
.
mutex
:
return
len
(
self
.
queue
)
+
len
(
self
.
currently_running
)
def
wipe_queue
(
self
):
with
self
.
mutex
:
self
.
queue
=
[]
self
.
server
.
queue_updated
()
def
delete_queue_item
(
self
,
function
):
with
self
.
mutex
:
for
x
in
range
(
len
(
self
.
queue
)):
if
function
(
self
.
queue
[
x
]):
if
len
(
self
.
queue
)
==
1
:
self
.
wipe_queue
()
else
:
self
.
queue
.
pop
(
x
)
heapq
.
heapify
(
self
.
queue
)
self
.
server
.
queue_updated
()
return
True
return
False
def
get_history
(
self
):
with
self
.
mutex
:
return
copy
.
deepcopy
(
self
.
history
)
def
wipe_history
(
self
):
with
self
.
mutex
:
self
.
history
=
{}
def
delete_history_item
(
self
,
id_to_delete
):
with
self
.
mutex
:
self
.
history
.
pop
(
id_to_delete
,
None
)
async
def
run
(
server
,
address
=
''
,
port
=
8188
,
verbose
=
True
):
await
asyncio
.
gather
(
server
.
start
(
address
,
port
,
verbose
),
server
.
publish_loop
())
...
...
@@ -400,7 +54,7 @@ if __name__ == "__main__":
loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
loop
)
server
=
server
.
PromptServer
(
loop
)
q
=
PromptQueue
(
server
)
q
=
execution
.
PromptQueue
(
server
)
hijack_progress
(
server
)
...
...
server.py
View file @
49d2e5bb
...
...
@@ -2,7 +2,7 @@ import os
import
sys
import
asyncio
import
nodes
import
mai
n
import
executio
n
import
uuid
import
json
...
...
@@ -111,7 +111,7 @@ class PromptServer():
if
"prompt"
in
json_data
:
prompt
=
json_data
[
"prompt"
]
valid
=
mai
n
.
validate_prompt
(
prompt
)
valid
=
executio
n
.
validate_prompt
(
prompt
)
extra_data
=
{}
if
"extra_data"
in
json_data
:
extra_data
=
json_data
[
"extra_data"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment