Unverified Commit a52aa9f4 authored by pythongosssss's avatar pythongosssss Committed by GitHub
Browse files

Moved api out to server

Reworked sockets to use socketio
Added progress to nodes
Added highlight to active node
Added preview to saveimage node
parent 92808718
......@@ -11,15 +11,7 @@ if os.name == "nt":
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
try:
import aiohttp
from aiohttp import web
except ImportError:
print("Module 'aiohttp' not installed. Please install it via:")
print("pip install aiohttp")
print("or")
print("pip install -r requirements.txt")
sys.exit()
import server
if __name__ == "__main__":
if '--help' in sys.argv:
......@@ -36,14 +28,14 @@ if __name__ == "__main__":
print()
exit()
if '--dont-upcast-attention' in sys.argv:
print("disabling upcasting of attention")
os.environ['ATTN_PRECISION'] = "fp16"
if '--dont-upcast-attention' in sys.argv:
print("disabling upcasting of attention")
os.environ['ATTN_PRECISION'] = "fp16"
import torch
import nodes
def get_input_data(inputs, class_def, outputs={}, prompt={}, extra_data={}):
def get_input_data(inputs, class_def, outputs={}, prompt={}, extra_data={}, server=None, unique_id=None):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
for x in inputs:
......@@ -65,9 +57,13 @@ def get_input_data(inputs, class_def, outputs={}, prompt={}, extra_data={}):
if h[x] == "EXTRA_PNGINFO":
if "extra_pnginfo" in extra_data:
input_data_all[x] = extra_data['extra_pnginfo']
if h[x] == "SERVER":
input_data_all[x] = server
if h[x] == "UNIQUE_ID":
input_data_all[x] = unique_id
return input_data_all
def recursive_execute(prompt, outputs, current_item, extra_data={}):
def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
......@@ -84,9 +80,11 @@ def recursive_execute(prompt, outputs, current_item, extra_data={}):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
executed += recursive_execute(prompt, outputs, input_unique_id, extra_data)
executed += recursive_execute(server, prompt, outputs, input_unique_id, extra_data)
input_data_all = get_input_data(inputs, class_def, outputs, prompt, extra_data)
input_data_all = get_input_data(inputs, class_def, outputs, prompt, extra_data, server, unique_id)
if server.client_id is not None:
server.send_sync("execute", { "node": unique_id }, server.client_id)
obj = class_def()
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
......@@ -157,11 +155,17 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
return to_delete
class PromptExecutor:
def __init__(self):
def __init__(self, server):
self.outputs = {}
self.old_prompt = {}
self.server = server
def execute(self, prompt, extra_data={}):
if "client_id" in extra_data:
self.server.client_id = extra_data["client_id"]
else:
self.server.client_id = None
with torch.no_grad():
for x in prompt:
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
......@@ -190,7 +194,7 @@ class PromptExecutor:
except:
valid = False
if valid:
executed += recursive_execute(prompt, self.outputs, x, extra_data)
executed += recursive_execute(self.server, prompt, self.outputs, x, extra_data)
except Exception as e:
print(traceback.format_exc())
......@@ -208,6 +212,11 @@ class PromptExecutor:
executed = set(executed)
for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x])
finally:
if self.server.client_id is not None:
self.server.send_sync("execute", { "node": None }, self.server.client_id)
torch.cuda.empty_cache()
def validate_inputs(prompt, item):
......@@ -293,27 +302,27 @@ def validate_prompt(prompt):
return (True, "")
def prompt_worker(q):
e = PromptExecutor()
def prompt_worker(q, server):
e = PromptExecutor(server)
while True:
item, item_id = q.get()
e.execute(item[-2], item[-1])
q.task_done(item_id)
class PromptQueue:
def __init__(self, socket_handler):
self.socket_handler = socket_handler
def __init__(self, server):
self.server = server
self.mutex = threading.RLock()
self.not_empty = threading.Condition(self.mutex)
self.task_counter = 0
self.queue = []
self.currently_running = {}
socket_handler.prompt_queue = self
server.prompt_queue = self
def put(self, item):
with self.mutex:
heapq.heappush(self.queue, item)
self.socket_handler.queue_updated(self)
self.server.queue_updated()
self.not_empty.notify()
def get(self):
......@@ -324,13 +333,13 @@ class PromptQueue:
i = self.task_counter
self.currently_running[i] = copy.deepcopy(item)
self.task_counter += 1
self.socket_handler.queue_updated(self)
self.server.queue_updated()
return (item, i)
def task_done(self, item_id):
with self.mutex:
self.currently_running.pop(item_id)
self.socket_handler.queue_updated(self)
self.server.queue_updated()
def get_current_queue(self):
with self.mutex:
......@@ -346,7 +355,7 @@ class PromptQueue:
def wipe_queue(self):
with self.mutex:
self.queue = []
self.socket_handler.queue_updated(self)
self.server.queue_updated()
def delete_queue_item(self, function):
with self.mutex:
......@@ -357,174 +366,32 @@ class PromptQueue:
else:
self.queue.pop(x)
heapq.heapify(self.queue)
self.socket_handler.queue_updated(self)
self.server.queue_updated()
return True
return False
def get_queue_info(prompt_queue):
prompt_info = {}
exec_info = {}
exec_info['queue_remaining'] = prompt_queue.get_tasks_remaining()
prompt_info['exec_info'] = exec_info
return prompt_info
class SocketHandler():
def __init__(self, loop):
self.connected = set()
self.messages = asyncio.Queue()
self.loop = loop
async def publish_loop(self):
while True:
msg = await self.messages.get()
await self.send(msg)
def queue_updated(self, queue):
# This is called by the queue processing thread so we need to make it thread safe
loop.call_soon_threadsafe(self.messages.put_nowait, { 'type': 'status', 'status': get_queue_info(queue) })
async def send(self, message, socket = None):
if isinstance(message, str) == False:
message = json.dumps(message)
if socket is None:
for ws in self.connected:
await ws.send_str(message)
else:
await socket.send_str(message)
async def process(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)
self.connected.add(ws)
try:
# Send initial state to the new client
await self.send({ 'type': 'status', 'status': get_queue_info(self.prompt_queue) }, ws)
async for msg in ws:
if msg.type == aiohttp.WSMsgType.ERROR:
print('ws connection closed with exception %s' % ws.exception())
finally:
self.connected.remove(ws)
return ws
class PromptServer():
def __init__(self, prompt_queue, socket_handler):
self.prompt_queue = prompt_queue
self.socket_handler = socket_handler
self.number = 0
self.app = web.Application()
self.web_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "webshit")
routes = web.RouteTableDef()
@routes.get('/ws')
async def websocket_handler(request):
return await self.socket_handler.process(request)
@routes.get("/")
async def get_root(request):
return web.FileResponse(os.path.join(self.web_root, "index.html"))
@routes.get("/prompt")
async def get_prompt(request):
return web.json_response(get_queue_info(self.prompt_queue))
@routes.get("/object_info")
async def get_object_info(request):
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:
obj_class = nodes.NODE_CLASS_MAPPINGS[x]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
info['name'] = x #TODO
info['description'] = ''
info['category'] = 'sd'
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
out[x] = info
return web.json_response(out)
@routes.get("/queue")
async def get_queue(request):
queue_info = {}
current_queue = self.prompt_queue.get_current_queue()
queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1]
return web.json_response(queue_info)
@routes.post("/prompt")
async def post_prompt(request):
print("got prompt")
resp_code = 200
out_string = ""
json_data = await request.json()
if "number" in json_data:
number = float(json_data['number'])
else:
number = self.number
if "front" in json_data:
if json_data['front']:
number = -number
self.number += 1
if "prompt" in json_data:
prompt = json_data["prompt"]
valid = validate_prompt(prompt)
extra_data = {}
if "extra_data" in json_data:
extra_data = json_data["extra_data"]
if valid[0]:
self.prompt_queue.put((number, id(prompt), prompt, extra_data))
else:
resp_code = 400
out_string = valid[1]
print("invalid prompt:", valid[1])
return web.Response(body=out_string, status=resp_code)
@routes.post("/queue")
async def post_queue(request):
json_data = await request.json()
if "clear" in json_data:
if json_data["clear"]:
self.prompt_queue.wipe_queue()
if "delete" in json_data:
to_delete = json_data['delete']
for id_to_delete in to_delete:
delete_func = lambda a: a[1] == int(id_to_delete)
self.prompt_queue.delete_queue_item(delete_func)
return web.Response(status=200)
self.app.add_routes(routes)
self.app.add_routes([
web.static('/', self.web_root),
])
async def start_server(server, address, port):
runner = web.AppRunner(server.app)
await runner.setup()
site = web.TCPSite(runner, address, port)
await site.start()
if address == '':
address = '0.0.0.0'
print("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(address, port))
async def run(server, address='', port=8188):
await asyncio.gather(server.start(address, port), server.publish_loop())
async def run(prompt_queue, socket_handler, address='', port=8188):
server = PromptServer(prompt_queue, socket_handler)
await asyncio.gather(start_server(server, address, port), socket_handler.publish_loop())
def hijack_progress(server):
from tqdm.auto import tqdm
orig_func = getattr(tqdm, "update")
def wrapped_func(*args, **kwargs):
pbar = args[0]
v = orig_func(*args, **kwargs)
server.send_sync("progress", { "value": pbar.n, "max": pbar.total}, server.client_id)
return v
setattr(tqdm, "update", wrapped_func)
if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server = server.PromptServer(loop)
q = PromptQueue(server)
socket_handler = SocketHandler(loop)
q = PromptQueue(socket_handler)
threading.Thread(target=prompt_worker, daemon=True, args=(q,)).start()
hijack_progress(server)
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
if '--listen' in sys.argv:
address = '0.0.0.0'
else:
......@@ -537,5 +404,11 @@ if __name__ == "__main__":
except:
pass
loop.run_until_complete(run(q, socket_handler, address=address, port=port))
if os.name == "nt":
try:
loop.run_until_complete(run(server, address=address, port=port))
except KeyboardInterrupt:
pass
else:
loop.run_until_complete(run(server, address=address, port=port))
......@@ -605,7 +605,7 @@ class SaveImage:
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"})},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "server": "SERVER", "unique_id": "UNIQUE_ID"},
}
RETURN_TYPES = ()
......@@ -615,7 +615,7 @@ class SaveImage:
CATEGORY = "image"
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None, server=None, unique_id=None):
def map_filename(filename):
prefix_len = len(filename_prefix)
prefix = filename[:prefix_len + 1]
......@@ -631,6 +631,8 @@ class SaveImage:
except FileNotFoundError:
os.mkdir(self.output_dir)
counter = 1
paths = list()
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(i.astype(np.uint8))
......@@ -640,8 +642,12 @@ class SaveImage:
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
img.save(os.path.join(self.output_dir, f"{filename_prefix}_{counter:05}_.png"), pnginfo=metadata, optimize=True)
file = f"{filename_prefix}_{counter:05}_.png"
img.save(os.path.join(self.output_dir, file), pnginfo=metadata, optimize=True)
paths.append(f"/view/{file}")
counter += 1
if server is not None:
server.send_sync("image", {"images": paths, "id": unique_id})
class LoadImage:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
......
import os
import sys
import asyncio
import nodes
import main
try:
import aiohttp
from aiohttp import web
except ImportError:
print("Module 'aiohttp' not installed. Please install it via:")
print("pip install aiohttp")
print("or")
print("pip install -r requirements.txt")
sys.exit()
try:
import socketio
except ImportError:
print("Module 'python-socketio' not installed. Please install it via:")
print("pip install python-socketio")
print("or")
print("pip install -r requirements.txt")
sys.exit()
class PromptServer():
def __init__(self, loop):
self.prompt_queue = None
self.loop = loop
self.messages = asyncio.Queue()
self.number = 0
self.app = web.Application()
self.sio = socketio.AsyncServer()
self.sio.attach(self.app)
self.web_root = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "webshit")
routes = web.RouteTableDef()
@self.sio.event
async def connect(sid, environ):
await self.sio.emit("status", self.get_queue_info(), sid)
@routes.get("/")
async def get_root(request):
return web.FileResponse(os.path.join(self.web_root, "index.html"))
@routes.get("/view/{file}")
async def view_image(request):
if "file" in request.match_info:
output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
file = request.match_info["file"]
file = os.path.splitext(os.path.basename(file))[0] + ".png"
file = os.path.join(output_dir, file)
if os.path.isfile(file):
return web.FileResponse(file)
return web.Response(status=404)
@routes.get("/prompt")
async def get_prompt(request):
return web.json_response(self.get_queue_info())
@routes.get("/object_info")
async def get_object_info(request):
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:
obj_class = nodes.NODE_CLASS_MAPPINGS[x]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
info['name'] = x #TODO
info['description'] = ''
info['category'] = 'sd'
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
out[x] = info
return web.json_response(out)
@routes.get("/queue")
async def get_queue(request):
queue_info = {}
current_queue = self.prompt_queue.get_current_queue()
queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1]
return web.json_response(queue_info)
@routes.post("/prompt")
async def post_prompt(request):
print("got prompt")
resp_code = 200
out_string = ""
json_data = await request.json()
if "number" in json_data:
number = float(json_data['number'])
else:
number = self.number
if "front" in json_data:
if json_data['front']:
number = -number
self.number += 1
if "prompt" in json_data:
prompt = json_data["prompt"]
valid = main.validate_prompt(prompt)
extra_data = {}
if "extra_data" in json_data:
extra_data = json_data["extra_data"]
if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"]
if valid[0]:
self.prompt_queue.put((number, id(prompt), prompt, extra_data))
else:
resp_code = 400
out_string = valid[1]
print("invalid prompt:", valid[1])
return web.Response(body=out_string, status=resp_code)
@routes.post("/queue")
async def post_queue(request):
json_data = await request.json()
if "clear" in json_data:
if json_data["clear"]:
self.prompt_queue.wipe_queue()
if "delete" in json_data:
to_delete = json_data['delete']
for id_to_delete in to_delete:
delete_func = lambda a: a[1] == int(id_to_delete)
self.prompt_queue.delete_queue_item(delete_func)
return web.Response(status=200)
self.app.add_routes(routes)
self.app.add_routes([
web.static('/', self.web_root),
])
def get_queue_info(self):
prompt_info = {}
exec_info = {}
exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
prompt_info['exec_info'] = exec_info
return prompt_info
async def send(self, event, data, sid=None):
await self.sio.emit(event, data, to=sid)
def send_sync(self, event, data, sid=None):
self.loop.call_soon_threadsafe(
self.messages.put_nowait, (event, data, sid))
def queue_updated(self):
self.send_sync("status", self.get_queue_info())
async def publish_loop(self):
while True:
msg = await self.messages.get()
await self.send(*msg)
async def start(self, address, port):
runner = web.AppRunner(self.app)
await runner.setup()
site = web.TCPSite(runner, address, port)
await site.start()
if address == '':
address = '0.0.0.0'
print("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(address, port))
\ No newline at end of file
......@@ -2,6 +2,7 @@
<head>
<link rel="stylesheet" type="text/css" href="litegraph.css">
<script type="text/javascript" src="litegraph.core.js"></script>
<script type="text/javascript" src="socket.io.min.js"></script>
</head>
<style>
.customtext_input {
......@@ -26,7 +27,7 @@
left: 50%; /* Center the modal horizontally */
top: 50%; /* Center the modal vertically */
transform: translate(-50%, -50%); /* Use this to center the modal */
width: 50%; /* Set a width for the modal */
min-width: 50%; /* Set a width for the modal */
height: auto; /* Set a height for the modal */
padding: 30px;
background-color: #ff0000; /* Modal background */
......@@ -58,6 +59,18 @@
white-space: pre-line; /* This will respect line breaks */
margin-bottom: 20px; /* Add some margin between the text and the close button*/
}
#modal-text img {
max-width: calc(100vw - 96px - 36px);
max-height: calc(100vh - 96px - 36px);
}
#images img {
width: 100%;
max-height: 300px;
object-fit: contain;
cursor: pointer;
}
</style>
<div id="myErrorModal" class="modal">
<div class="modal-content">
......@@ -65,7 +78,6 @@
<span class="close">CLOSE</span>
</div>
</div>
<canvas id='mycanvas' width='1000' height='1000' style='width: 100%; height: 100%;'></canvas>
<script>
......@@ -75,6 +87,7 @@ var canvas = new LGraphCanvas("#mycanvas", graph);
const ccc = document.getElementById("mycanvas");
const ctx = ccc.getContext("2d");
let images = {}
// Resize the canvas to match the size of the canvas element
function resizeCanvas() {
......@@ -132,6 +145,7 @@ function onObjectInfo(json) {
this._widgets = []
min_height = 1;
min_width = 1;
for (let x in inp) {
let default_val = min_val = max_val = step_val = multiline = dynamic_prompt = undefined;
if (inp[x].length > 1) {
......@@ -261,6 +275,48 @@ function onObjectInfo(json) {
} else {
this.addInput(x, type);
}
if(key === "SaveImage") {
MyNode.prototype.onDrawBackground = function(ctx) {
if(this.id + "" in images) {
const src = images[this.id + ""][0];
if(this.src !== src) {
this.img = null;
this.src = src;
const img = new Image();
img.src = src;
img.onload = () => {
graph.setDirtyCanvas(true);
this.img = img;
if(this.size[1] < 100) {
this.size[1] = 250;
}
}
}
if(this.img) {
let w = this.img.naturalWidth;
let h = this.img.naturalHeight;
let dw = this.size[0];
let dh = this.size[1];
const scaleX = dw / w;
const scaleY = dh / h;
const scale = Math.min(scaleX, scaleY, 1);
w *= scale;
h *= scale;
let x = (dw - w) / 2;
let y = (dh - h) / 2;
ctx.drawImage(this.img, x, y, w, h);
}
} else {
this.size[1] = 58
}
};
}
}
out = j['output'];
......@@ -388,19 +444,19 @@ function graphToPrompt() {
function closeModal() {
var modal = document.getElementById("myErrorModal");
modal.style.display = "none";
modal.setAttribute("style", "");
}
function showModal(text) {
var modal = document.getElementById("myErrorModal");
var modalText = document.getElementById("modal-text");
modalText.innerHTML = text;
modal.style.display = "block";
modal.setAttribute("style", "display: block");
var closeBtn = modal.getElementsByClassName("close")[0];
closeBtn.onclick = function(event) {closeModal();}
return modal
}
function promptPosted(data)
{
if (data.status != 200) {
......@@ -431,7 +487,7 @@ function promptPosted(data)
function postPrompt(number) {
let prompt = graphToPrompt();
let full_data = {prompt: prompt, extra_data: {extra_pnginfo: {workflow: graph.serialize()}}};
let full_data = {client_id: clientId, prompt: prompt, extra_data: {extra_pnginfo: {workflow: graph.serialize()}}};
if (number == -1) {
full_data.front = true;
} else
......@@ -534,64 +590,105 @@ document.addEventListener('drop', (event) => {
prompt_file_load(file);
});
let runningNodeId = null;
let progress = null;
let clientId = null;
const orig = LGraphCanvas.prototype.drawNodeShape;
LGraphCanvas.prototype.drawNodeShape = function(node, ctx, size, fgcolor, bgcolor, selected, mouse_over) {
const res = orig.apply(this, arguments);
if(node.id + "" === runningNodeId) {
const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE;
ctx.lineWidth = 1;
ctx.globalAlpha = 0.8;
ctx.beginPath();
if( shape == LiteGraph.BOX_SHAPE )
ctx.rect(-6,-6 + LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0]+1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT );
else if (shape == LiteGraph.ROUND_SHAPE || (shape == LiteGraph.CARD_SHAPE && node.flags.collapsed) )
ctx.roundRect(-6,-6 - LiteGraph.NODE_TITLE_HEIGHT, 12 +size[0]+1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT , this.round_radius * 2);
else if (shape == LiteGraph.CARD_SHAPE)
ctx.roundRect(-6,-6 + LiteGraph.NODE_TITLE_HEIGHT, 12 +size[0]+1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT , this.round_radius * 2, 2);
else if (shape == LiteGraph.CIRCLE_SHAPE)
ctx.arc(size[0] * 0.5, size[1] * 0.5, size[0] * 0.5 + 6, 0, Math.PI*2);
ctx.strokeStyle = "#0f0"
ctx.stroke();
ctx.strokeStyle = fgcolor;
ctx.globalAlpha = 1;
if(progress) {
ctx.fillStyle = "green";
ctx.fillRect(0, 0, size[0] * (progress.value / progress.max), 6);
ctx.fillStyle = bgcolor;
}
}
return res;
}
function updateNodeProgress(v) {
progress = v;
graph.setDirtyCanvas(true, false);
}
function setRunningNode(id) {
progress = null;
runningNodeId = id;
graph.setDirtyCanvas(true, false);
}
(() => {
function updateStatus(data) {
document.getElementById("queuesize").innerHTML = "Queue size: " + (data ? data.exec_info.queue_remaining : "ERR");
}
//fix for colab and other things that don't support websockets.
function manually_fetch_queue() {
fetch('/prompt')
.then(response => response.json())
.then(data => {
updateStatus(data);
}).catch((response) => {updateStatus(null)});
}
let ws;
function createSocket(isReconnect) {
if(ws) return;
let opened = false;
ws = new WebSocket(`ws${window.location.protocol === "https:"? "s" : ""}://${location.host}/ws`);
const ws = io();
ws.addEventListener("open", () => {
opened = true;
if(isReconnect) {
ws.on("connect", () => {
clientId = ws.id;
if(opened) {
closeModal();
} else {
opened = true;
}
});
ws.addEventListener("error", () => {
if(ws) ws.close();
manually_fetch_queue();
});
ws.addEventListener("close", () => {
setTimeout(() => {
ws = null;
createSocket(true);
}, 300);
ws.on("disconnect", () => {
if(opened) {
updateStatus(null);
showModal("Reconnecting...");
}
});
ws.addEventListener("message", (event) => {
try {
const msg = JSON.parse(event.data);
switch(msg.type) {
case "status":
updateStatus(msg.status)
break;
default:
throw new Error("Unknown message type")
ws.on("status", (data) => {
updateStatus(data);
});
ws.on("progress", (data) => {
updateNodeProgress(data);
});
ws.on("execute", (data) => {
setRunningNode(data.node);
});
ws.on("image", (data) => {
images[data.id] = data.images;
const container = document.getElementById("images");
container.replaceChildren(...Object.values(images).map(src => {
const img = document.createElement("img");
img.src = src;
img.onclick = () => {
const modal = showModal();
const modalText = document.getElementById("modal-text");
modalText.innerHTML = `<img src="${img.src}"/>`
modal.setAttribute("style", modal.getAttribute("style") + "; background: #202020")
}
} catch (error) {
console.warn("Unhandled message:", event.data)
}
return img;
}))
});
}
createSocket();
......@@ -766,7 +863,7 @@ function clearQueue() {
</script>
<span style="font-size: 15px;position: absolute; top: 50%; right: 0%; background-color: white; text-align: center; z-index: 100;width:170px;">
<span style="font-size: 15px;position: absolute; top: 50%; right: 0%; background-color: white; text-align: center; z-index: 100;width:170px; transform: translateY(-50%);">
<span id="queuesize">Queue size: X</span><br>
<button style="font-size: 20px;width: 100%;" id="queuebutton" onclick="postPrompt(0)">Queue Prompt</button><br>
<span style="left: 0%;">
......@@ -801,6 +898,7 @@ function clearQueue() {
<br>
<button style="font-size: 20px;" onclick="clearGraph()">Clear</button><br>
<button style="font-size: 20px;" onclick="loadTxt2Img()">Load Default</button><br>
<div id="images"></div>
</span>
</body>
</html>
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment