"comfy/ldm/modules/vscode:/vscode.git/clone" did not exist on "bb4940d837f0cfd338ff64776b084303be066c67"
server.py 27.1 KB
Newer Older
pythongosssss's avatar
pythongosssss committed
1
2
3
import os
import sys
import asyncio
4
5
import traceback

pythongosssss's avatar
pythongosssss committed
6
import nodes
7
import folder_paths
8
import execution
pythongosssss's avatar
pythongosssss committed
9
import uuid
10
import urllib
pythongosssss's avatar
pythongosssss committed
11
import json
12
import glob
space-nuko's avatar
space-nuko committed
13
import struct
Garrett Sutula's avatar
Garrett Sutula committed
14
import ssl
15
from PIL import Image, ImageOps
Chris's avatar
Chris committed
16
from PIL.PngImagePlugin import PngInfo
17
18
from io import BytesIO

comfyanonymous's avatar
comfyanonymous committed
19
20
import aiohttp
from aiohttp import web
comfyanonymous's avatar
comfyanonymous committed
21
import logging
pythongosssss's avatar
pythongosssss committed
22

comfyanonymous's avatar
Style.  
comfyanonymous committed
23
import mimetypes
EllangoK's avatar
EllangoK committed
24
from comfy.cli_args import args
25
import comfy.utils
space-nuko's avatar
space-nuko committed
26
import comfy.model_management
27

28
from app.user_manager import UserManager
space-nuko's avatar
space-nuko committed
29
30
31

class BinaryEventTypes:
    PREVIEW_IMAGE = 1
32
    UNENCODED_PREVIEW_IMAGE = 2
space-nuko's avatar
space-nuko committed
33

34
35
36
37
async def send_socket_catch_exception(function, message):
    try:
        await function(message)
    except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
comfyanonymous's avatar
comfyanonymous committed
38
        logging.warning("send error: {}".format(err))
space-nuko's avatar
space-nuko committed
39

40
41
42
43
44
45
46
@web.middleware
async def cache_control(request: web.Request, handler):
    response: web.Response = await handler(request)
    if request.path.endswith('.js') or request.path.endswith('.css'):
        response.headers.setdefault('Cache-Control', 'no-cache')
    return response

EllangoK's avatar
EllangoK committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def create_cors_middleware(allowed_origin: str):
    @web.middleware
    async def cors_middleware(request: web.Request, handler):
        if request.method == "OPTIONS":
            # Pre-flight request. Reply successfully:
            response = web.Response()
        else:
            response = await handler(request)

        response.headers['Access-Control-Allow-Origin'] = allowed_origin
        response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
        response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
        response.headers['Access-Control-Allow-Credentials'] = 'true'
        return response

    return cors_middleware
EllangoK's avatar
EllangoK committed
63

pythongosssss's avatar
pythongosssss committed
64
65
class PromptServer():
    def __init__(self, loop):
66
        PromptServer.instance = self
comfyanonymous's avatar
Style.  
comfyanonymous committed
67

68
        mimetypes.init()
pythongosssss's avatar
pythongosssss committed
69
        mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
70

71
        self.user_manager = UserManager()
72
        self.supports = ["custom_nodes_from_web"]
pythongosssss's avatar
pythongosssss committed
73
74
75
76
        self.prompt_queue = None
        self.loop = loop
        self.messages = asyncio.Queue()
        self.number = 0
EllangoK's avatar
EllangoK committed
77
78

        middlewares = [cache_control]
79
80
        if args.enable_cors_header:
            middlewares.append(create_cors_middleware(args.enable_cors_header))
EllangoK's avatar
EllangoK committed
81

82
83
        max_upload_size = round(args.max_upload_size * 1024 * 1024)
        self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
pythongosssss's avatar
pythongosssss committed
84
85
        self.sockets = dict()
        self.web_root = os.path.join(os.path.dirname(
pythongosssss's avatar
pythongosssss committed
86
            os.path.realpath(__file__)), "web")
pythongosssss's avatar
pythongosssss committed
87
        routes = web.RouteTableDef()
88
        self.routes = routes
89
90
        self.last_node_id = None
        self.client_id = None
pythongosssss's avatar
pythongosssss committed
91

92
93
        self.on_prompt_handlers = []

pythongosssss's avatar
pythongosssss committed
94
95
96
97
        @routes.get('/ws')
        async def websocket_handler(request):
            ws = web.WebSocketResponse()
            await ws.prepare(request)
98
99
100
101
102
            sid = request.rel_url.query.get('clientId', '')
            if sid:
                # Reusing existing session, remove old
                self.sockets.pop(sid, None)
            else:
103
                sid = uuid.uuid4().hex
104

pythongosssss's avatar
pythongosssss committed
105
            self.sockets[sid] = ws
106

pythongosssss's avatar
pythongosssss committed
107
108
109
            try:
                # Send initial state to the new client
                await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid)
110
111
112
113
                # On reconnect if we are the currently executing client send the current node
                if self.client_id == sid and self.last_node_id is not None:
                    await self.send("executing", { "node": self.last_node_id }, sid)
                    
pythongosssss's avatar
pythongosssss committed
114
115
                async for msg in ws:
                    if msg.type == aiohttp.WSMsgType.ERROR:
comfyanonymous's avatar
comfyanonymous committed
116
                        logging.warning('ws connection closed with exception %s' % ws.exception())
pythongosssss's avatar
pythongosssss committed
117
            finally:
118
                self.sockets.pop(sid, None)
pythongosssss's avatar
pythongosssss committed
119
120
121
122
123
            return ws

        @routes.get("/")
        async def get_root(request):
            return web.FileResponse(os.path.join(self.web_root, "index.html"))
124

125
126
        @routes.get("/embeddings")
        def get_embeddings(self):
127
            embeddings = folder_paths.get_filename_list("embeddings")
128
            return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
129

130
131
        @routes.get("/extensions")
        async def get_extensions(request):
132
            files = glob.glob(os.path.join(
133
                glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True)
134
135
136
137
            
            extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))
            
            for name, dir in nodes.EXTENSION_WEB_DIRS.items():
138
                files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True)
139
140
141
142
                extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote(
                    name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))

            return web.json_response(extensions)
143

144
145
        def get_dir_by_type(dir_type):
            if dir_type is None:
146
147
148
                dir_type = "input"

            if dir_type == "input":
149
150
151
152
153
154
                type_dir = folder_paths.get_input_directory()
            elif dir_type == "temp":
                type_dir = folder_paths.get_temp_directory()
            elif dir_type == "output":
                type_dir = folder_paths.get_output_directory()

155
            return type_dir, dir_type
156

comfyanonymous's avatar
comfyanonymous committed
157
        def image_upload(post, image_save_function=None):
ltdrdata's avatar
ltdrdata committed
158
            image = post.get("image")
159
            overwrite = post.get("overwrite")
ltdrdata's avatar
ltdrdata committed
160

comfyanonymous's avatar
comfyanonymous committed
161
            image_upload_type = post.get("type")
162
            upload_dir, image_upload_type = get_dir_by_type(image_upload_type)
pythongosssss's avatar
pythongosssss committed
163
164
165
166
167
168

            if image and image.file:
                filename = image.filename
                if not filename:
                    return web.Response(status=400)

comfyanonymous's avatar
comfyanonymous committed
169
170
                subfolder = post.get("subfolder", "")
                full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder))
comfyanonymous's avatar
comfyanonymous committed
171
                filepath = os.path.abspath(os.path.join(full_output_folder, filename))
comfyanonymous's avatar
comfyanonymous committed
172

comfyanonymous's avatar
comfyanonymous committed
173
                if os.path.commonpath((upload_dir, filepath)) != upload_dir:
comfyanonymous's avatar
comfyanonymous committed
174
175
176
177
178
                    return web.Response(status=400)

                if not os.path.exists(full_output_folder):
                    os.makedirs(full_output_folder)

pythongosssss's avatar
pythongosssss committed
179
                split = os.path.splitext(filename)
comfyanonymous's avatar
comfyanonymous committed
180

181
182
183
184
185
186
187
188
                if overwrite is not None and (overwrite == "true" or overwrite == "1"):
                    pass
                else:
                    i = 1
                    while os.path.exists(filepath):
                        filename = f"{split[0]} ({i}){split[1]}"
                        filepath = os.path.join(full_output_folder, filename)
                        i += 1
pythongosssss's avatar
pythongosssss committed
189

comfyanonymous's avatar
comfyanonymous committed
190
191
192
193
194
                if image_save_function is not None:
                    image_save_function(image, post, filepath)
                else:
                    with open(filepath, "wb") as f:
                        f.write(image.file.read())
pythongosssss's avatar
pythongosssss committed
195

comfyanonymous's avatar
comfyanonymous committed
196
                return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
pythongosssss's avatar
pythongosssss committed
197
198
199
            else:
                return web.Response(status=400)

comfyanonymous's avatar
comfyanonymous committed
200
201
202
203
204
        @routes.post("/upload/image")
        async def upload_image(request):
            post = await request.post()
            return image_upload(post)

205

206
207
208
209
        @routes.post("/upload/mask")
        async def upload_mask(request):
            post = await request.post()

comfyanonymous's avatar
comfyanonymous committed
210
            def image_save_function(image, post, filepath):
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
                original_ref = json.loads(post.get("original_ref"))
                filename, output_dir = folder_paths.annotated_filepath(original_ref['filename'])

                # validation for security: prevent accessing arbitrary path
                if filename[0] == '/' or '..' in filename:
                    return web.Response(status=400)

                if output_dir is None:
                    type = original_ref.get("type", "output")
                    output_dir = folder_paths.get_directory_by_type(type)

                if output_dir is None:
                    return web.Response(status=400)

                if original_ref.get("subfolder", "") != "":
                    full_output_dir = os.path.join(output_dir, original_ref["subfolder"])
                    if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
                        return web.Response(status=403)
                    output_dir = full_output_dir
230

231
232
233
234
                file = os.path.join(output_dir, filename)

                if os.path.isfile(file):
                    with Image.open(file) as original_pil:
Chris's avatar
Chris committed
235
                        metadata = PngInfo()
Chris's avatar
Chris committed
236
237
238
                        if hasattr(original_pil,'text'):
                            for key in original_pil.text:
                                metadata.add_text(key, original_pil.text[key])
239
240
241
242
243
244
                        original_pil = original_pil.convert('RGBA')
                        mask_pil = Image.open(image.file).convert('RGBA')

                        # alpha copy
                        new_alpha = mask_pil.getchannel('A')
                        original_pil.putalpha(new_alpha)
Chris's avatar
Chris committed
245
                        original_pil.save(filepath, compress_level=4, pnginfo=metadata)
246

comfyanonymous's avatar
comfyanonymous committed
247
            return image_upload(post, image_save_function)
pythongosssss's avatar
pythongosssss committed
248

249
        @routes.get("/view")
pythongosssss's avatar
pythongosssss committed
250
        async def view_image(request):
m957ymj75urz's avatar
m957ymj75urz committed
251
            if "filename" in request.rel_url.query:
252
253
254
255
256
257
258
259
260
261
262
                filename = request.rel_url.query["filename"]
                filename,output_dir = folder_paths.annotated_filepath(filename)

                # validation for security: prevent accessing arbitrary path
                if filename[0] == '/' or '..' in filename:
                    return web.Response(status=400)

                if output_dir is None:
                    type = request.rel_url.query.get("type", "output")
                    output_dir = folder_paths.get_directory_by_type(type)

263
                if output_dir is None:
pythongosssss's avatar
pythongosssss committed
264
265
                    return web.Response(status=400)

266
                if "subfolder" in request.rel_url.query:
m957ymj75urz's avatar
m957ymj75urz committed
267
                    full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
268
                    if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
m957ymj75urz's avatar
m957ymj75urz committed
269
270
                        return web.Response(status=403)
                    output_dir = full_output_dir
271

272
273
                filename = os.path.basename(filename)
                file = os.path.join(output_dir, filename)
m957ymj75urz's avatar
m957ymj75urz committed
274

pythongosssss's avatar
pythongosssss committed
275
                if os.path.isfile(file):
276
277
278
                    if 'preview' in request.rel_url.query:
                        with Image.open(file) as img:
                            preview_info = request.rel_url.query['preview'].split(';')
279
                            image_format = preview_info[0]
280
                            if image_format not in ['webp', 'jpeg'] or 'a' in request.rel_url.query.get('channel', ''):
281
                                image_format = 'webp'
282
283
284
285
286
287

                            quality = 90
                            if preview_info[-1].isdigit():
                                quality = int(preview_info[-1])

                            buffer = BytesIO()
288
                            if image_format in ['jpeg'] or request.rel_url.query.get('channel', '') == 'rgb':
289
290
                                img = img.convert("RGB")
                            img.save(buffer, format=image_format, quality=quality)
291
292
293
294
295
                            buffer.seek(0)

                            return web.Response(body=buffer.read(), content_type=f'image/{image_format}',
                                                headers={"Content-Disposition": f"filename=\"{filename}\""})

296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
                    if 'channel' not in request.rel_url.query:
                        channel = 'rgba'
                    else:
                        channel = request.rel_url.query["channel"]

                    if channel == 'rgb':
                        with Image.open(file) as img:
                            if img.mode == "RGBA":
                                r, g, b, a = img.split()
                                new_img = Image.merge('RGB', (r, g, b))
                            else:
                                new_img = img.convert("RGB")

                            buffer = BytesIO()
                            new_img.save(buffer, format='PNG')
                            buffer.seek(0)

                            return web.Response(body=buffer.read(), content_type='image/png',
                                                headers={"Content-Disposition": f"filename=\"{filename}\""})

                    elif channel == 'a':
                        with Image.open(file) as img:
                            if img.mode == "RGBA":
                                _, _, _, a = img.split()
                            else:
                                a = Image.new('L', img.size, 255)

                            # alpha img
                            alpha_img = Image.new('RGBA', img.size)
                            alpha_img.putalpha(a)
                            alpha_buffer = BytesIO()
                            alpha_img.save(alpha_buffer, format='PNG')
                            alpha_buffer.seek(0)

                            return web.Response(body=alpha_buffer.read(), content_type='image/png',
                                                headers={"Content-Disposition": f"filename=\"{filename}\""})
                    else:
                        return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})

pythongosssss's avatar
pythongosssss committed
335
            return web.Response(status=404)
336

337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        @routes.get("/view_metadata/{folder_name}")
        async def view_metadata(request):
            folder_name = request.match_info.get("folder_name", None)
            if folder_name is None:
                return web.Response(status=404)
            if not "filename" in request.rel_url.query:
                return web.Response(status=404)

            filename = request.rel_url.query["filename"]
            if not filename.endswith(".safetensors"):
                return web.Response(status=404)

            safetensors_path = folder_paths.get_full_path(folder_name, filename)
            if safetensors_path is None:
                return web.Response(status=404)
            out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024)
            if out is None:
                return web.Response(status=404)
            dt = json.loads(out)
            if not "__metadata__" in dt:
                return web.Response(status=404)
            return web.json_response(dt["__metadata__"])

space-nuko's avatar
space-nuko committed
360
361
        @routes.get("/system_stats")
        async def get_queue(request):
362
363
            device = comfy.model_management.get_torch_device()
            device_name = comfy.model_management.get_torch_device_name(device)
space-nuko's avatar
space-nuko committed
364
365
366
            vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
            vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
            system_stats = {
367
368
369
370
371
                "system": {
                    "os": os.name,
                    "python_version": sys.version,
                    "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
                },
space-nuko's avatar
space-nuko committed
372
373
374
375
376
377
378
379
380
381
382
383
384
385
                "devices": [
                    {
                        "name": device_name,
                        "type": device.type,
                        "index": device.index,
                        "vram_total": vram_total,
                        "vram_free": vram_free,
                        "torch_vram_total": torch_vram_total,
                        "torch_vram_free": torch_vram_free,
                    }
                ]
            }
            return web.json_response(system_stats)

pythongosssss's avatar
pythongosssss committed
386
387
388
        @routes.get("/prompt")
        async def get_prompt(request):
            return web.json_response(self.get_queue_info())
389

390
391
392
393
394
395
396
397
398
        def node_info(node_class):
            obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
            info = {}
            info['input'] = obj_class.INPUT_TYPES()
            info['output'] = obj_class.RETURN_TYPES
            info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
            info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
            info['name'] = node_class
            info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
Chris's avatar
Chris committed
399
            info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else ''
400
            info['category'] = 'sd'
401
402
403
404
405
            if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
                info['output_node'] = True
            else:
                info['output_node'] = False

406
407
408
409
            if hasattr(obj_class, 'CATEGORY'):
                info['category'] = obj_class.CATEGORY
            return info

pythongosssss's avatar
pythongosssss committed
410
411
412
413
        @routes.get("/object_info")
        async def get_object_info(request):
            out = {}
            for x in nodes.NODE_CLASS_MAPPINGS:
414
415
416
                try:
                    out[x] = node_info(x)
                except Exception as e:
417
418
                    logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
                    logging.error(traceback.format_exc())
419
420
421
422
423
424
425
426
            return web.json_response(out)

        @routes.get("/object_info/{node_class}")
        async def get_object_info_node(request):
            node_class = request.match_info.get("node_class", None)
            out = {}
            if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
                out[node_class] = node_info(node_class)
pythongosssss's avatar
pythongosssss committed
427
            return web.json_response(out)
428

pythongosssss's avatar
pythongosssss committed
429
430
        @routes.get("/history")
        async def get_history(request):
431
432
433
434
            max_items = request.rel_url.query.get("max_items", None)
            if max_items is not None:
                max_items = int(max_items)
            return web.json_response(self.prompt_queue.get_history(max_items=max_items))
435

436
437
438
439
440
        @routes.get("/history/{prompt_id}")
        async def get_history(request):
            prompt_id = request.match_info.get("prompt_id", None)
            return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id))

pythongosssss's avatar
pythongosssss committed
441
442
443
444
445
446
447
        @routes.get("/queue")
        async def get_queue(request):
            queue_info = {}
            current_queue = self.prompt_queue.get_current_queue()
            queue_info['queue_running'] = current_queue[0]
            queue_info['queue_pending'] = current_queue[1]
            return web.json_response(queue_info)
448

pythongosssss's avatar
pythongosssss committed
449
450
        @routes.post("/prompt")
        async def post_prompt(request):
comfyanonymous's avatar
comfyanonymous committed
451
            logging.info("got prompt")
pythongosssss's avatar
pythongosssss committed
452
453
454
            resp_code = 200
            out_string = ""
            json_data =  await request.json()
455
            json_data = self.trigger_on_prompt(json_data)
pythongosssss's avatar
pythongosssss committed
456
457
458
459
460
461
462
463
464
465
466
467
468

            if "number" in json_data:
                number = float(json_data['number'])
            else:
                number = self.number
                if "front" in json_data:
                    if json_data['front']:
                        number = -number

                self.number += 1

            if "prompt" in json_data:
                prompt = json_data["prompt"]
469
                valid = execution.validate_prompt(prompt)
pythongosssss's avatar
pythongosssss committed
470
471
472
473
474
475
476
                extra_data = {}
                if "extra_data" in json_data:
                    extra_data = json_data["extra_data"]

                if "client_id" in json_data:
                    extra_data["client_id"] = json_data["client_id"]
                if valid[0]:
477
                    prompt_id = str(uuid.uuid4())
comfyanonymous's avatar
comfyanonymous committed
478
479
                    outputs_to_execute = valid[2]
                    self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
480
481
                    response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
                    return web.json_response(response)
pythongosssss's avatar
pythongosssss committed
482
                else:
comfyanonymous's avatar
comfyanonymous committed
483
                    logging.warning("invalid prompt: {}".format(valid[1]))
484
                    return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
485
            else:
486
                return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
pythongosssss's avatar
pythongosssss committed
487
488
489
490
491
492
493
494
495
496

        @routes.post("/queue")
        async def post_queue(request):
            json_data =  await request.json()
            if "clear" in json_data:
                if json_data["clear"]:
                    self.prompt_queue.wipe_queue()
            if "delete" in json_data:
                to_delete = json_data['delete']
                for id_to_delete in to_delete:
comfyanonymous's avatar
comfyanonymous committed
497
                    delete_func = lambda a: a[1] == id_to_delete
pythongosssss's avatar
pythongosssss committed
498
                    self.prompt_queue.delete_queue_item(delete_func)
comfyanonymous's avatar
comfyanonymous committed
499

pythongosssss's avatar
pythongosssss committed
500
            return web.Response(status=200)
pythongosssss's avatar
pythongosssss committed
501
502
503
504
505
506

        @routes.post("/interrupt")
        async def post_interrupt(request):
            nodes.interrupt_processing()
            return web.Response(status=200)

507
        @routes.post("/free")
ramyma's avatar
ramyma committed
508
        async def post_free(request):
509
510
511
512
513
514
515
516
517
            json_data = await request.json()
            unload_models = json_data.get("unload_models", False)
            free_memory = json_data.get("free_memory", False)
            if unload_models:
                self.prompt_queue.set_flag("unload_models", unload_models)
            if free_memory:
                self.prompt_queue.set_flag("free_memory", free_memory)
            return web.Response(status=200)

pythongosssss's avatar
pythongosssss committed
518
519
520
521
522
        @routes.post("/history")
        async def post_history(request):
            json_data =  await request.json()
            if "clear" in json_data:
                if json_data["clear"]:
523
                    self.prompt_queue.wipe_history()
pythongosssss's avatar
pythongosssss committed
524
525
526
            if "delete" in json_data:
                to_delete = json_data['delete']
                for id_to_delete in to_delete:
527
528
                    self.prompt_queue.delete_history_item(id_to_delete)

pythongosssss's avatar
pythongosssss committed
529
            return web.Response(status=200)
530

531
    def add_routes(self):
532
        self.user_manager.add_routes(self.routes)
533
534
535
536
537
538
539
540

        # Prefix every route with /api for easier matching for delegation.
        # This is very useful for frontend dev server, which need to forward
        # everything except serving of static files.
        # Currently both the old endpoints without prefix and new endpoints with
        # prefix are supported.
        api_routes = web.RouteTableDef()
        for route in self.routes:
Chenlei Hu's avatar
Chenlei Hu committed
541
542
543
544
            # Custom nodes might add extra static routes. Only process non-static
            # routes to add /api prefix.
            if isinstance(route, web.RouteDef):
                api_routes.route(route.method, "/api" + route.path)(route.handler, **route.kwargs)
545
        self.app.add_routes(api_routes)
546
        self.app.add_routes(self.routes)
547
548
549

        for name, dir in nodes.EXTENSION_WEB_DIRS.items():
            self.app.add_routes([
550
                web.static('/extensions/' + urllib.parse.quote(name), dir),
551
552
            ])

pythongosssss's avatar
pythongosssss committed
553
        self.app.add_routes([
554
            web.static('/', self.web_root),
pythongosssss's avatar
pythongosssss committed
555
556
557
558
559
560
561
562
563
564
        ])

    def get_queue_info(self):
        prompt_info = {}
        exec_info = {}
        exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
        prompt_info['exec_info'] = exec_info
        return prompt_info

    async def send(self, event, data, sid=None):
565
566
567
        if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
            await self.send_image(data, sid=sid)
        elif isinstance(data, (bytes, bytearray)):
space-nuko's avatar
space-nuko committed
568
569
570
571
572
573
574
575
576
577
578
579
580
            await self.send_bytes(event, data, sid)
        else:
            await self.send_json(event, data, sid)

    def encode_bytes(self, event, data):
        if not isinstance(event, int):
            raise RuntimeError(f"Binary event types must be integers, got {event}")

        packed = struct.pack(">I", event)
        message = bytearray(packed)
        message.extend(data)
        return message

581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
    async def send_image(self, image_data, sid=None):
        image_type = image_data[0]
        image = image_data[1]
        max_size = image_data[2]
        if max_size is not None:
            if hasattr(Image, 'Resampling'):
                resampling = Image.Resampling.BILINEAR
            else:
                resampling = Image.ANTIALIAS

            image = ImageOps.contain(image, (max_size, max_size), resampling)
        type_num = 1
        if image_type == "JPEG":
            type_num = 1
        elif image_type == "PNG":
            type_num = 2

        bytesIO = BytesIO()
        header = struct.pack(">I", type_num)
        bytesIO.write(header)
601
        image.save(bytesIO, format=image_type, quality=95, compress_level=1)
602
603
604
        preview_bytes = bytesIO.getvalue()
        await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)

space-nuko's avatar
space-nuko committed
605
606
607
608
    async def send_bytes(self, event, data, sid=None):
        message = self.encode_bytes(event, data)

        if sid is None:
609
610
            sockets = list(self.sockets.values())
            for ws in sockets:
611
                await send_socket_catch_exception(ws.send_bytes, message)
space-nuko's avatar
space-nuko committed
612
        elif sid in self.sockets:
613
            await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
space-nuko's avatar
space-nuko committed
614
615

    async def send_json(self, event, data, sid=None):
pythongosssss's avatar
pythongosssss committed
616
617
618
        message = {"type": event, "data": data}

        if sid is None:
619
620
            sockets = list(self.sockets.values())
            for ws in sockets:
621
                await send_socket_catch_exception(ws.send_json, message)
pythongosssss's avatar
pythongosssss committed
622
        elif sid in self.sockets:
623
            await send_socket_catch_exception(self.sockets[sid].send_json, message)
pythongosssss's avatar
pythongosssss committed
624
625
626
627

    def send_sync(self, event, data, sid=None):
        self.loop.call_soon_threadsafe(
            self.messages.put_nowait, (event, data, sid))
628

pythongosssss's avatar
pythongosssss committed
629
630
631
632
633
634
635
636
    def queue_updated(self):
        self.send_sync("status", { "status": self.get_queue_info() })

    async def publish_loop(self):
        while True:
            msg = await self.messages.get()
            await self.send(*msg)

637
    async def start(self, address, port, verbose=True, call_on_start=None):
638
        runner = web.AppRunner(self.app, access_log=None)
pythongosssss's avatar
pythongosssss committed
639
        await runner.setup()
Garrett Sutula's avatar
Garrett Sutula committed
640
641
642
643
644
645
646
647
648
        ssl_ctx = None
        scheme = "http"
        if args.tls_keyfile and args.tls_certfile:
                ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE)
                ssl_ctx.load_cert_chain(certfile=args.tls_certfile,
                                keyfile=args.tls_keyfile)
                scheme = "https"

        site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
pythongosssss's avatar
pythongosssss committed
649
        await site.start()
650

comfyanonymous's avatar
comfyanonymous committed
651
        if verbose:
comfyanonymous's avatar
comfyanonymous committed
652
            logging.info("Starting server\n")
Garrett Sutula's avatar
Garrett Sutula committed
653
            logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port))
654
        if call_on_start is not None:
Garrett Sutula's avatar
Garrett Sutula committed
655
            call_on_start(scheme, address, port)
656

657
658
659
660
661
662
663
664
    def add_on_prompt_handler(self, handler):
        self.on_prompt_handlers.append(handler)

    def trigger_on_prompt(self, json_data):
        for handler in self.on_prompt_handlers:
            try:
                json_data = handler(json_data)
            except Exception as e:
comfyanonymous's avatar
comfyanonymous committed
665
                logging.warning(f"[ERROR] An error occurred during the on_prompt_handler processing")
666
                logging.warning(traceback.format_exc())
667
668

        return json_data