Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
69cc75fb
Commit
69cc75fb
authored
Mar 02, 2023
by
comfyanonymous
Browse files
Add a way to interrupt current processing in the backend.
parent
1e2c4df9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
44 additions
and
1 deletion
+44
-1
comfy/model_management.py
comfy/model_management.py
+28
-0
comfy/samplers.py
comfy/samplers.py
+2
-0
execution.py
execution.py
+1
-0
nodes.py
nodes.py
+7
-0
server.py
server.py
+6
-1
No files found.
comfy/model_management.py
View file @
69cc75fb
...
@@ -162,3 +162,31 @@ def maximum_batch_area():
...
@@ -162,3 +162,31 @@ def maximum_batch_area():
memory_free
=
get_free_memory
()
/
(
1024
*
1024
)
memory_free
=
get_free_memory
()
/
(
1024
*
1024
)
area
=
((
memory_free
-
1024
)
*
0.9
)
/
(
0.6
)
area
=
((
memory_free
-
1024
)
*
0.9
)
/
(
0.6
)
return
int
(
max
(
area
,
0
))
return
int
(
max
(
area
,
0
))
#TODO: might be cleaner to put this somewhere else
import
threading
class
InterruptProcessingException
(
Exception
):
pass
interrupt_processing_mutex
=
threading
.
RLock
()
interrupt_processing
=
False
def
interrupt_current_processing
(
value
=
True
):
global
interrupt_processing
global
interrupt_processing_mutex
with
interrupt_processing_mutex
:
interrupt_processing
=
value
def
processing_interrupted
():
global
interrupt_processing
global
interrupt_processing_mutex
with
interrupt_processing_mutex
:
return
interrupt_processing
def
throw_exception_if_processing_interrupted
():
global
interrupt_processing
global
interrupt_processing_mutex
with
interrupt_processing_mutex
:
if
interrupt_processing
:
interrupt_processing
=
False
raise
InterruptProcessingException
()
comfy/samplers.py
View file @
69cc75fb
...
@@ -172,6 +172,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
...
@@ -172,6 +172,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
output
=
model_function
(
input_x
,
timestep_
,
cond
=
c
).
chunk
(
batch_chunks
)
output
=
model_function
(
input_x
,
timestep_
,
cond
=
c
).
chunk
(
batch_chunks
)
del
input_x
del
input_x
model_management
.
throw_exception_if_processing_interrupted
()
for
o
in
range
(
batch_chunks
):
for
o
in
range
(
batch_chunks
):
if
cond_or_uncond
[
o
]
==
COND
:
if
cond_or_uncond
[
o
]
==
COND
:
out_cond
[:,:,
area
[
o
][
2
]:
area
[
o
][
0
]
+
area
[
o
][
2
],
area
[
o
][
3
]:
area
[
o
][
1
]
+
area
[
o
][
3
]]
+=
output
[
o
]
*
mult
[
o
]
out_cond
[:,:,
area
[
o
][
2
]:
area
[
o
][
0
]
+
area
[
o
][
2
],
area
[
o
][
3
]:
area
[
o
][
1
]
+
area
[
o
][
3
]]
+=
output
[
o
]
*
mult
[
o
]
...
...
execution.py
View file @
69cc75fb
...
@@ -58,6 +58,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
...
@@ -58,6 +58,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
server
.
send_sync
(
"executing"
,
{
"node"
:
unique_id
},
server
.
client_id
)
server
.
send_sync
(
"executing"
,
{
"node"
:
unique_id
},
server
.
client_id
)
obj
=
class_def
()
obj
=
class_def
()
nodes
.
before_node_execution
()
outputs
[
unique_id
]
=
getattr
(
obj
,
obj
.
FUNCTION
)(
**
input_data_all
)
outputs
[
unique_id
]
=
getattr
(
obj
,
obj
.
FUNCTION
)(
**
input_data_all
)
if
"ui"
in
outputs
[
unique_id
]
and
server
.
client_id
is
not
None
:
if
"ui"
in
outputs
[
unique_id
]
and
server
.
client_id
is
not
None
:
server
.
send_sync
(
"executed"
,
{
"node"
:
unique_id
,
"output"
:
outputs
[
unique_id
][
"ui"
]
},
server
.
client_id
)
server
.
send_sync
(
"executed"
,
{
"node"
:
unique_id
,
"output"
:
outputs
[
unique_id
][
"ui"
]
},
server
.
client_id
)
...
...
nodes.py
View file @
69cc75fb
...
@@ -41,6 +41,13 @@ def recursive_search(directory):
...
@@ -41,6 +41,13 @@ def recursive_search(directory):
def
filter_files_extensions
(
files
,
extensions
):
def
filter_files_extensions
(
files
,
extensions
):
return
sorted
(
list
(
filter
(
lambda
a
:
os
.
path
.
splitext
(
a
)[
-
1
].
lower
()
in
extensions
,
files
)))
return
sorted
(
list
(
filter
(
lambda
a
:
os
.
path
.
splitext
(
a
)[
-
1
].
lower
()
in
extensions
,
files
)))
def
before_node_execution
():
model_management
.
throw_exception_if_processing_interrupted
()
def
interrupt_processing
():
model_management
.
interrupt_current_processing
()
class
CLIPTextEncode
:
class
CLIPTextEncode
:
@
classmethod
@
classmethod
def
INPUT_TYPES
(
s
):
def
INPUT_TYPES
(
s
):
...
...
server.py
View file @
69cc75fb
...
@@ -140,7 +140,12 @@ class PromptServer():
...
@@ -140,7 +140,12 @@ class PromptServer():
self
.
prompt_queue
.
delete_queue_item
(
delete_func
)
self
.
prompt_queue
.
delete_queue_item
(
delete_func
)
return
web
.
Response
(
status
=
200
)
return
web
.
Response
(
status
=
200
)
@
routes
.
post
(
"/interrupt"
)
async
def
post_interrupt
(
request
):
nodes
.
interrupt_processing
()
return
web
.
Response
(
status
=
200
)
@
routes
.
post
(
"/history"
)
@
routes
.
post
(
"/history"
)
async
def
post_history
(
request
):
async
def
post_history
(
request
):
json_data
=
await
request
.
json
()
json_data
=
await
request
.
json
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment