server.py 26 KB
Newer Older
pythongosssss's avatar
pythongosssss committed
1
2
3
import os
import sys
import asyncio
4
5
import traceback

pythongosssss's avatar
pythongosssss committed
6
import nodes
7
import folder_paths
8
import execution
pythongosssss's avatar
pythongosssss committed
9
import uuid
10
import urllib
pythongosssss's avatar
pythongosssss committed
11
import json
12
import glob
space-nuko's avatar
space-nuko committed
13
import struct
14
from PIL import Image, ImageOps
Chris's avatar
Chris committed
15
from PIL.PngImagePlugin import PngInfo
16
17
from io import BytesIO

comfyanonymous's avatar
comfyanonymous committed
18
19
import aiohttp
from aiohttp import web
comfyanonymous's avatar
comfyanonymous committed
20
import logging
pythongosssss's avatar
pythongosssss committed
21

comfyanonymous's avatar
Style.  
comfyanonymous committed
22
import mimetypes
EllangoK's avatar
EllangoK committed
23
from comfy.cli_args import args
24
import comfy.utils
space-nuko's avatar
space-nuko committed
25
import comfy.model_management
26

27
from app.user_manager import UserManager
space-nuko's avatar
space-nuko committed
28
29
30

class BinaryEventTypes:
    PREVIEW_IMAGE = 1
31
    UNENCODED_PREVIEW_IMAGE = 2
space-nuko's avatar
space-nuko committed
32

33
34
35
36
async def send_socket_catch_exception(function, message):
    try:
        await function(message)
    except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
comfyanonymous's avatar
comfyanonymous committed
37
        logging.warning("send error: {}".format(err))
space-nuko's avatar
space-nuko committed
38

39
40
41
42
43
44
45
@web.middleware
async def cache_control(request: web.Request, handler):
    response: web.Response = await handler(request)
    if request.path.endswith('.js') or request.path.endswith('.css'):
        response.headers.setdefault('Cache-Control', 'no-cache')
    return response

EllangoK's avatar
EllangoK committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def create_cors_middleware(allowed_origin: str):
    @web.middleware
    async def cors_middleware(request: web.Request, handler):
        if request.method == "OPTIONS":
            # Pre-flight request. Reply successfully:
            response = web.Response()
        else:
            response = await handler(request)

        response.headers['Access-Control-Allow-Origin'] = allowed_origin
        response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
        response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
        response.headers['Access-Control-Allow-Credentials'] = 'true'
        return response

    return cors_middleware
EllangoK's avatar
EllangoK committed
62

pythongosssss's avatar
pythongosssss committed
63
64
class PromptServer():
    def __init__(self, loop):
65
        PromptServer.instance = self
comfyanonymous's avatar
Style.  
comfyanonymous committed
66

67
        mimetypes.init()
pythongosssss's avatar
pythongosssss committed
68
        mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
69

70
        self.user_manager = UserManager()
71
        self.supports = ["custom_nodes_from_web"]
pythongosssss's avatar
pythongosssss committed
72
73
74
75
        self.prompt_queue = None
        self.loop = loop
        self.messages = asyncio.Queue()
        self.number = 0
EllangoK's avatar
EllangoK committed
76
77

        middlewares = [cache_control]
78
79
        if args.enable_cors_header:
            middlewares.append(create_cors_middleware(args.enable_cors_header))
EllangoK's avatar
EllangoK committed
80

81
82
        max_upload_size = round(args.max_upload_size * 1024 * 1024)
        self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
pythongosssss's avatar
pythongosssss committed
83
84
        self.sockets = dict()
        self.web_root = os.path.join(os.path.dirname(
pythongosssss's avatar
pythongosssss committed
85
            os.path.realpath(__file__)), "web")
pythongosssss's avatar
pythongosssss committed
86
        routes = web.RouteTableDef()
87
        self.routes = routes
88
89
        self.last_node_id = None
        self.client_id = None
pythongosssss's avatar
pythongosssss committed
90

91
92
        self.on_prompt_handlers = []

pythongosssss's avatar
pythongosssss committed
93
94
95
96
        @routes.get('/ws')
        async def websocket_handler(request):
            ws = web.WebSocketResponse()
            await ws.prepare(request)
97
98
99
100
101
            sid = request.rel_url.query.get('clientId', '')
            if sid:
                # Reusing existing session, remove old
                self.sockets.pop(sid, None)
            else:
102
                sid = uuid.uuid4().hex
103

pythongosssss's avatar
pythongosssss committed
104
            self.sockets[sid] = ws
105

pythongosssss's avatar
pythongosssss committed
106
107
108
            try:
                # Send initial state to the new client
                await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid)
109
110
111
112
                # On reconnect if we are the currently executing client send the current node
                if self.client_id == sid and self.last_node_id is not None:
                    await self.send("executing", { "node": self.last_node_id }, sid)
                    
pythongosssss's avatar
pythongosssss committed
113
114
                async for msg in ws:
                    if msg.type == aiohttp.WSMsgType.ERROR:
comfyanonymous's avatar
comfyanonymous committed
115
                        logging.warning('ws connection closed with exception %s' % ws.exception())
pythongosssss's avatar
pythongosssss committed
116
            finally:
117
                self.sockets.pop(sid, None)
pythongosssss's avatar
pythongosssss committed
118
119
120
121
122
            return ws

        @routes.get("/")
        async def get_root(request):
            return web.FileResponse(os.path.join(self.web_root, "index.html"))
123

124
125
        @routes.get("/embeddings")
        def get_embeddings(self):
126
            embeddings = folder_paths.get_filename_list("embeddings")
127
            return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
128

129
130
        @routes.get("/extensions")
        async def get_extensions(request):
131
            files = glob.glob(os.path.join(
132
                glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True)
133
134
135
136
            
            extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))
            
            for name, dir in nodes.EXTENSION_WEB_DIRS.items():
137
                files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True)
138
139
140
141
                extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote(
                    name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))

            return web.json_response(extensions)
142

143
144
        def get_dir_by_type(dir_type):
            if dir_type is None:
145
146
147
                dir_type = "input"

            if dir_type == "input":
148
149
150
151
152
153
                type_dir = folder_paths.get_input_directory()
            elif dir_type == "temp":
                type_dir = folder_paths.get_temp_directory()
            elif dir_type == "output":
                type_dir = folder_paths.get_output_directory()

154
            return type_dir, dir_type
155

comfyanonymous's avatar
comfyanonymous committed
156
        def image_upload(post, image_save_function=None):
ltdrdata's avatar
ltdrdata committed
157
            image = post.get("image")
158
            overwrite = post.get("overwrite")
ltdrdata's avatar
ltdrdata committed
159

comfyanonymous's avatar
comfyanonymous committed
160
            image_upload_type = post.get("type")
161
            upload_dir, image_upload_type = get_dir_by_type(image_upload_type)
pythongosssss's avatar
pythongosssss committed
162
163
164
165
166
167

            if image and image.file:
                filename = image.filename
                if not filename:
                    return web.Response(status=400)

comfyanonymous's avatar
comfyanonymous committed
168
169
                subfolder = post.get("subfolder", "")
                full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder))
comfyanonymous's avatar
comfyanonymous committed
170
                filepath = os.path.abspath(os.path.join(full_output_folder, filename))
comfyanonymous's avatar
comfyanonymous committed
171

comfyanonymous's avatar
comfyanonymous committed
172
                if os.path.commonpath((upload_dir, filepath)) != upload_dir:
comfyanonymous's avatar
comfyanonymous committed
173
174
175
176
177
                    return web.Response(status=400)

                if not os.path.exists(full_output_folder):
                    os.makedirs(full_output_folder)

pythongosssss's avatar
pythongosssss committed
178
                split = os.path.splitext(filename)
comfyanonymous's avatar
comfyanonymous committed
179

180
181
182
183
184
185
186
187
                if overwrite is not None and (overwrite == "true" or overwrite == "1"):
                    pass
                else:
                    i = 1
                    while os.path.exists(filepath):
                        filename = f"{split[0]} ({i}){split[1]}"
                        filepath = os.path.join(full_output_folder, filename)
                        i += 1
pythongosssss's avatar
pythongosssss committed
188

comfyanonymous's avatar
comfyanonymous committed
189
190
191
192
193
                if image_save_function is not None:
                    image_save_function(image, post, filepath)
                else:
                    with open(filepath, "wb") as f:
                        f.write(image.file.read())
pythongosssss's avatar
pythongosssss committed
194

comfyanonymous's avatar
comfyanonymous committed
195
                return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
pythongosssss's avatar
pythongosssss committed
196
197
198
            else:
                return web.Response(status=400)

comfyanonymous's avatar
comfyanonymous committed
199
200
201
202
203
        @routes.post("/upload/image")
        async def upload_image(request):
            post = await request.post()
            return image_upload(post)

204

205
206
207
208
        @routes.post("/upload/mask")
        async def upload_mask(request):
            post = await request.post()

comfyanonymous's avatar
comfyanonymous committed
209
            def image_save_function(image, post, filepath):
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
                original_ref = json.loads(post.get("original_ref"))
                filename, output_dir = folder_paths.annotated_filepath(original_ref['filename'])

                # validation for security: prevent accessing arbitrary path
                if filename[0] == '/' or '..' in filename:
                    return web.Response(status=400)

                if output_dir is None:
                    type = original_ref.get("type", "output")
                    output_dir = folder_paths.get_directory_by_type(type)

                if output_dir is None:
                    return web.Response(status=400)

                if original_ref.get("subfolder", "") != "":
                    full_output_dir = os.path.join(output_dir, original_ref["subfolder"])
                    if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
                        return web.Response(status=403)
                    output_dir = full_output_dir
229

230
231
232
233
                file = os.path.join(output_dir, filename)

                if os.path.isfile(file):
                    with Image.open(file) as original_pil:
Chris's avatar
Chris committed
234
                        metadata = PngInfo()
Chris's avatar
Chris committed
235
236
237
                        if hasattr(original_pil,'text'):
                            for key in original_pil.text:
                                metadata.add_text(key, original_pil.text[key])
238
239
240
241
242
243
                        original_pil = original_pil.convert('RGBA')
                        mask_pil = Image.open(image.file).convert('RGBA')

                        # alpha copy
                        new_alpha = mask_pil.getchannel('A')
                        original_pil.putalpha(new_alpha)
Chris's avatar
Chris committed
244
                        original_pil.save(filepath, compress_level=4, pnginfo=metadata)
245

comfyanonymous's avatar
comfyanonymous committed
246
            return image_upload(post, image_save_function)
pythongosssss's avatar
pythongosssss committed
247

248
        @routes.get("/view")
pythongosssss's avatar
pythongosssss committed
249
        async def view_image(request):
m957ymj75urz's avatar
m957ymj75urz committed
250
            if "filename" in request.rel_url.query:
251
252
253
254
255
256
257
258
259
260
261
                filename = request.rel_url.query["filename"]
                filename,output_dir = folder_paths.annotated_filepath(filename)

                # validation for security: prevent accessing arbitrary path
                if filename[0] == '/' or '..' in filename:
                    return web.Response(status=400)

                if output_dir is None:
                    type = request.rel_url.query.get("type", "output")
                    output_dir = folder_paths.get_directory_by_type(type)

262
                if output_dir is None:
pythongosssss's avatar
pythongosssss committed
263
264
                    return web.Response(status=400)

265
                if "subfolder" in request.rel_url.query:
m957ymj75urz's avatar
m957ymj75urz committed
266
                    full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
267
                    if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
m957ymj75urz's avatar
m957ymj75urz committed
268
269
                        return web.Response(status=403)
                    output_dir = full_output_dir
270

271
272
                filename = os.path.basename(filename)
                file = os.path.join(output_dir, filename)
m957ymj75urz's avatar
m957ymj75urz committed
273

pythongosssss's avatar
pythongosssss committed
274
                if os.path.isfile(file):
275
276
277
                    if 'preview' in request.rel_url.query:
                        with Image.open(file) as img:
                            preview_info = request.rel_url.query['preview'].split(';')
278
                            image_format = preview_info[0]
279
                            if image_format not in ['webp', 'jpeg'] or 'a' in request.rel_url.query.get('channel', ''):
280
                                image_format = 'webp'
281
282
283
284
285
286

                            quality = 90
                            if preview_info[-1].isdigit():
                                quality = int(preview_info[-1])

                            buffer = BytesIO()
287
                            if image_format in ['jpeg'] or request.rel_url.query.get('channel', '') == 'rgb':
288
289
                                img = img.convert("RGB")
                            img.save(buffer, format=image_format, quality=quality)
290
291
292
293
294
                            buffer.seek(0)

                            return web.Response(body=buffer.read(), content_type=f'image/{image_format}',
                                                headers={"Content-Disposition": f"filename=\"{filename}\""})

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
                    if 'channel' not in request.rel_url.query:
                        channel = 'rgba'
                    else:
                        channel = request.rel_url.query["channel"]

                    if channel == 'rgb':
                        with Image.open(file) as img:
                            if img.mode == "RGBA":
                                r, g, b, a = img.split()
                                new_img = Image.merge('RGB', (r, g, b))
                            else:
                                new_img = img.convert("RGB")

                            buffer = BytesIO()
                            new_img.save(buffer, format='PNG')
                            buffer.seek(0)

                            return web.Response(body=buffer.read(), content_type='image/png',
                                                headers={"Content-Disposition": f"filename=\"{filename}\""})

                    elif channel == 'a':
                        with Image.open(file) as img:
                            if img.mode == "RGBA":
                                _, _, _, a = img.split()
                            else:
                                a = Image.new('L', img.size, 255)

                            # alpha img
                            alpha_img = Image.new('RGBA', img.size)
                            alpha_img.putalpha(a)
                            alpha_buffer = BytesIO()
                            alpha_img.save(alpha_buffer, format='PNG')
                            alpha_buffer.seek(0)

                            return web.Response(body=alpha_buffer.read(), content_type='image/png',
                                                headers={"Content-Disposition": f"filename=\"{filename}\""})
                    else:
                        return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})

pythongosssss's avatar
pythongosssss committed
334
            return web.Response(status=404)
335

336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
        @routes.get("/view_metadata/{folder_name}")
        async def view_metadata(request):
            folder_name = request.match_info.get("folder_name", None)
            if folder_name is None:
                return web.Response(status=404)
            if not "filename" in request.rel_url.query:
                return web.Response(status=404)

            filename = request.rel_url.query["filename"]
            if not filename.endswith(".safetensors"):
                return web.Response(status=404)

            safetensors_path = folder_paths.get_full_path(folder_name, filename)
            if safetensors_path is None:
                return web.Response(status=404)
            out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024)
            if out is None:
                return web.Response(status=404)
            dt = json.loads(out)
            if not "__metadata__" in dt:
                return web.Response(status=404)
            return web.json_response(dt["__metadata__"])

space-nuko's avatar
space-nuko committed
359
360
        @routes.get("/system_stats")
        async def get_queue(request):
361
362
            device = comfy.model_management.get_torch_device()
            device_name = comfy.model_management.get_torch_device_name(device)
space-nuko's avatar
space-nuko committed
363
364
365
            vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
            vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
            system_stats = {
366
367
368
369
370
                "system": {
                    "os": os.name,
                    "python_version": sys.version,
                    "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
                },
space-nuko's avatar
space-nuko committed
371
372
373
374
375
376
377
378
379
380
381
382
383
384
                "devices": [
                    {
                        "name": device_name,
                        "type": device.type,
                        "index": device.index,
                        "vram_total": vram_total,
                        "vram_free": vram_free,
                        "torch_vram_total": torch_vram_total,
                        "torch_vram_free": torch_vram_free,
                    }
                ]
            }
            return web.json_response(system_stats)

pythongosssss's avatar
pythongosssss committed
385
386
387
        @routes.get("/prompt")
        async def get_prompt(request):
            return web.json_response(self.get_queue_info())
388

389
390
391
392
393
394
395
396
397
        def node_info(node_class):
            obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
            info = {}
            info['input'] = obj_class.INPUT_TYPES()
            info['output'] = obj_class.RETURN_TYPES
            info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
            info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
            info['name'] = node_class
            info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
Chris's avatar
Chris committed
398
            info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else ''
399
            info['category'] = 'sd'
400
401
402
403
404
            if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
                info['output_node'] = True
            else:
                info['output_node'] = False

405
406
407
408
            if hasattr(obj_class, 'CATEGORY'):
                info['category'] = obj_class.CATEGORY
            return info

pythongosssss's avatar
pythongosssss committed
409
410
411
412
        @routes.get("/object_info")
        async def get_object_info(request):
            out = {}
            for x in nodes.NODE_CLASS_MAPPINGS:
413
414
415
416
417
                try:
                    out[x] = node_info(x)
                except Exception as e:
                    print(f"[ERROR] An error occurred while retrieving information for the '{x}' node.", file=sys.stderr)
                    traceback.print_exc()
418
419
420
421
422
423
424
425
            return web.json_response(out)

        @routes.get("/object_info/{node_class}")
        async def get_object_info_node(request):
            node_class = request.match_info.get("node_class", None)
            out = {}
            if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
                out[node_class] = node_info(node_class)
pythongosssss's avatar
pythongosssss committed
426
            return web.json_response(out)
427

pythongosssss's avatar
pythongosssss committed
428
429
        @routes.get("/history")
        async def get_history(request):
430
431
432
433
            max_items = request.rel_url.query.get("max_items", None)
            if max_items is not None:
                max_items = int(max_items)
            return web.json_response(self.prompt_queue.get_history(max_items=max_items))
434

435
436
437
438
439
        @routes.get("/history/{prompt_id}")
        async def get_history(request):
            prompt_id = request.match_info.get("prompt_id", None)
            return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id))

pythongosssss's avatar
pythongosssss committed
440
441
442
443
444
445
446
        @routes.get("/queue")
        async def get_queue(request):
            queue_info = {}
            current_queue = self.prompt_queue.get_current_queue()
            queue_info['queue_running'] = current_queue[0]
            queue_info['queue_pending'] = current_queue[1]
            return web.json_response(queue_info)
447

pythongosssss's avatar
pythongosssss committed
448
449
        @routes.post("/prompt")
        async def post_prompt(request):
comfyanonymous's avatar
comfyanonymous committed
450
            logging.info("got prompt")
pythongosssss's avatar
pythongosssss committed
451
452
453
            resp_code = 200
            out_string = ""
            json_data =  await request.json()
454
            json_data = self.trigger_on_prompt(json_data)
pythongosssss's avatar
pythongosssss committed
455
456
457
458
459
460
461
462
463
464
465
466
467

            if "number" in json_data:
                number = float(json_data['number'])
            else:
                number = self.number
                if "front" in json_data:
                    if json_data['front']:
                        number = -number

                self.number += 1

            if "prompt" in json_data:
                prompt = json_data["prompt"]
468
                valid = execution.validate_prompt(prompt)
pythongosssss's avatar
pythongosssss committed
469
470
471
472
473
474
475
                extra_data = {}
                if "extra_data" in json_data:
                    extra_data = json_data["extra_data"]

                if "client_id" in json_data:
                    extra_data["client_id"] = json_data["client_id"]
                if valid[0]:
476
                    prompt_id = str(uuid.uuid4())
comfyanonymous's avatar
comfyanonymous committed
477
478
                    outputs_to_execute = valid[2]
                    self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
479
480
                    response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
                    return web.json_response(response)
pythongosssss's avatar
pythongosssss committed
481
                else:
comfyanonymous's avatar
comfyanonymous committed
482
                    logging.warning("invalid prompt: {}".format(valid[1]))
483
                    return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
484
            else:
485
                return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
pythongosssss's avatar
pythongosssss committed
486
487
488
489
490
491
492
493
494
495

        @routes.post("/queue")
        async def post_queue(request):
            json_data =  await request.json()
            if "clear" in json_data:
                if json_data["clear"]:
                    self.prompt_queue.wipe_queue()
            if "delete" in json_data:
                to_delete = json_data['delete']
                for id_to_delete in to_delete:
comfyanonymous's avatar
comfyanonymous committed
496
                    delete_func = lambda a: a[1] == id_to_delete
pythongosssss's avatar
pythongosssss committed
497
                    self.prompt_queue.delete_queue_item(delete_func)
comfyanonymous's avatar
comfyanonymous committed
498

pythongosssss's avatar
pythongosssss committed
499
            return web.Response(status=200)
pythongosssss's avatar
pythongosssss committed
500
501
502
503
504
505

        @routes.post("/interrupt")
        async def post_interrupt(request):
            nodes.interrupt_processing()
            return web.Response(status=200)

506
        @routes.post("/free")
ramyma's avatar
ramyma committed
507
        async def post_free(request):
508
509
510
511
512
513
514
515
516
            json_data = await request.json()
            unload_models = json_data.get("unload_models", False)
            free_memory = json_data.get("free_memory", False)
            if unload_models:
                self.prompt_queue.set_flag("unload_models", unload_models)
            if free_memory:
                self.prompt_queue.set_flag("free_memory", free_memory)
            return web.Response(status=200)

pythongosssss's avatar
pythongosssss committed
517
518
519
520
521
        @routes.post("/history")
        async def post_history(request):
            json_data =  await request.json()
            if "clear" in json_data:
                if json_data["clear"]:
522
                    self.prompt_queue.wipe_history()
pythongosssss's avatar
pythongosssss committed
523
524
525
            if "delete" in json_data:
                to_delete = json_data['delete']
                for id_to_delete in to_delete:
526
527
                    self.prompt_queue.delete_history_item(id_to_delete)

pythongosssss's avatar
pythongosssss committed
528
            return web.Response(status=200)
529
530
        
    def add_routes(self):
531
        self.user_manager.add_routes(self.routes)
532
        self.app.add_routes(self.routes)
533
534
535

        for name, dir in nodes.EXTENSION_WEB_DIRS.items():
            self.app.add_routes([
536
                web.static('/extensions/' + urllib.parse.quote(name), dir),
537
538
            ])

pythongosssss's avatar
pythongosssss committed
539
        self.app.add_routes([
540
            web.static('/', self.web_root),
pythongosssss's avatar
pythongosssss committed
541
542
543
544
545
546
547
548
549
550
        ])

    def get_queue_info(self):
        prompt_info = {}
        exec_info = {}
        exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
        prompt_info['exec_info'] = exec_info
        return prompt_info

    async def send(self, event, data, sid=None):
551
552
553
        if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
            await self.send_image(data, sid=sid)
        elif isinstance(data, (bytes, bytearray)):
space-nuko's avatar
space-nuko committed
554
555
556
557
558
559
560
561
562
563
564
565
566
            await self.send_bytes(event, data, sid)
        else:
            await self.send_json(event, data, sid)

    def encode_bytes(self, event, data):
        if not isinstance(event, int):
            raise RuntimeError(f"Binary event types must be integers, got {event}")

        packed = struct.pack(">I", event)
        message = bytearray(packed)
        message.extend(data)
        return message

567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
    async def send_image(self, image_data, sid=None):
        image_type = image_data[0]
        image = image_data[1]
        max_size = image_data[2]
        if max_size is not None:
            if hasattr(Image, 'Resampling'):
                resampling = Image.Resampling.BILINEAR
            else:
                resampling = Image.ANTIALIAS

            image = ImageOps.contain(image, (max_size, max_size), resampling)
        type_num = 1
        if image_type == "JPEG":
            type_num = 1
        elif image_type == "PNG":
            type_num = 2

        bytesIO = BytesIO()
        header = struct.pack(">I", type_num)
        bytesIO.write(header)
587
        image.save(bytesIO, format=image_type, quality=95, compress_level=1)
588
589
590
        preview_bytes = bytesIO.getvalue()
        await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)

space-nuko's avatar
space-nuko committed
591
592
593
594
    async def send_bytes(self, event, data, sid=None):
        message = self.encode_bytes(event, data)

        if sid is None:
595
596
            sockets = list(self.sockets.values())
            for ws in sockets:
597
                await send_socket_catch_exception(ws.send_bytes, message)
space-nuko's avatar
space-nuko committed
598
        elif sid in self.sockets:
599
            await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
space-nuko's avatar
space-nuko committed
600
601

    async def send_json(self, event, data, sid=None):
pythongosssss's avatar
pythongosssss committed
602
603
604
        message = {"type": event, "data": data}

        if sid is None:
605
606
            sockets = list(self.sockets.values())
            for ws in sockets:
607
                await send_socket_catch_exception(ws.send_json, message)
pythongosssss's avatar
pythongosssss committed
608
        elif sid in self.sockets:
609
            await send_socket_catch_exception(self.sockets[sid].send_json, message)
pythongosssss's avatar
pythongosssss committed
610
611
612
613

    def send_sync(self, event, data, sid=None):
        self.loop.call_soon_threadsafe(
            self.messages.put_nowait, (event, data, sid))
614

pythongosssss's avatar
pythongosssss committed
615
616
617
618
619
620
621
622
    def queue_updated(self):
        self.send_sync("status", { "status": self.get_queue_info() })

    async def publish_loop(self):
        while True:
            msg = await self.messages.get()
            await self.send(*msg)

623
    async def start(self, address, port, verbose=True, call_on_start=None):
624
        runner = web.AppRunner(self.app, access_log=None)
pythongosssss's avatar
pythongosssss committed
625
626
627
        await runner.setup()
        site = web.TCPSite(runner, address, port)
        await site.start()
628

comfyanonymous's avatar
comfyanonymous committed
629
        if verbose:
comfyanonymous's avatar
comfyanonymous committed
630
631
            logging.info("Starting server\n")
            logging.info("To see the GUI go to: http://{}:{}".format(address, port))
632
633
634
        if call_on_start is not None:
            call_on_start(address, port)

635
636
637
638
639
640
641
642
    def add_on_prompt_handler(self, handler):
        self.on_prompt_handlers.append(handler)

    def trigger_on_prompt(self, json_data):
        for handler in self.on_prompt_handlers:
            try:
                json_data = handler(json_data)
            except Exception as e:
comfyanonymous's avatar
comfyanonymous committed
643
                logging.warning(f"[ERROR] An error occurred during the on_prompt_handler processing")
644
645
646
                traceback.print_exc()

        return json_data