server.py 28.2 KB
Newer Older
pythongosssss's avatar
pythongosssss committed
1
2
3
import os
import sys
import asyncio
4
5
import traceback

pythongosssss's avatar
pythongosssss committed
6
import nodes
7
import folder_paths
8
import execution
pythongosssss's avatar
pythongosssss committed
9
import uuid
10
import urllib
pythongosssss's avatar
pythongosssss committed
11
import json
12
import glob
space-nuko's avatar
space-nuko committed
13
import struct
Garrett Sutula's avatar
Garrett Sutula committed
14
import ssl
15
import hashlib
16
from PIL import Image, ImageOps
Chris's avatar
Chris committed
17
from PIL.PngImagePlugin import PngInfo
18
19
from io import BytesIO

comfyanonymous's avatar
comfyanonymous committed
20
21
import aiohttp
from aiohttp import web
comfyanonymous's avatar
comfyanonymous committed
22
import logging
pythongosssss's avatar
pythongosssss committed
23

comfyanonymous's avatar
Style.  
comfyanonymous committed
24
import mimetypes
EllangoK's avatar
EllangoK committed
25
from comfy.cli_args import args
26
import comfy.utils
space-nuko's avatar
space-nuko committed
27
import comfy.model_management
28
from app.frontend_management import FrontendManager
29
from app.user_manager import UserManager
space-nuko's avatar
space-nuko committed
30

31

space-nuko's avatar
space-nuko committed
32
33
class BinaryEventTypes:
    PREVIEW_IMAGE = 1
34
    UNENCODED_PREVIEW_IMAGE = 2
space-nuko's avatar
space-nuko committed
35

36
37
38
39
async def send_socket_catch_exception(function, message):
    try:
        await function(message)
    except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
comfyanonymous's avatar
comfyanonymous committed
40
        logging.warning("send error: {}".format(err))
space-nuko's avatar
space-nuko committed
41

42
43
44
45
46
47
48
@web.middleware
async def cache_control(request: web.Request, handler):
    response: web.Response = await handler(request)
    if request.path.endswith('.js') or request.path.endswith('.css'):
        response.headers.setdefault('Cache-Control', 'no-cache')
    return response

EllangoK's avatar
EllangoK committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def create_cors_middleware(allowed_origin: str):
    @web.middleware
    async def cors_middleware(request: web.Request, handler):
        if request.method == "OPTIONS":
            # Pre-flight request. Reply successfully:
            response = web.Response()
        else:
            response = await handler(request)

        response.headers['Access-Control-Allow-Origin'] = allowed_origin
        response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
        response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
        response.headers['Access-Control-Allow-Credentials'] = 'true'
        return response

    return cors_middleware
EllangoK's avatar
EllangoK committed
65

pythongosssss's avatar
pythongosssss committed
66
67
class PromptServer():
    def __init__(self, loop):
68
        PromptServer.instance = self
comfyanonymous's avatar
Style.  
comfyanonymous committed
69

70
        mimetypes.init()
pythongosssss's avatar
pythongosssss committed
71
        mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
72

73
        self.user_manager = UserManager()
74
        self.supports = ["custom_nodes_from_web"]
pythongosssss's avatar
pythongosssss committed
75
76
77
78
        self.prompt_queue = None
        self.loop = loop
        self.messages = asyncio.Queue()
        self.number = 0
EllangoK's avatar
EllangoK committed
79
80

        middlewares = [cache_control]
81
82
        if args.enable_cors_header:
            middlewares.append(create_cors_middleware(args.enable_cors_header))
EllangoK's avatar
EllangoK committed
83

84
85
        max_upload_size = round(args.max_upload_size * 1024 * 1024)
        self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
pythongosssss's avatar
pythongosssss committed
86
        self.sockets = dict()
87
88
89
90
91
92
        self.web_root = (
            FrontendManager.init_frontend(args.front_end_version)
            if args.front_end_root is None
            else args.front_end_root
        )
        logging.info(f"[Prompt Server] web root: {self.web_root}")
pythongosssss's avatar
pythongosssss committed
93
        routes = web.RouteTableDef()
94
        self.routes = routes
95
96
        self.last_node_id = None
        self.client_id = None
pythongosssss's avatar
pythongosssss committed
97

98
99
        self.on_prompt_handlers = []

pythongosssss's avatar
pythongosssss committed
100
101
102
103
        @routes.get('/ws')
        async def websocket_handler(request):
            ws = web.WebSocketResponse()
            await ws.prepare(request)
104
105
106
107
108
            sid = request.rel_url.query.get('clientId', '')
            if sid:
                # Reusing existing session, remove old
                self.sockets.pop(sid, None)
            else:
109
                sid = uuid.uuid4().hex
110

pythongosssss's avatar
pythongosssss committed
111
            self.sockets[sid] = ws
112

pythongosssss's avatar
pythongosssss committed
113
114
115
            try:
                # Send initial state to the new client
                await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid)
116
117
118
                # On reconnect if we are the currently executing client send the current node
                if self.client_id == sid and self.last_node_id is not None:
                    await self.send("executing", { "node": self.last_node_id }, sid)
comfyanonymous's avatar
comfyanonymous committed
119

pythongosssss's avatar
pythongosssss committed
120
121
                async for msg in ws:
                    if msg.type == aiohttp.WSMsgType.ERROR:
comfyanonymous's avatar
comfyanonymous committed
122
                        logging.warning('ws connection closed with exception %s' % ws.exception())
pythongosssss's avatar
pythongosssss committed
123
            finally:
124
                self.sockets.pop(sid, None)
pythongosssss's avatar
pythongosssss committed
125
126
127
128
129
            return ws

        @routes.get("/")
        async def get_root(request):
            return web.FileResponse(os.path.join(self.web_root, "index.html"))
130

131
132
        @routes.get("/embeddings")
        def get_embeddings(self):
133
            embeddings = folder_paths.get_filename_list("embeddings")
134
            return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
135

136
137
        @routes.get("/extensions")
        async def get_extensions(request):
138
            files = glob.glob(os.path.join(
139
                glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True)
comfyanonymous's avatar
comfyanonymous committed
140

141
            extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))
comfyanonymous's avatar
comfyanonymous committed
142

143
            for name, dir in nodes.EXTENSION_WEB_DIRS.items():
144
                files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True)
145
146
147
148
                extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote(
                    name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))

            return web.json_response(extensions)
149

150
151
        def get_dir_by_type(dir_type):
            if dir_type is None:
152
153
154
                dir_type = "input"

            if dir_type == "input":
155
156
157
158
159
160
                type_dir = folder_paths.get_input_directory()
            elif dir_type == "temp":
                type_dir = folder_paths.get_temp_directory()
            elif dir_type == "output":
                type_dir = folder_paths.get_output_directory()

161
            return type_dir, dir_type
comfyanonymous's avatar
comfyanonymous committed
162

163
164
165
166
167
168
169
170
171
172
173
174
        def compare_image_hash(filepath, image):
            # function to compare hashes of two images to see if it already exists, fix to #3465
            if os.path.exists(filepath):
                a = hashlib.sha256()
                b = hashlib.sha256()
                with open(filepath, "rb") as f:
                    a.update(f.read())
                    b.update(image.file.read())
                    image.file.seek(0)
                    f.close()
                return a.hexdigest() == b.hexdigest()
            return False
comfyanonymous's avatar
comfyanonymous committed
175

comfyanonymous's avatar
comfyanonymous committed
176
        def image_upload(post, image_save_function=None):
ltdrdata's avatar
ltdrdata committed
177
            image = post.get("image")
178
            overwrite = post.get("overwrite")
179
            image_is_duplicate = False
ltdrdata's avatar
ltdrdata committed
180

comfyanonymous's avatar
comfyanonymous committed
181
            image_upload_type = post.get("type")
182
            upload_dir, image_upload_type = get_dir_by_type(image_upload_type)
pythongosssss's avatar
pythongosssss committed
183
184
185
186
187
188

            if image and image.file:
                filename = image.filename
                if not filename:
                    return web.Response(status=400)

comfyanonymous's avatar
comfyanonymous committed
189
190
                subfolder = post.get("subfolder", "")
                full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder))
comfyanonymous's avatar
comfyanonymous committed
191
                filepath = os.path.abspath(os.path.join(full_output_folder, filename))
comfyanonymous's avatar
comfyanonymous committed
192

comfyanonymous's avatar
comfyanonymous committed
193
                if os.path.commonpath((upload_dir, filepath)) != upload_dir:
comfyanonymous's avatar
comfyanonymous committed
194
195
196
197
198
                    return web.Response(status=400)

                if not os.path.exists(full_output_folder):
                    os.makedirs(full_output_folder)

pythongosssss's avatar
pythongosssss committed
199
                split = os.path.splitext(filename)
comfyanonymous's avatar
comfyanonymous committed
200

201
202
203
204
205
                if overwrite is not None and (overwrite == "true" or overwrite == "1"):
                    pass
                else:
                    i = 1
                    while os.path.exists(filepath):
206
207
208
                        if compare_image_hash(filepath, image): #compare hash to prevent saving of duplicates with same name, fix for #3465
                            image_is_duplicate = True
                            break
209
210
211
                        filename = f"{split[0]} ({i}){split[1]}"
                        filepath = os.path.join(full_output_folder, filename)
                        i += 1
pythongosssss's avatar
pythongosssss committed
212

comfyanonymous's avatar
comfyanonymous committed
213
                if not image_is_duplicate:
214
215
216
217
218
                    if image_save_function is not None:
                        image_save_function(image, post, filepath)
                    else:
                        with open(filepath, "wb") as f:
                            f.write(image.file.read())
pythongosssss's avatar
pythongosssss committed
219

comfyanonymous's avatar
comfyanonymous committed
220
                return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
pythongosssss's avatar
pythongosssss committed
221
222
223
            else:
                return web.Response(status=400)

comfyanonymous's avatar
comfyanonymous committed
224
225
226
227
228
        @routes.post("/upload/image")
        async def upload_image(request):
            post = await request.post()
            return image_upload(post)

229

230
231
232
233
        @routes.post("/upload/mask")
        async def upload_mask(request):
            post = await request.post()

comfyanonymous's avatar
comfyanonymous committed
234
            def image_save_function(image, post, filepath):
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
                original_ref = json.loads(post.get("original_ref"))
                filename, output_dir = folder_paths.annotated_filepath(original_ref['filename'])

                # validation for security: prevent accessing arbitrary path
                if filename[0] == '/' or '..' in filename:
                    return web.Response(status=400)

                if output_dir is None:
                    type = original_ref.get("type", "output")
                    output_dir = folder_paths.get_directory_by_type(type)

                if output_dir is None:
                    return web.Response(status=400)

                if original_ref.get("subfolder", "") != "":
                    full_output_dir = os.path.join(output_dir, original_ref["subfolder"])
                    if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
                        return web.Response(status=403)
                    output_dir = full_output_dir
254

255
256
257
258
                file = os.path.join(output_dir, filename)

                if os.path.isfile(file):
                    with Image.open(file) as original_pil:
Chris's avatar
Chris committed
259
                        metadata = PngInfo()
Chris's avatar
Chris committed
260
261
262
                        if hasattr(original_pil,'text'):
                            for key in original_pil.text:
                                metadata.add_text(key, original_pil.text[key])
263
264
265
266
267
268
                        original_pil = original_pil.convert('RGBA')
                        mask_pil = Image.open(image.file).convert('RGBA')

                        # alpha copy
                        new_alpha = mask_pil.getchannel('A')
                        original_pil.putalpha(new_alpha)
Chris's avatar
Chris committed
269
                        original_pil.save(filepath, compress_level=4, pnginfo=metadata)
270

comfyanonymous's avatar
comfyanonymous committed
271
            return image_upload(post, image_save_function)
pythongosssss's avatar
pythongosssss committed
272

273
        @routes.get("/view")
pythongosssss's avatar
pythongosssss committed
274
        async def view_image(request):
m957ymj75urz's avatar
m957ymj75urz committed
275
            if "filename" in request.rel_url.query:
276
277
278
279
280
281
282
283
284
285
286
                filename = request.rel_url.query["filename"]
                filename,output_dir = folder_paths.annotated_filepath(filename)

                # validation for security: prevent accessing arbitrary path
                if filename[0] == '/' or '..' in filename:
                    return web.Response(status=400)

                if output_dir is None:
                    type = request.rel_url.query.get("type", "output")
                    output_dir = folder_paths.get_directory_by_type(type)

287
                if output_dir is None:
pythongosssss's avatar
pythongosssss committed
288
289
                    return web.Response(status=400)

290
                if "subfolder" in request.rel_url.query:
m957ymj75urz's avatar
m957ymj75urz committed
291
                    full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
292
                    if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
m957ymj75urz's avatar
m957ymj75urz committed
293
294
                        return web.Response(status=403)
                    output_dir = full_output_dir
295

296
297
                filename = os.path.basename(filename)
                file = os.path.join(output_dir, filename)
m957ymj75urz's avatar
m957ymj75urz committed
298

pythongosssss's avatar
pythongosssss committed
299
                if os.path.isfile(file):
300
301
302
                    if 'preview' in request.rel_url.query:
                        with Image.open(file) as img:
                            preview_info = request.rel_url.query['preview'].split(';')
303
                            image_format = preview_info[0]
304
                            if image_format not in ['webp', 'jpeg'] or 'a' in request.rel_url.query.get('channel', ''):
305
                                image_format = 'webp'
306
307
308
309
310
311

                            quality = 90
                            if preview_info[-1].isdigit():
                                quality = int(preview_info[-1])

                            buffer = BytesIO()
312
                            if image_format in ['jpeg'] or request.rel_url.query.get('channel', '') == 'rgb':
313
314
                                img = img.convert("RGB")
                            img.save(buffer, format=image_format, quality=quality)
315
316
317
318
319
                            buffer.seek(0)

                            return web.Response(body=buffer.read(), content_type=f'image/{image_format}',
                                                headers={"Content-Disposition": f"filename=\"{filename}\""})

320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
                    if 'channel' not in request.rel_url.query:
                        channel = 'rgba'
                    else:
                        channel = request.rel_url.query["channel"]

                    if channel == 'rgb':
                        with Image.open(file) as img:
                            if img.mode == "RGBA":
                                r, g, b, a = img.split()
                                new_img = Image.merge('RGB', (r, g, b))
                            else:
                                new_img = img.convert("RGB")

                            buffer = BytesIO()
                            new_img.save(buffer, format='PNG')
                            buffer.seek(0)

                            return web.Response(body=buffer.read(), content_type='image/png',
                                                headers={"Content-Disposition": f"filename=\"{filename}\""})

                    elif channel == 'a':
                        with Image.open(file) as img:
                            if img.mode == "RGBA":
                                _, _, _, a = img.split()
                            else:
                                a = Image.new('L', img.size, 255)

                            # alpha img
                            alpha_img = Image.new('RGBA', img.size)
                            alpha_img.putalpha(a)
                            alpha_buffer = BytesIO()
                            alpha_img.save(alpha_buffer, format='PNG')
                            alpha_buffer.seek(0)

                            return web.Response(body=alpha_buffer.read(), content_type='image/png',
                                                headers={"Content-Disposition": f"filename=\"{filename}\""})
                    else:
                        return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})

pythongosssss's avatar
pythongosssss committed
359
            return web.Response(status=404)
360

361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        @routes.get("/view_metadata/{folder_name}")
        async def view_metadata(request):
            folder_name = request.match_info.get("folder_name", None)
            if folder_name is None:
                return web.Response(status=404)
            if not "filename" in request.rel_url.query:
                return web.Response(status=404)

            filename = request.rel_url.query["filename"]
            if not filename.endswith(".safetensors"):
                return web.Response(status=404)

            safetensors_path = folder_paths.get_full_path(folder_name, filename)
            if safetensors_path is None:
                return web.Response(status=404)
            out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024)
            if out is None:
                return web.Response(status=404)
            dt = json.loads(out)
            if not "__metadata__" in dt:
                return web.Response(status=404)
            return web.json_response(dt["__metadata__"])

space-nuko's avatar
space-nuko committed
384
385
        @routes.get("/system_stats")
        async def get_queue(request):
386
387
            device = comfy.model_management.get_torch_device()
            device_name = comfy.model_management.get_torch_device_name(device)
space-nuko's avatar
space-nuko committed
388
389
390
            vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
            vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
            system_stats = {
391
392
393
394
395
                "system": {
                    "os": os.name,
                    "python_version": sys.version,
                    "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
                },
space-nuko's avatar
space-nuko committed
396
397
398
399
400
401
402
403
404
405
406
407
408
409
                "devices": [
                    {
                        "name": device_name,
                        "type": device.type,
                        "index": device.index,
                        "vram_total": vram_total,
                        "vram_free": vram_free,
                        "torch_vram_total": torch_vram_total,
                        "torch_vram_free": torch_vram_free,
                    }
                ]
            }
            return web.json_response(system_stats)

pythongosssss's avatar
pythongosssss committed
410
411
412
        @routes.get("/prompt")
        async def get_prompt(request):
            return web.json_response(self.get_queue_info())
413

414
415
416
417
418
419
420
421
422
        def node_info(node_class):
            obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
            info = {}
            info['input'] = obj_class.INPUT_TYPES()
            info['output'] = obj_class.RETURN_TYPES
            info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
            info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
            info['name'] = node_class
            info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
Chris's avatar
Chris committed
423
            info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else ''
424
            info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes")
425
            info['category'] = 'sd'
426
427
428
429
430
            if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
                info['output_node'] = True
            else:
                info['output_node'] = False

431
432
433
434
            if hasattr(obj_class, 'CATEGORY'):
                info['category'] = obj_class.CATEGORY
            return info

pythongosssss's avatar
pythongosssss committed
435
436
437
438
        @routes.get("/object_info")
        async def get_object_info(request):
            out = {}
            for x in nodes.NODE_CLASS_MAPPINGS:
439
440
441
                try:
                    out[x] = node_info(x)
                except Exception as e:
442
443
                    logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
                    logging.error(traceback.format_exc())
444
445
446
447
448
449
450
451
            return web.json_response(out)

        @routes.get("/object_info/{node_class}")
        async def get_object_info_node(request):
            node_class = request.match_info.get("node_class", None)
            out = {}
            if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
                out[node_class] = node_info(node_class)
pythongosssss's avatar
pythongosssss committed
452
            return web.json_response(out)
453

pythongosssss's avatar
pythongosssss committed
454
455
        @routes.get("/history")
        async def get_history(request):
456
457
458
459
            max_items = request.rel_url.query.get("max_items", None)
            if max_items is not None:
                max_items = int(max_items)
            return web.json_response(self.prompt_queue.get_history(max_items=max_items))
460

461
462
463
464
465
        @routes.get("/history/{prompt_id}")
        async def get_history(request):
            prompt_id = request.match_info.get("prompt_id", None)
            return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id))

pythongosssss's avatar
pythongosssss committed
466
467
468
469
470
471
472
        @routes.get("/queue")
        async def get_queue(request):
            queue_info = {}
            current_queue = self.prompt_queue.get_current_queue()
            queue_info['queue_running'] = current_queue[0]
            queue_info['queue_pending'] = current_queue[1]
            return web.json_response(queue_info)
473

pythongosssss's avatar
pythongosssss committed
474
475
        @routes.post("/prompt")
        async def post_prompt(request):
comfyanonymous's avatar
comfyanonymous committed
476
            logging.info("got prompt")
pythongosssss's avatar
pythongosssss committed
477
478
479
            resp_code = 200
            out_string = ""
            json_data =  await request.json()
480
            json_data = self.trigger_on_prompt(json_data)
pythongosssss's avatar
pythongosssss committed
481
482
483
484
485
486
487
488
489
490
491
492
493

            if "number" in json_data:
                number = float(json_data['number'])
            else:
                number = self.number
                if "front" in json_data:
                    if json_data['front']:
                        number = -number

                self.number += 1

            if "prompt" in json_data:
                prompt = json_data["prompt"]
494
                valid = execution.validate_prompt(prompt)
pythongosssss's avatar
pythongosssss committed
495
496
497
498
499
500
501
                extra_data = {}
                if "extra_data" in json_data:
                    extra_data = json_data["extra_data"]

                if "client_id" in json_data:
                    extra_data["client_id"] = json_data["client_id"]
                if valid[0]:
502
                    prompt_id = str(uuid.uuid4())
comfyanonymous's avatar
comfyanonymous committed
503
504
                    outputs_to_execute = valid[2]
                    self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
505
506
                    response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
                    return web.json_response(response)
pythongosssss's avatar
pythongosssss committed
507
                else:
comfyanonymous's avatar
comfyanonymous committed
508
                    logging.warning("invalid prompt: {}".format(valid[1]))
509
                    return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
510
            else:
511
                return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
pythongosssss's avatar
pythongosssss committed
512
513
514
515
516
517
518
519
520
521

        @routes.post("/queue")
        async def post_queue(request):
            json_data =  await request.json()
            if "clear" in json_data:
                if json_data["clear"]:
                    self.prompt_queue.wipe_queue()
            if "delete" in json_data:
                to_delete = json_data['delete']
                for id_to_delete in to_delete:
comfyanonymous's avatar
comfyanonymous committed
522
                    delete_func = lambda a: a[1] == id_to_delete
pythongosssss's avatar
pythongosssss committed
523
                    self.prompt_queue.delete_queue_item(delete_func)
comfyanonymous's avatar
comfyanonymous committed
524

pythongosssss's avatar
pythongosssss committed
525
            return web.Response(status=200)
pythongosssss's avatar
pythongosssss committed
526
527
528
529
530
531

        @routes.post("/interrupt")
        async def post_interrupt(request):
            nodes.interrupt_processing()
            return web.Response(status=200)

532
        @routes.post("/free")
ramyma's avatar
ramyma committed
533
        async def post_free(request):
534
535
536
537
538
539
540
541
542
            json_data = await request.json()
            unload_models = json_data.get("unload_models", False)
            free_memory = json_data.get("free_memory", False)
            if unload_models:
                self.prompt_queue.set_flag("unload_models", unload_models)
            if free_memory:
                self.prompt_queue.set_flag("free_memory", free_memory)
            return web.Response(status=200)

pythongosssss's avatar
pythongosssss committed
543
544
545
546
547
        @routes.post("/history")
        async def post_history(request):
            json_data =  await request.json()
            if "clear" in json_data:
                if json_data["clear"]:
548
                    self.prompt_queue.wipe_history()
pythongosssss's avatar
pythongosssss committed
549
550
551
            if "delete" in json_data:
                to_delete = json_data['delete']
                for id_to_delete in to_delete:
552
553
                    self.prompt_queue.delete_history_item(id_to_delete)

pythongosssss's avatar
pythongosssss committed
554
            return web.Response(status=200)
555

556
    def add_routes(self):
557
        self.user_manager.add_routes(self.routes)
558
559
560
561
562
563
564
565

        # Prefix every route with /api for easier matching for delegation.
        # This is very useful for frontend dev server, which need to forward
        # everything except serving of static files.
        # Currently both the old endpoints without prefix and new endpoints with
        # prefix are supported.
        api_routes = web.RouteTableDef()
        for route in self.routes:
Chenlei Hu's avatar
Chenlei Hu committed
566
567
568
569
            # Custom nodes might add extra static routes. Only process non-static
            # routes to add /api prefix.
            if isinstance(route, web.RouteDef):
                api_routes.route(route.method, "/api" + route.path)(route.handler, **route.kwargs)
570
        self.app.add_routes(api_routes)
571
        self.app.add_routes(self.routes)
572
573
574

        for name, dir in nodes.EXTENSION_WEB_DIRS.items():
            self.app.add_routes([
575
                web.static('/extensions/' + urllib.parse.quote(name), dir),
576
577
            ])

pythongosssss's avatar
pythongosssss committed
578
        self.app.add_routes([
579
            web.static('/', self.web_root),
pythongosssss's avatar
pythongosssss committed
580
581
582
583
584
585
586
587
588
589
        ])

    def get_queue_info(self):
        prompt_info = {}
        exec_info = {}
        exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
        prompt_info['exec_info'] = exec_info
        return prompt_info

    async def send(self, event, data, sid=None):
590
591
592
        if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
            await self.send_image(data, sid=sid)
        elif isinstance(data, (bytes, bytearray)):
space-nuko's avatar
space-nuko committed
593
594
595
596
597
598
599
600
601
602
603
604
605
            await self.send_bytes(event, data, sid)
        else:
            await self.send_json(event, data, sid)

    def encode_bytes(self, event, data):
        if not isinstance(event, int):
            raise RuntimeError(f"Binary event types must be integers, got {event}")

        packed = struct.pack(">I", event)
        message = bytearray(packed)
        message.extend(data)
        return message

606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
    async def send_image(self, image_data, sid=None):
        image_type = image_data[0]
        image = image_data[1]
        max_size = image_data[2]
        if max_size is not None:
            if hasattr(Image, 'Resampling'):
                resampling = Image.Resampling.BILINEAR
            else:
                resampling = Image.ANTIALIAS

            image = ImageOps.contain(image, (max_size, max_size), resampling)
        type_num = 1
        if image_type == "JPEG":
            type_num = 1
        elif image_type == "PNG":
            type_num = 2

        bytesIO = BytesIO()
        header = struct.pack(">I", type_num)
        bytesIO.write(header)
626
        image.save(bytesIO, format=image_type, quality=95, compress_level=1)
627
628
629
        preview_bytes = bytesIO.getvalue()
        await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)

space-nuko's avatar
space-nuko committed
630
631
632
633
    async def send_bytes(self, event, data, sid=None):
        message = self.encode_bytes(event, data)

        if sid is None:
634
635
            sockets = list(self.sockets.values())
            for ws in sockets:
636
                await send_socket_catch_exception(ws.send_bytes, message)
space-nuko's avatar
space-nuko committed
637
        elif sid in self.sockets:
638
            await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
space-nuko's avatar
space-nuko committed
639
640

    async def send_json(self, event, data, sid=None):
pythongosssss's avatar
pythongosssss committed
641
642
643
        message = {"type": event, "data": data}

        if sid is None:
644
645
            sockets = list(self.sockets.values())
            for ws in sockets:
646
                await send_socket_catch_exception(ws.send_json, message)
pythongosssss's avatar
pythongosssss committed
647
        elif sid in self.sockets:
648
            await send_socket_catch_exception(self.sockets[sid].send_json, message)
pythongosssss's avatar
pythongosssss committed
649
650
651
652

    def send_sync(self, event, data, sid=None):
        self.loop.call_soon_threadsafe(
            self.messages.put_nowait, (event, data, sid))
653

pythongosssss's avatar
pythongosssss committed
654
655
656
657
658
659
660
661
    def queue_updated(self):
        self.send_sync("status", { "status": self.get_queue_info() })

    async def publish_loop(self):
        while True:
            msg = await self.messages.get()
            await self.send(*msg)

662
    async def start(self, address, port, verbose=True, call_on_start=None):
663
        runner = web.AppRunner(self.app, access_log=None)
pythongosssss's avatar
pythongosssss committed
664
        await runner.setup()
Garrett Sutula's avatar
Garrett Sutula committed
665
666
667
668
669
670
671
672
673
        ssl_ctx = None
        scheme = "http"
        if args.tls_keyfile and args.tls_certfile:
                ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE)
                ssl_ctx.load_cert_chain(certfile=args.tls_certfile,
                                keyfile=args.tls_keyfile)
                scheme = "https"

        site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
pythongosssss's avatar
pythongosssss committed
674
        await site.start()
675

comfyanonymous's avatar
comfyanonymous committed
676
        if verbose:
comfyanonymous's avatar
comfyanonymous committed
677
            logging.info("Starting server\n")
Garrett Sutula's avatar
Garrett Sutula committed
678
            logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port))
679
        if call_on_start is not None:
Garrett Sutula's avatar
Garrett Sutula committed
680
            call_on_start(scheme, address, port)
681

682
683
684
685
686
687
688
689
    def add_on_prompt_handler(self, handler):
        self.on_prompt_handlers.append(handler)

    def trigger_on_prompt(self, json_data):
        for handler in self.on_prompt_handlers:
            try:
                json_data = handler(json_data)
            except Exception as e:
comfyanonymous's avatar
comfyanonymous committed
690
                logging.warning(f"[ERROR] An error occurred during the on_prompt_handler processing")
691
                logging.warning(traceback.format_exc())
692
693

        return json_data