"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "9cdb5bf9fe656fb26d1e0a2fc7551af1e08cbfb2"
Commit 4b083142 authored by comfyanonymous's avatar comfyanonymous
Browse files

Add more features to the backend queue code.

The queue can now be queried, entries can be deleted and prompts easily
queued to the front of the queue.

Just need to expose it in the UI next.
parent 9d611a90
...@@ -3,7 +3,7 @@ import sys ...@@ -3,7 +3,7 @@ import sys
import copy import copy
import json import json
import threading import threading
import queue import heapq
import traceback import traceback
if '--dont-upcast-attention' in sys.argv: if '--dont-upcast-attention' in sys.argv:
...@@ -148,6 +148,7 @@ class PromptExecutor: ...@@ -148,6 +148,7 @@ class PromptExecutor:
to_execute += [(0, x)] to_execute += [(0, x)]
while len(to_execute) > 0: while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
x = to_execute.pop(0)[-1] x = to_execute.pop(0)[-1]
...@@ -266,10 +267,63 @@ def validate_prompt(prompt): ...@@ -266,10 +267,63 @@ def validate_prompt(prompt):
def prompt_worker(q): def prompt_worker(q):
e = PromptExecutor() e = PromptExecutor()
while True: while True:
item = q.get() item, item_id = q.get()
e.execute(item[-2], item[-1]) e.execute(item[-2], item[-1])
q.task_done() q.task_done(item_id)
class PromptQueue:
def __init__(self):
self.mutex = threading.RLock()
self.not_empty = threading.Condition(self.mutex)
self.task_counter = 0
self.queue = []
self.currently_running = {}
def put(self, item):
with self.mutex:
heapq.heappush(self.queue, item)
self.not_empty.notify()
def get(self):
with self.not_empty:
while len(self.queue) == 0:
self.not_empty.wait()
item = heapq.heappop(self.queue)
i = self.task_counter
self.currently_running[i] = copy.deepcopy(item)
self.task_counter += 1
return (item, i)
def task_done(self, item_id):
with self.mutex:
self.currently_running.pop(item_id)
def get_current_queue(self):
with self.mutex:
out = []
for x in self.currently_running.values():
out += [x]
return (out, copy.deepcopy(self.queue))
def get_tasks_remaining(self):
with self.mutex:
return len(self.queue) + len(self.currently_running)
def wipe_queue(self):
with self.mutex:
self.queue = []
def delete_queue_item(self, function):
with self.mutex:
for x in range(len(self.queue)):
if function(self.queue[x]):
if len(self.queue) == 1:
self.wipe_queue()
else:
self.queue.pop(x)
heapq.heapify(self.queue)
return True
return False
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
...@@ -285,9 +339,16 @@ class PromptServer(BaseHTTPRequestHandler): ...@@ -285,9 +339,16 @@ class PromptServer(BaseHTTPRequestHandler):
self._set_headers(ct='application/json') self._set_headers(ct='application/json')
prompt_info = {} prompt_info = {}
exec_info = {} exec_info = {}
exec_info['queue_remaining'] = self.server.prompt_queue.unfinished_tasks exec_info['queue_remaining'] = self.server.prompt_queue.get_tasks_remaining()
prompt_info['exec_info'] = exec_info prompt_info['exec_info'] = exec_info
self.wfile.write(json.dumps(prompt_info).encode('utf-8')) self.wfile.write(json.dumps(prompt_info).encode('utf-8'))
elif self.path == "/queue":
self._set_headers(ct='application/json')
queue_info = {}
current_queue = self.server.prompt_queue.get_current_queue()
queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1]
self.wfile.write(json.dumps(queue_info).encode('utf-8'))
elif self.path == "/object_info": elif self.path == "/object_info":
self._set_headers(ct='application/json') self._set_headers(ct='application/json')
out = {} out = {}
...@@ -325,12 +386,16 @@ class PromptServer(BaseHTTPRequestHandler): ...@@ -325,12 +386,16 @@ class PromptServer(BaseHTTPRequestHandler):
out_string = "" out_string = ""
if self.path == "/prompt": if self.path == "/prompt":
print("got prompt") print("got prompt")
self.data_string = self.rfile.read(int(self.headers['Content-Length'])) data_string = self.rfile.read(int(self.headers['Content-Length']))
json_data = json.loads(self.data_string) json_data = json.loads(data_string)
if "number" in json_data: if "number" in json_data:
number = float(json_data['number']) number = float(json_data['number'])
else: else:
number = self.server.number number = self.server.number
if "front" in json_data:
if json_data['front']:
number = -number
self.server.number += 1 self.server.number += 1
if "prompt" in json_data: if "prompt" in json_data:
prompt = json_data["prompt"] prompt = json_data["prompt"]
...@@ -344,6 +409,18 @@ class PromptServer(BaseHTTPRequestHandler): ...@@ -344,6 +409,18 @@ class PromptServer(BaseHTTPRequestHandler):
resp_code = 400 resp_code = 400
out_string = valid[1] out_string = valid[1]
print("invalid prompt:", valid[1]) print("invalid prompt:", valid[1])
elif self.path == "/queue":
data_string = self.rfile.read(int(self.headers['Content-Length']))
json_data = json.loads(data_string)
if "clear" in json_data:
if json_data["clear"]:
self.server.prompt_queue.wipe_queue()
if "delete" in json_data:
to_delete = json_data['delete']
for id_to_delete in to_delete:
delete_func = lambda a: a[1] == int(id_to_delete)
self.server.prompt_queue.delete_queue_item(delete_func)
self._set_headers(code=resp_code) self._set_headers(code=resp_code)
self.end_headers() self.end_headers()
self.wfile.write(out_string.encode('utf8')) self.wfile.write(out_string.encode('utf8'))
...@@ -366,7 +443,7 @@ def run(prompt_queue, address='', port=8188): ...@@ -366,7 +443,7 @@ def run(prompt_queue, address='', port=8188):
if __name__ == "__main__": if __name__ == "__main__":
q = queue.PriorityQueue() q = PromptQueue()
threading.Thread(target=prompt_worker, daemon=True, args=(q,)).start() threading.Thread(target=prompt_worker, daemon=True, args=(q,)).start()
run(q, address='127.0.0.1', port=8188) run(q, address='127.0.0.1', port=8188)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment