Unverified Commit a52aa9f4 authored by pythongosssss's avatar pythongosssss Committed by GitHub
Browse files

Moved api out to server

Reworked sockets to use socketio
Added progress to nodes
Added highlight to active node
Added preview to saveimage node
parent 92808718
...@@ -11,15 +11,7 @@ if os.name == "nt": ...@@ -11,15 +11,7 @@ if os.name == "nt":
import logging import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
try: import server
import aiohttp
from aiohttp import web
except ImportError:
print("Module 'aiohttp' not installed. Please install it via:")
print("pip install aiohttp")
print("or")
print("pip install -r requirements.txt")
sys.exit()
if __name__ == "__main__": if __name__ == "__main__":
if '--help' in sys.argv: if '--help' in sys.argv:
...@@ -36,14 +28,14 @@ if __name__ == "__main__": ...@@ -36,14 +28,14 @@ if __name__ == "__main__":
print() print()
exit() exit()
if '--dont-upcast-attention' in sys.argv: if '--dont-upcast-attention' in sys.argv:
print("disabling upcasting of attention") print("disabling upcasting of attention")
os.environ['ATTN_PRECISION'] = "fp16" os.environ['ATTN_PRECISION'] = "fp16"
import torch import torch
import nodes import nodes
def get_input_data(inputs, class_def, outputs={}, prompt={}, extra_data={}): def get_input_data(inputs, class_def, outputs={}, prompt={}, extra_data={}, server=None, unique_id=None):
valid_inputs = class_def.INPUT_TYPES() valid_inputs = class_def.INPUT_TYPES()
input_data_all = {} input_data_all = {}
for x in inputs: for x in inputs:
...@@ -65,9 +57,13 @@ def get_input_data(inputs, class_def, outputs={}, prompt={}, extra_data={}): ...@@ -65,9 +57,13 @@ def get_input_data(inputs, class_def, outputs={}, prompt={}, extra_data={}):
if h[x] == "EXTRA_PNGINFO": if h[x] == "EXTRA_PNGINFO":
if "extra_pnginfo" in extra_data: if "extra_pnginfo" in extra_data:
input_data_all[x] = extra_data['extra_pnginfo'] input_data_all[x] = extra_data['extra_pnginfo']
if h[x] == "SERVER":
input_data_all[x] = server
if h[x] == "UNIQUE_ID":
input_data_all[x] = unique_id
return input_data_all return input_data_all
def recursive_execute(prompt, outputs, current_item, extra_data={}): def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
unique_id = current_item unique_id = current_item
inputs = prompt[unique_id]['inputs'] inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type'] class_type = prompt[unique_id]['class_type']
...@@ -84,9 +80,11 @@ def recursive_execute(prompt, outputs, current_item, extra_data={}): ...@@ -84,9 +80,11 @@ def recursive_execute(prompt, outputs, current_item, extra_data={}):
input_unique_id = input_data[0] input_unique_id = input_data[0]
output_index = input_data[1] output_index = input_data[1]
if input_unique_id not in outputs: if input_unique_id not in outputs:
executed += recursive_execute(prompt, outputs, input_unique_id, extra_data) executed += recursive_execute(server, prompt, outputs, input_unique_id, extra_data)
input_data_all = get_input_data(inputs, class_def, outputs, prompt, extra_data) input_data_all = get_input_data(inputs, class_def, outputs, prompt, extra_data, server, unique_id)
if server.client_id is not None:
server.send_sync("execute", { "node": unique_id }, server.client_id)
obj = class_def() obj = class_def()
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all) outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
...@@ -157,11 +155,17 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item ...@@ -157,11 +155,17 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
return to_delete return to_delete
class PromptExecutor: class PromptExecutor:
def __init__(self): def __init__(self, server):
self.outputs = {} self.outputs = {}
self.old_prompt = {} self.old_prompt = {}
self.server = server
def execute(self, prompt, extra_data={}): def execute(self, prompt, extra_data={}):
if "client_id" in extra_data:
self.server.client_id = extra_data["client_id"]
else:
self.server.client_id = None
with torch.no_grad(): with torch.no_grad():
for x in prompt: for x in prompt:
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
...@@ -190,7 +194,7 @@ class PromptExecutor: ...@@ -190,7 +194,7 @@ class PromptExecutor:
except: except:
valid = False valid = False
if valid: if valid:
executed += recursive_execute(prompt, self.outputs, x, extra_data) executed += recursive_execute(self.server, prompt, self.outputs, x, extra_data)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
...@@ -208,6 +212,11 @@ class PromptExecutor: ...@@ -208,6 +212,11 @@ class PromptExecutor:
executed = set(executed) executed = set(executed)
for x in executed: for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x]) self.old_prompt[x] = copy.deepcopy(prompt[x])
finally:
if self.server.client_id is not None:
self.server.send_sync("execute", { "node": None }, self.server.client_id)
torch.cuda.empty_cache() torch.cuda.empty_cache()
def validate_inputs(prompt, item): def validate_inputs(prompt, item):
...@@ -293,27 +302,27 @@ def validate_prompt(prompt): ...@@ -293,27 +302,27 @@ def validate_prompt(prompt):
return (True, "") return (True, "")
def prompt_worker(q): def prompt_worker(q, server):
e = PromptExecutor() e = PromptExecutor(server)
while True: while True:
item, item_id = q.get() item, item_id = q.get()
e.execute(item[-2], item[-1]) e.execute(item[-2], item[-1])
q.task_done(item_id) q.task_done(item_id)
class PromptQueue: class PromptQueue:
def __init__(self, socket_handler): def __init__(self, server):
self.socket_handler = socket_handler self.server = server
self.mutex = threading.RLock() self.mutex = threading.RLock()
self.not_empty = threading.Condition(self.mutex) self.not_empty = threading.Condition(self.mutex)
self.task_counter = 0 self.task_counter = 0
self.queue = [] self.queue = []
self.currently_running = {} self.currently_running = {}
socket_handler.prompt_queue = self server.prompt_queue = self
def put(self, item): def put(self, item):
with self.mutex: with self.mutex:
heapq.heappush(self.queue, item) heapq.heappush(self.queue, item)
self.socket_handler.queue_updated(self) self.server.queue_updated()
self.not_empty.notify() self.not_empty.notify()
def get(self): def get(self):
...@@ -324,13 +333,13 @@ class PromptQueue: ...@@ -324,13 +333,13 @@ class PromptQueue:
i = self.task_counter i = self.task_counter
self.currently_running[i] = copy.deepcopy(item) self.currently_running[i] = copy.deepcopy(item)
self.task_counter += 1 self.task_counter += 1
self.socket_handler.queue_updated(self) self.server.queue_updated()
return (item, i) return (item, i)
def task_done(self, item_id): def task_done(self, item_id):
with self.mutex: with self.mutex:
self.currently_running.pop(item_id) self.currently_running.pop(item_id)
self.socket_handler.queue_updated(self) self.server.queue_updated()
def get_current_queue(self): def get_current_queue(self):
with self.mutex: with self.mutex:
...@@ -346,7 +355,7 @@ class PromptQueue: ...@@ -346,7 +355,7 @@ class PromptQueue:
def wipe_queue(self): def wipe_queue(self):
with self.mutex: with self.mutex:
self.queue = [] self.queue = []
self.socket_handler.queue_updated(self) self.server.queue_updated()
def delete_queue_item(self, function): def delete_queue_item(self, function):
with self.mutex: with self.mutex:
...@@ -357,174 +366,32 @@ class PromptQueue: ...@@ -357,174 +366,32 @@ class PromptQueue:
else: else:
self.queue.pop(x) self.queue.pop(x)
heapq.heapify(self.queue) heapq.heapify(self.queue)
self.socket_handler.queue_updated(self) self.server.queue_updated()
return True return True
return False return False
def get_queue_info(prompt_queue): async def run(server, address='', port=8188):
prompt_info = {} await asyncio.gather(server.start(address, port), server.publish_loop())
exec_info = {}
exec_info['queue_remaining'] = prompt_queue.get_tasks_remaining()
prompt_info['exec_info'] = exec_info
return prompt_info
class SocketHandler():
def __init__(self, loop):
self.connected = set()
self.messages = asyncio.Queue()
self.loop = loop
async def publish_loop(self):
while True:
msg = await self.messages.get()
await self.send(msg)
def queue_updated(self, queue):
# This is called by the queue processing thread so we need to make it thread safe
loop.call_soon_threadsafe(self.messages.put_nowait, { 'type': 'status', 'status': get_queue_info(queue) })
async def send(self, message, socket = None):
if isinstance(message, str) == False:
message = json.dumps(message)
if socket is None:
for ws in self.connected:
await ws.send_str(message)
else:
await socket.send_str(message)
async def process(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)
self.connected.add(ws)
try:
# Send initial state to the new client
await self.send({ 'type': 'status', 'status': get_queue_info(self.prompt_queue) }, ws)
async for msg in ws:
if msg.type == aiohttp.WSMsgType.ERROR:
print('ws connection closed with exception %s' % ws.exception())
finally:
self.connected.remove(ws)
return ws
class PromptServer():
def __init__(self, prompt_queue, socket_handler):
self.prompt_queue = prompt_queue
self.socket_handler = socket_handler
self.number = 0
self.app = web.Application()
self.web_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "webshit")
routes = web.RouteTableDef()
@routes.get('/ws')
async def websocket_handler(request):
return await self.socket_handler.process(request)
@routes.get("/")
async def get_root(request):
return web.FileResponse(os.path.join(self.web_root, "index.html"))
@routes.get("/prompt")
async def get_prompt(request):
return web.json_response(get_queue_info(self.prompt_queue))
@routes.get("/object_info")
async def get_object_info(request):
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:
obj_class = nodes.NODE_CLASS_MAPPINGS[x]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
info['name'] = x #TODO
info['description'] = ''
info['category'] = 'sd'
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
out[x] = info
return web.json_response(out)
@routes.get("/queue")
async def get_queue(request):
queue_info = {}
current_queue = self.prompt_queue.get_current_queue()
queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1]
return web.json_response(queue_info)
@routes.post("/prompt")
async def post_prompt(request):
print("got prompt")
resp_code = 200
out_string = ""
json_data = await request.json()
if "number" in json_data:
number = float(json_data['number'])
else:
number = self.number
if "front" in json_data:
if json_data['front']:
number = -number
self.number += 1
if "prompt" in json_data:
prompt = json_data["prompt"]
valid = validate_prompt(prompt)
extra_data = {}
if "extra_data" in json_data:
extra_data = json_data["extra_data"]
if valid[0]:
self.prompt_queue.put((number, id(prompt), prompt, extra_data))
else:
resp_code = 400
out_string = valid[1]
print("invalid prompt:", valid[1])
return web.Response(body=out_string, status=resp_code)
@routes.post("/queue")
async def post_queue(request):
json_data = await request.json()
if "clear" in json_data:
if json_data["clear"]:
self.prompt_queue.wipe_queue()
if "delete" in json_data:
to_delete = json_data['delete']
for id_to_delete in to_delete:
delete_func = lambda a: a[1] == int(id_to_delete)
self.prompt_queue.delete_queue_item(delete_func)
return web.Response(status=200)
self.app.add_routes(routes)
self.app.add_routes([
web.static('/', self.web_root),
])
async def start_server(server, address, port):
runner = web.AppRunner(server.app)
await runner.setup()
site = web.TCPSite(runner, address, port)
await site.start()
if address == '':
address = '0.0.0.0'
print("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(address, port))
async def run(prompt_queue, socket_handler, address='', port=8188): def hijack_progress(server):
server = PromptServer(prompt_queue, socket_handler) from tqdm.auto import tqdm
await asyncio.gather(start_server(server, address, port), socket_handler.publish_loop()) orig_func = getattr(tqdm, "update")
def wrapped_func(*args, **kwargs):
pbar = args[0]
v = orig_func(*args, **kwargs)
server.send_sync("progress", { "value": pbar.n, "max": pbar.total}, server.client_id)
return v
setattr(tqdm, "update", wrapped_func)
if __name__ == "__main__": if __name__ == "__main__":
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
server = server.PromptServer(loop)
q = PromptQueue(server)
socket_handler = SocketHandler(loop) hijack_progress(server)
q = PromptQueue(socket_handler)
threading.Thread(target=prompt_worker, daemon=True, args=(q,)).start() threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
if '--listen' in sys.argv: if '--listen' in sys.argv:
address = '0.0.0.0' address = '0.0.0.0'
else: else:
...@@ -537,5 +404,11 @@ if __name__ == "__main__": ...@@ -537,5 +404,11 @@ if __name__ == "__main__":
except: except:
pass pass
loop.run_until_complete(run(q, socket_handler, address=address, port=port)) if os.name == "nt":
try:
loop.run_until_complete(run(server, address=address, port=port))
except KeyboardInterrupt:
pass
else:
loop.run_until_complete(run(server, address=address, port=port))
...@@ -605,7 +605,7 @@ class SaveImage: ...@@ -605,7 +605,7 @@ class SaveImage:
return {"required": return {"required":
{"images": ("IMAGE", ), {"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"})}, "filename_prefix": ("STRING", {"default": "ComfyUI"})},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "server": "SERVER", "unique_id": "UNIQUE_ID"},
} }
RETURN_TYPES = () RETURN_TYPES = ()
...@@ -615,7 +615,7 @@ class SaveImage: ...@@ -615,7 +615,7 @@ class SaveImage:
CATEGORY = "image" CATEGORY = "image"
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None, server=None, unique_id=None):
def map_filename(filename): def map_filename(filename):
prefix_len = len(filename_prefix) prefix_len = len(filename_prefix)
prefix = filename[:prefix_len + 1] prefix = filename[:prefix_len + 1]
...@@ -631,6 +631,8 @@ class SaveImage: ...@@ -631,6 +631,8 @@ class SaveImage:
except FileNotFoundError: except FileNotFoundError:
os.mkdir(self.output_dir) os.mkdir(self.output_dir)
counter = 1 counter = 1
paths = list()
for image in images: for image in images:
i = 255. * image.cpu().numpy() i = 255. * image.cpu().numpy()
img = Image.fromarray(i.astype(np.uint8)) img = Image.fromarray(i.astype(np.uint8))
...@@ -640,8 +642,12 @@ class SaveImage: ...@@ -640,8 +642,12 @@ class SaveImage:
if extra_pnginfo is not None: if extra_pnginfo is not None:
for x in extra_pnginfo: for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x])) metadata.add_text(x, json.dumps(extra_pnginfo[x]))
img.save(os.path.join(self.output_dir, f"{filename_prefix}_{counter:05}_.png"), pnginfo=metadata, optimize=True) file = f"{filename_prefix}_{counter:05}_.png"
img.save(os.path.join(self.output_dir, file), pnginfo=metadata, optimize=True)
paths.append(f"/view/{file}")
counter += 1 counter += 1
if server is not None:
server.send_sync("image", {"images": paths, "id": unique_id})
class LoadImage: class LoadImage:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
......
import os
import sys
import asyncio
import nodes
import main
try:
import aiohttp
from aiohttp import web
except ImportError:
print("Module 'aiohttp' not installed. Please install it via:")
print("pip install aiohttp")
print("or")
print("pip install -r requirements.txt")
sys.exit()
try:
import socketio
except ImportError:
print("Module 'python-socketio' not installed. Please install it via:")
print("pip install python-socketio")
print("or")
print("pip install -r requirements.txt")
sys.exit()
class PromptServer():
def __init__(self, loop):
self.prompt_queue = None
self.loop = loop
self.messages = asyncio.Queue()
self.number = 0
self.app = web.Application()
self.sio = socketio.AsyncServer()
self.sio.attach(self.app)
self.web_root = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "webshit")
routes = web.RouteTableDef()
@self.sio.event
async def connect(sid, environ):
await self.sio.emit("status", self.get_queue_info(), sid)
@routes.get("/")
async def get_root(request):
return web.FileResponse(os.path.join(self.web_root, "index.html"))
@routes.get("/view/{file}")
async def view_image(request):
if "file" in request.match_info:
output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
file = request.match_info["file"]
file = os.path.splitext(os.path.basename(file))[0] + ".png"
file = os.path.join(output_dir, file)
if os.path.isfile(file):
return web.FileResponse(file)
return web.Response(status=404)
@routes.get("/prompt")
async def get_prompt(request):
return web.json_response(self.get_queue_info())
@routes.get("/object_info")
async def get_object_info(request):
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:
obj_class = nodes.NODE_CLASS_MAPPINGS[x]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
info['name'] = x #TODO
info['description'] = ''
info['category'] = 'sd'
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
out[x] = info
return web.json_response(out)
@routes.get("/queue")
async def get_queue(request):
queue_info = {}
current_queue = self.prompt_queue.get_current_queue()
queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1]
return web.json_response(queue_info)
@routes.post("/prompt")
async def post_prompt(request):
print("got prompt")
resp_code = 200
out_string = ""
json_data = await request.json()
if "number" in json_data:
number = float(json_data['number'])
else:
number = self.number
if "front" in json_data:
if json_data['front']:
number = -number
self.number += 1
if "prompt" in json_data:
prompt = json_data["prompt"]
valid = main.validate_prompt(prompt)
extra_data = {}
if "extra_data" in json_data:
extra_data = json_data["extra_data"]
if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"]
if valid[0]:
self.prompt_queue.put((number, id(prompt), prompt, extra_data))
else:
resp_code = 400
out_string = valid[1]
print("invalid prompt:", valid[1])
return web.Response(body=out_string, status=resp_code)
@routes.post("/queue")
async def post_queue(request):
json_data = await request.json()
if "clear" in json_data:
if json_data["clear"]:
self.prompt_queue.wipe_queue()
if "delete" in json_data:
to_delete = json_data['delete']
for id_to_delete in to_delete:
delete_func = lambda a: a[1] == int(id_to_delete)
self.prompt_queue.delete_queue_item(delete_func)
return web.Response(status=200)
self.app.add_routes(routes)
self.app.add_routes([
web.static('/', self.web_root),
])
def get_queue_info(self):
prompt_info = {}
exec_info = {}
exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
prompt_info['exec_info'] = exec_info
return prompt_info
async def send(self, event, data, sid=None):
await self.sio.emit(event, data, to=sid)
def send_sync(self, event, data, sid=None):
self.loop.call_soon_threadsafe(
self.messages.put_nowait, (event, data, sid))
def queue_updated(self):
self.send_sync("status", self.get_queue_info())
async def publish_loop(self):
while True:
msg = await self.messages.get()
await self.send(*msg)
async def start(self, address, port):
runner = web.AppRunner(self.app)
await runner.setup()
site = web.TCPSite(runner, address, port)
await site.start()
if address == '':
address = '0.0.0.0'
print("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(address, port))
\ No newline at end of file
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
<head> <head>
<link rel="stylesheet" type="text/css" href="litegraph.css"> <link rel="stylesheet" type="text/css" href="litegraph.css">
<script type="text/javascript" src="litegraph.core.js"></script> <script type="text/javascript" src="litegraph.core.js"></script>
<script type="text/javascript" src="socket.io.min.js"></script>
</head> </head>
<style> <style>
.customtext_input { .customtext_input {
...@@ -26,7 +27,7 @@ ...@@ -26,7 +27,7 @@
left: 50%; /* Center the modal horizontally */ left: 50%; /* Center the modal horizontally */
top: 50%; /* Center the modal vertically */ top: 50%; /* Center the modal vertically */
transform: translate(-50%, -50%); /* Use this to center the modal */ transform: translate(-50%, -50%); /* Use this to center the modal */
width: 50%; /* Set a width for the modal */ min-width: 50%; /* Set a width for the modal */
height: auto; /* Set a height for the modal */ height: auto; /* Set a height for the modal */
padding: 30px; padding: 30px;
background-color: #ff0000; /* Modal background */ background-color: #ff0000; /* Modal background */
...@@ -58,6 +59,18 @@ ...@@ -58,6 +59,18 @@
white-space: pre-line; /* This will respect line breaks */ white-space: pre-line; /* This will respect line breaks */
margin-bottom: 20px; /* Add some margin between the text and the close button*/ margin-bottom: 20px; /* Add some margin between the text and the close button*/
} }
#modal-text img {
max-width: calc(100vw - 96px - 36px);
max-height: calc(100vh - 96px - 36px);
}
#images img {
width: 100%;
max-height: 300px;
object-fit: contain;
cursor: pointer;
}
</style> </style>
<div id="myErrorModal" class="modal"> <div id="myErrorModal" class="modal">
<div class="modal-content"> <div class="modal-content">
...@@ -65,7 +78,6 @@ ...@@ -65,7 +78,6 @@
<span class="close">CLOSE</span> <span class="close">CLOSE</span>
</div> </div>
</div> </div>
<canvas id='mycanvas' width='1000' height='1000' style='width: 100%; height: 100%;'></canvas> <canvas id='mycanvas' width='1000' height='1000' style='width: 100%; height: 100%;'></canvas>
<script> <script>
...@@ -75,6 +87,7 @@ var canvas = new LGraphCanvas("#mycanvas", graph); ...@@ -75,6 +87,7 @@ var canvas = new LGraphCanvas("#mycanvas", graph);
const ccc = document.getElementById("mycanvas"); const ccc = document.getElementById("mycanvas");
const ctx = ccc.getContext("2d"); const ctx = ccc.getContext("2d");
let images = {}
// Resize the canvas to match the size of the canvas element // Resize the canvas to match the size of the canvas element
function resizeCanvas() { function resizeCanvas() {
...@@ -132,6 +145,7 @@ function onObjectInfo(json) { ...@@ -132,6 +145,7 @@ function onObjectInfo(json) {
this._widgets = [] this._widgets = []
min_height = 1; min_height = 1;
min_width = 1; min_width = 1;
for (let x in inp) { for (let x in inp) {
let default_val = min_val = max_val = step_val = multiline = dynamic_prompt = undefined; let default_val = min_val = max_val = step_val = multiline = dynamic_prompt = undefined;
if (inp[x].length > 1) { if (inp[x].length > 1) {
...@@ -261,6 +275,48 @@ function onObjectInfo(json) { ...@@ -261,6 +275,48 @@ function onObjectInfo(json) {
} else { } else {
this.addInput(x, type); this.addInput(x, type);
} }
if(key === "SaveImage") {
MyNode.prototype.onDrawBackground = function(ctx) {
if(this.id + "" in images) {
const src = images[this.id + ""][0];
if(this.src !== src) {
this.img = null;
this.src = src;
const img = new Image();
img.src = src;
img.onload = () => {
graph.setDirtyCanvas(true);
this.img = img;
if(this.size[1] < 100) {
this.size[1] = 250;
}
}
}
if(this.img) {
let w = this.img.naturalWidth;
let h = this.img.naturalHeight;
let dw = this.size[0];
let dh = this.size[1];
const scaleX = dw / w;
const scaleY = dh / h;
const scale = Math.min(scaleX, scaleY, 1);
w *= scale;
h *= scale;
let x = (dw - w) / 2;
let y = (dh - h) / 2;
ctx.drawImage(this.img, x, y, w, h);
}
} else {
this.size[1] = 58
}
};
}
} }
out = j['output']; out = j['output'];
...@@ -388,19 +444,19 @@ function graphToPrompt() { ...@@ -388,19 +444,19 @@ function graphToPrompt() {
function closeModal() { function closeModal() {
var modal = document.getElementById("myErrorModal"); var modal = document.getElementById("myErrorModal");
modal.style.display = "none"; modal.setAttribute("style", "");
} }
function showModal(text) { function showModal(text) {
var modal = document.getElementById("myErrorModal"); var modal = document.getElementById("myErrorModal");
var modalText = document.getElementById("modal-text"); var modalText = document.getElementById("modal-text");
modalText.innerHTML = text; modalText.innerHTML = text;
modal.style.display = "block"; modal.setAttribute("style", "display: block");
var closeBtn = modal.getElementsByClassName("close")[0]; var closeBtn = modal.getElementsByClassName("close")[0];
closeBtn.onclick = function(event) {closeModal();} closeBtn.onclick = function(event) {closeModal();}
return modal
} }
function promptPosted(data) function promptPosted(data)
{ {
if (data.status != 200) { if (data.status != 200) {
...@@ -431,7 +487,7 @@ function promptPosted(data) ...@@ -431,7 +487,7 @@ function promptPosted(data)
function postPrompt(number) { function postPrompt(number) {
let prompt = graphToPrompt(); let prompt = graphToPrompt();
let full_data = {prompt: prompt, extra_data: {extra_pnginfo: {workflow: graph.serialize()}}}; let full_data = {client_id: clientId, prompt: prompt, extra_data: {extra_pnginfo: {workflow: graph.serialize()}}};
if (number == -1) { if (number == -1) {
full_data.front = true; full_data.front = true;
} else } else
...@@ -534,64 +590,105 @@ document.addEventListener('drop', (event) => { ...@@ -534,64 +590,105 @@ document.addEventListener('drop', (event) => {
prompt_file_load(file); prompt_file_load(file);
}); });
let runningNodeId = null;
let progress = null;
let clientId = null;
const orig = LGraphCanvas.prototype.drawNodeShape;
LGraphCanvas.prototype.drawNodeShape = function(node, ctx, size, fgcolor, bgcolor, selected, mouse_over) {
const res = orig.apply(this, arguments);
if(node.id + "" === runningNodeId) {
const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE;
ctx.lineWidth = 1;
ctx.globalAlpha = 0.8;
ctx.beginPath();
if( shape == LiteGraph.BOX_SHAPE )
ctx.rect(-6,-6 + LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0]+1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT );
else if (shape == LiteGraph.ROUND_SHAPE || (shape == LiteGraph.CARD_SHAPE && node.flags.collapsed) )
ctx.roundRect(-6,-6 - LiteGraph.NODE_TITLE_HEIGHT, 12 +size[0]+1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT , this.round_radius * 2);
else if (shape == LiteGraph.CARD_SHAPE)
ctx.roundRect(-6,-6 + LiteGraph.NODE_TITLE_HEIGHT, 12 +size[0]+1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT , this.round_radius * 2, 2);
else if (shape == LiteGraph.CIRCLE_SHAPE)
ctx.arc(size[0] * 0.5, size[1] * 0.5, size[0] * 0.5 + 6, 0, Math.PI*2);
ctx.strokeStyle = "#0f0"
ctx.stroke();
ctx.strokeStyle = fgcolor;
ctx.globalAlpha = 1;
if(progress) {
ctx.fillStyle = "green";
ctx.fillRect(0, 0, size[0] * (progress.value / progress.max), 6);
ctx.fillStyle = bgcolor;
}
}
return res;
}
function updateNodeProgress(v) {
progress = v;
graph.setDirtyCanvas(true, false);
}
function setRunningNode(id) {
progress = null;
runningNodeId = id;
graph.setDirtyCanvas(true, false);
}
(() => { (() => {
function updateStatus(data) { function updateStatus(data) {
document.getElementById("queuesize").innerHTML = "Queue size: " + (data ? data.exec_info.queue_remaining : "ERR"); document.getElementById("queuesize").innerHTML = "Queue size: " + (data ? data.exec_info.queue_remaining : "ERR");
} }
//fix for colab and other things that don't support websockets.
function manually_fetch_queue() {
fetch('/prompt')
.then(response => response.json())
.then(data => {
updateStatus(data);
}).catch((response) => {updateStatus(null)});
}
let ws;
function createSocket(isReconnect) { function createSocket(isReconnect) {
if(ws) return;
let opened = false; let opened = false;
ws = new WebSocket(`ws${window.location.protocol === "https:"? "s" : ""}://${location.host}/ws`); const ws = io();
ws.addEventListener("open", () => { ws.on("connect", () => {
opened = true; clientId = ws.id;
if(isReconnect) {
if(opened) {
closeModal(); closeModal();
} else {
opened = true;
} }
}); });
ws.addEventListener("error", () => { ws.on("disconnect", () => {
if(ws) ws.close();
manually_fetch_queue();
});
ws.addEventListener("close", () => {
setTimeout(() => {
ws = null;
createSocket(true);
}, 300);
if(opened) { if(opened) {
updateStatus(null); updateStatus(null);
showModal("Reconnecting..."); showModal("Reconnecting...");
} }
}); });
ws.addEventListener("message", (event) => { ws.on("status", (data) => {
try { updateStatus(data);
const msg = JSON.parse(event.data); });
switch(msg.type) {
case "status": ws.on("progress", (data) => {
updateStatus(msg.status) updateNodeProgress(data);
break; });
default:
throw new Error("Unknown message type") ws.on("execute", (data) => {
setRunningNode(data.node);
});
ws.on("image", (data) => {
images[data.id] = data.images;
const container = document.getElementById("images");
container.replaceChildren(...Object.values(images).map(src => {
const img = document.createElement("img");
img.src = src;
img.onclick = () => {
const modal = showModal();
const modalText = document.getElementById("modal-text");
modalText.innerHTML = `<img src="${img.src}"/>`
modal.setAttribute("style", modal.getAttribute("style") + "; background: #202020")
} }
} catch (error) { return img;
console.warn("Unhandled message:", event.data) }))
}
}); });
} }
createSocket(); createSocket();
...@@ -766,7 +863,7 @@ function clearQueue() { ...@@ -766,7 +863,7 @@ function clearQueue() {
</script> </script>
<span style="font-size: 15px;position: absolute; top: 50%; right: 0%; background-color: white; text-align: center; z-index: 100;width:170px;"> <span style="font-size: 15px;position: absolute; top: 50%; right: 0%; background-color: white; text-align: center; z-index: 100;width:170px; transform: translateY(-50%);">
<span id="queuesize">Queue size: X</span><br> <span id="queuesize">Queue size: X</span><br>
<button style="font-size: 20px;width: 100%;" id="queuebutton" onclick="postPrompt(0)">Queue Prompt</button><br> <button style="font-size: 20px;width: 100%;" id="queuebutton" onclick="postPrompt(0)">Queue Prompt</button><br>
<span style="left: 0%;"> <span style="left: 0%;">
...@@ -801,6 +898,7 @@ function clearQueue() { ...@@ -801,6 +898,7 @@ function clearQueue() {
<br> <br>
<button style="font-size: 20px;" onclick="clearGraph()">Clear</button><br> <button style="font-size: 20px;" onclick="clearGraph()">Clear</button><br>
<button style="font-size: 20px;" onclick="loadTxt2Img()">Load Default</button><br> <button style="font-size: 20px;" onclick="loadTxt2Img()">Load Default</button><br>
<div id="images"></div>
</span> </span>
</body> </body>
</html> </html>
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment