server.py 25.9 KB
Newer Older
pythongosssss's avatar
pythongosssss committed
1
2
3
import os
import sys
import asyncio
4
5
import traceback

pythongosssss's avatar
pythongosssss committed
6
import nodes
7
import folder_paths
8
import execution
pythongosssss's avatar
pythongosssss committed
9
import uuid
10
import urllib
pythongosssss's avatar
pythongosssss committed
11
import json
12
import glob
space-nuko's avatar
space-nuko committed
13
import struct
14
from PIL import Image, ImageOps
Chris's avatar
Chris committed
15
from PIL.PngImagePlugin import PngInfo
16
17
from io import BytesIO

comfyanonymous's avatar
comfyanonymous committed
18
19
import aiohttp
from aiohttp import web
pythongosssss's avatar
pythongosssss committed
20

comfyanonymous's avatar
Style.  
comfyanonymous committed
21
import mimetypes
EllangoK's avatar
EllangoK committed
22
from comfy.cli_args import args
23
import comfy.utils
space-nuko's avatar
space-nuko committed
24
import comfy.model_management
25

26
from app.user_manager import UserManager
space-nuko's avatar
space-nuko committed
27
28
29

class BinaryEventTypes:
    PREVIEW_IMAGE = 1
30
    UNENCODED_PREVIEW_IMAGE = 2
space-nuko's avatar
space-nuko committed
31

32
33
34
35
36
async def send_socket_catch_exception(function, message):
    try:
        await function(message)
    except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
        print("send error:", err)
space-nuko's avatar
space-nuko committed
37

38
39
40
41
42
43
44
@web.middleware
async def cache_control(request: web.Request, handler):
    response: web.Response = await handler(request)
    if request.path.endswith('.js') or request.path.endswith('.css'):
        response.headers.setdefault('Cache-Control', 'no-cache')
    return response

EllangoK's avatar
EllangoK committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def create_cors_middleware(allowed_origin: str):
    @web.middleware
    async def cors_middleware(request: web.Request, handler):
        if request.method == "OPTIONS":
            # Pre-flight request. Reply successfully:
            response = web.Response()
        else:
            response = await handler(request)

        response.headers['Access-Control-Allow-Origin'] = allowed_origin
        response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
        response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
        response.headers['Access-Control-Allow-Credentials'] = 'true'
        return response

    return cors_middleware
EllangoK's avatar
EllangoK committed
61

pythongosssss's avatar
pythongosssss committed
62
63
class PromptServer():
    def __init__(self, loop):
64
        PromptServer.instance = self
comfyanonymous's avatar
Style.  
comfyanonymous committed
65

66
        mimetypes.init()
pythongosssss's avatar
pythongosssss committed
67
        mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
68

69
        self.user_manager = UserManager()
70
        self.supports = ["custom_nodes_from_web"]
pythongosssss's avatar
pythongosssss committed
71
72
73
74
        self.prompt_queue = None
        self.loop = loop
        self.messages = asyncio.Queue()
        self.number = 0
EllangoK's avatar
EllangoK committed
75
76

        middlewares = [cache_control]
77
78
        if args.enable_cors_header:
            middlewares.append(create_cors_middleware(args.enable_cors_header))
EllangoK's avatar
EllangoK committed
79

80
81
        max_upload_size = round(args.max_upload_size * 1024 * 1024)
        self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
pythongosssss's avatar
pythongosssss committed
82
83
        self.sockets = dict()
        self.web_root = os.path.join(os.path.dirname(
pythongosssss's avatar
pythongosssss committed
84
            os.path.realpath(__file__)), "web")
pythongosssss's avatar
pythongosssss committed
85
        routes = web.RouteTableDef()
86
        self.routes = routes
87
88
        self.last_node_id = None
        self.client_id = None
pythongosssss's avatar
pythongosssss committed
89

90
91
        self.on_prompt_handlers = []

pythongosssss's avatar
pythongosssss committed
92
93
94
95
        @routes.get('/ws')
        async def websocket_handler(request):
            ws = web.WebSocketResponse()
            await ws.prepare(request)
96
97
98
99
100
            sid = request.rel_url.query.get('clientId', '')
            if sid:
                # Reusing existing session, remove old
                self.sockets.pop(sid, None)
            else:
101
                sid = uuid.uuid4().hex
102

pythongosssss's avatar
pythongosssss committed
103
            self.sockets[sid] = ws
104

pythongosssss's avatar
pythongosssss committed
105
106
107
            try:
                # Send initial state to the new client
                await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid)
108
109
110
111
                # On reconnect if we are the currently executing client send the current node
                if self.client_id == sid and self.last_node_id is not None:
                    await self.send("executing", { "node": self.last_node_id }, sid)
                    
pythongosssss's avatar
pythongosssss committed
112
113
114
115
                async for msg in ws:
                    if msg.type == aiohttp.WSMsgType.ERROR:
                        print('ws connection closed with exception %s' % ws.exception())
            finally:
116
                self.sockets.pop(sid, None)
pythongosssss's avatar
pythongosssss committed
117
118
119
120
121
            return ws

        @routes.get("/")
        async def get_root(request):
            return web.FileResponse(os.path.join(self.web_root, "index.html"))
122

123
124
        @routes.get("/embeddings")
        def get_embeddings(self):
125
            embeddings = folder_paths.get_filename_list("embeddings")
126
            return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
127

128
129
        @routes.get("/extensions")
        async def get_extensions(request):
130
            files = glob.glob(os.path.join(
131
                glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True)
132
133
134
135
            
            extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))
            
            for name, dir in nodes.EXTENSION_WEB_DIRS.items():
136
                files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True)
137
138
139
140
                extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote(
                    name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))

            return web.json_response(extensions)
141

142
143
        def get_dir_by_type(dir_type):
            if dir_type is None:
144
145
146
                dir_type = "input"

            if dir_type == "input":
147
148
149
150
151
152
                type_dir = folder_paths.get_input_directory()
            elif dir_type == "temp":
                type_dir = folder_paths.get_temp_directory()
            elif dir_type == "output":
                type_dir = folder_paths.get_output_directory()

153
            return type_dir, dir_type
154

comfyanonymous's avatar
comfyanonymous committed
155
        def image_upload(post, image_save_function=None):
ltdrdata's avatar
ltdrdata committed
156
            image = post.get("image")
157
            overwrite = post.get("overwrite")
ltdrdata's avatar
ltdrdata committed
158

comfyanonymous's avatar
comfyanonymous committed
159
            image_upload_type = post.get("type")
160
            upload_dir, image_upload_type = get_dir_by_type(image_upload_type)
pythongosssss's avatar
pythongosssss committed
161
162
163
164
165
166

            if image and image.file:
                filename = image.filename
                if not filename:
                    return web.Response(status=400)

comfyanonymous's avatar
comfyanonymous committed
167
168
                subfolder = post.get("subfolder", "")
                full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder))
comfyanonymous's avatar
comfyanonymous committed
169
                filepath = os.path.abspath(os.path.join(full_output_folder, filename))
comfyanonymous's avatar
comfyanonymous committed
170

comfyanonymous's avatar
comfyanonymous committed
171
                if os.path.commonpath((upload_dir, filepath)) != upload_dir:
comfyanonymous's avatar
comfyanonymous committed
172
173
174
175
176
                    return web.Response(status=400)

                if not os.path.exists(full_output_folder):
                    os.makedirs(full_output_folder)

pythongosssss's avatar
pythongosssss committed
177
                split = os.path.splitext(filename)
comfyanonymous's avatar
comfyanonymous committed
178

179
180
181
182
183
184
185
186
                if overwrite is not None and (overwrite == "true" or overwrite == "1"):
                    pass
                else:
                    i = 1
                    while os.path.exists(filepath):
                        filename = f"{split[0]} ({i}){split[1]}"
                        filepath = os.path.join(full_output_folder, filename)
                        i += 1
pythongosssss's avatar
pythongosssss committed
187

comfyanonymous's avatar
comfyanonymous committed
188
189
190
191
192
                if image_save_function is not None:
                    image_save_function(image, post, filepath)
                else:
                    with open(filepath, "wb") as f:
                        f.write(image.file.read())
pythongosssss's avatar
pythongosssss committed
193

comfyanonymous's avatar
comfyanonymous committed
194
                return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
pythongosssss's avatar
pythongosssss committed
195
196
197
            else:
                return web.Response(status=400)

comfyanonymous's avatar
comfyanonymous committed
198
199
200
201
202
        @routes.post("/upload/image")
        async def upload_image(request):
            post = await request.post()
            return image_upload(post)

203

204
205
206
207
        @routes.post("/upload/mask")
        async def upload_mask(request):
            post = await request.post()

comfyanonymous's avatar
comfyanonymous committed
208
            def image_save_function(image, post, filepath):
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
                original_ref = json.loads(post.get("original_ref"))
                filename, output_dir = folder_paths.annotated_filepath(original_ref['filename'])

                # validation for security: prevent accessing arbitrary path
                if filename[0] == '/' or '..' in filename:
                    return web.Response(status=400)

                if output_dir is None:
                    type = original_ref.get("type", "output")
                    output_dir = folder_paths.get_directory_by_type(type)

                if output_dir is None:
                    return web.Response(status=400)

                if original_ref.get("subfolder", "") != "":
                    full_output_dir = os.path.join(output_dir, original_ref["subfolder"])
                    if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
                        return web.Response(status=403)
                    output_dir = full_output_dir
228

229
230
231
232
                file = os.path.join(output_dir, filename)

                if os.path.isfile(file):
                    with Image.open(file) as original_pil:
Chris's avatar
Chris committed
233
                        metadata = PngInfo()
Chris's avatar
Chris committed
234
235
236
                        if hasattr(original_pil,'text'):
                            for key in original_pil.text:
                                metadata.add_text(key, original_pil.text[key])
237
238
239
240
241
242
                        original_pil = original_pil.convert('RGBA')
                        mask_pil = Image.open(image.file).convert('RGBA')

                        # alpha copy
                        new_alpha = mask_pil.getchannel('A')
                        original_pil.putalpha(new_alpha)
Chris's avatar
Chris committed
243
                        original_pil.save(filepath, compress_level=4, pnginfo=metadata)
244

comfyanonymous's avatar
comfyanonymous committed
245
            return image_upload(post, image_save_function)
pythongosssss's avatar
pythongosssss committed
246

247
        @routes.get("/view")
pythongosssss's avatar
pythongosssss committed
248
        async def view_image(request):
m957ymj75urz's avatar
m957ymj75urz committed
249
            if "filename" in request.rel_url.query:
250
251
252
253
254
255
256
257
258
259
260
                filename = request.rel_url.query["filename"]
                filename,output_dir = folder_paths.annotated_filepath(filename)

                # validation for security: prevent accessing arbitrary path
                if filename[0] == '/' or '..' in filename:
                    return web.Response(status=400)

                if output_dir is None:
                    type = request.rel_url.query.get("type", "output")
                    output_dir = folder_paths.get_directory_by_type(type)

261
                if output_dir is None:
pythongosssss's avatar
pythongosssss committed
262
263
                    return web.Response(status=400)

264
                if "subfolder" in request.rel_url.query:
m957ymj75urz's avatar
m957ymj75urz committed
265
                    full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
266
                    if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
m957ymj75urz's avatar
m957ymj75urz committed
267
268
                        return web.Response(status=403)
                    output_dir = full_output_dir
269

270
271
                filename = os.path.basename(filename)
                file = os.path.join(output_dir, filename)
m957ymj75urz's avatar
m957ymj75urz committed
272

pythongosssss's avatar
pythongosssss committed
273
                if os.path.isfile(file):
274
275
276
                    if 'preview' in request.rel_url.query:
                        with Image.open(file) as img:
                            preview_info = request.rel_url.query['preview'].split(';')
277
                            image_format = preview_info[0]
278
                            if image_format not in ['webp', 'jpeg'] or 'a' in request.rel_url.query.get('channel', ''):
279
                                image_format = 'webp'
280
281
282
283
284
285

                            quality = 90
                            if preview_info[-1].isdigit():
                                quality = int(preview_info[-1])

                            buffer = BytesIO()
286
                            if image_format in ['jpeg'] or request.rel_url.query.get('channel', '') == 'rgb':
287
288
                                img = img.convert("RGB")
                            img.save(buffer, format=image_format, quality=quality)
289
290
291
292
293
                            buffer.seek(0)

                            return web.Response(body=buffer.read(), content_type=f'image/{image_format}',
                                                headers={"Content-Disposition": f"filename=\"{filename}\""})

294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
                    if 'channel' not in request.rel_url.query:
                        channel = 'rgba'
                    else:
                        channel = request.rel_url.query["channel"]

                    if channel == 'rgb':
                        with Image.open(file) as img:
                            if img.mode == "RGBA":
                                r, g, b, a = img.split()
                                new_img = Image.merge('RGB', (r, g, b))
                            else:
                                new_img = img.convert("RGB")

                            buffer = BytesIO()
                            new_img.save(buffer, format='PNG')
                            buffer.seek(0)

                            return web.Response(body=buffer.read(), content_type='image/png',
                                                headers={"Content-Disposition": f"filename=\"{filename}\""})

                    elif channel == 'a':
                        with Image.open(file) as img:
                            if img.mode == "RGBA":
                                _, _, _, a = img.split()
                            else:
                                a = Image.new('L', img.size, 255)

                            # alpha img
                            alpha_img = Image.new('RGBA', img.size)
                            alpha_img.putalpha(a)
                            alpha_buffer = BytesIO()
                            alpha_img.save(alpha_buffer, format='PNG')
                            alpha_buffer.seek(0)

                            return web.Response(body=alpha_buffer.read(), content_type='image/png',
                                                headers={"Content-Disposition": f"filename=\"{filename}\""})
                    else:
                        return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})

pythongosssss's avatar
pythongosssss committed
333
            return web.Response(status=404)
334

335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
        @routes.get("/view_metadata/{folder_name}")
        async def view_metadata(request):
            folder_name = request.match_info.get("folder_name", None)
            if folder_name is None:
                return web.Response(status=404)
            if not "filename" in request.rel_url.query:
                return web.Response(status=404)

            filename = request.rel_url.query["filename"]
            if not filename.endswith(".safetensors"):
                return web.Response(status=404)

            safetensors_path = folder_paths.get_full_path(folder_name, filename)
            if safetensors_path is None:
                return web.Response(status=404)
            out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024)
            if out is None:
                return web.Response(status=404)
            dt = json.loads(out)
            if not "__metadata__" in dt:
                return web.Response(status=404)
            return web.json_response(dt["__metadata__"])

space-nuko's avatar
space-nuko committed
358
359
        @routes.get("/system_stats")
        async def get_queue(request):
360
361
            device = comfy.model_management.get_torch_device()
            device_name = comfy.model_management.get_torch_device_name(device)
space-nuko's avatar
space-nuko committed
362
363
364
            vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
            vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
            system_stats = {
365
366
367
368
369
                "system": {
                    "os": os.name,
                    "python_version": sys.version,
                    "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
                },
space-nuko's avatar
space-nuko committed
370
371
372
373
374
375
376
377
378
379
380
381
382
383
                "devices": [
                    {
                        "name": device_name,
                        "type": device.type,
                        "index": device.index,
                        "vram_total": vram_total,
                        "vram_free": vram_free,
                        "torch_vram_total": torch_vram_total,
                        "torch_vram_free": torch_vram_free,
                    }
                ]
            }
            return web.json_response(system_stats)

pythongosssss's avatar
pythongosssss committed
384
385
386
        @routes.get("/prompt")
        async def get_prompt(request):
            return web.json_response(self.get_queue_info())
387

388
389
390
391
392
393
394
395
396
        def node_info(node_class):
            obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
            info = {}
            info['input'] = obj_class.INPUT_TYPES()
            info['output'] = obj_class.RETURN_TYPES
            info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
            info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
            info['name'] = node_class
            info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
Chris's avatar
Chris committed
397
            info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else ''
398
            info['category'] = 'sd'
399
400
401
402
403
            if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
                info['output_node'] = True
            else:
                info['output_node'] = False

404
405
406
407
            if hasattr(obj_class, 'CATEGORY'):
                info['category'] = obj_class.CATEGORY
            return info

pythongosssss's avatar
pythongosssss committed
408
409
410
411
        @routes.get("/object_info")
        async def get_object_info(request):
            out = {}
            for x in nodes.NODE_CLASS_MAPPINGS:
412
413
414
415
416
                try:
                    out[x] = node_info(x)
                except Exception as e:
                    print(f"[ERROR] An error occurred while retrieving information for the '{x}' node.", file=sys.stderr)
                    traceback.print_exc()
417
418
419
420
421
422
423
424
            return web.json_response(out)

        @routes.get("/object_info/{node_class}")
        async def get_object_info_node(request):
            node_class = request.match_info.get("node_class", None)
            out = {}
            if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
                out[node_class] = node_info(node_class)
pythongosssss's avatar
pythongosssss committed
425
            return web.json_response(out)
426

pythongosssss's avatar
pythongosssss committed
427
428
        @routes.get("/history")
        async def get_history(request):
429
430
431
432
            max_items = request.rel_url.query.get("max_items", None)
            if max_items is not None:
                max_items = int(max_items)
            return web.json_response(self.prompt_queue.get_history(max_items=max_items))
433

434
435
436
437
438
        @routes.get("/history/{prompt_id}")
        async def get_history(request):
            prompt_id = request.match_info.get("prompt_id", None)
            return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id))

pythongosssss's avatar
pythongosssss committed
439
440
441
442
443
444
445
        @routes.get("/queue")
        async def get_queue(request):
            queue_info = {}
            current_queue = self.prompt_queue.get_current_queue()
            queue_info['queue_running'] = current_queue[0]
            queue_info['queue_pending'] = current_queue[1]
            return web.json_response(queue_info)
446

pythongosssss's avatar
pythongosssss committed
447
448
449
450
451
452
        @routes.post("/prompt")
        async def post_prompt(request):
            print("got prompt")
            resp_code = 200
            out_string = ""
            json_data =  await request.json()
453
            json_data = self.trigger_on_prompt(json_data)
pythongosssss's avatar
pythongosssss committed
454
455
456
457
458
459
460
461
462
463
464
465
466

            if "number" in json_data:
                number = float(json_data['number'])
            else:
                number = self.number
                if "front" in json_data:
                    if json_data['front']:
                        number = -number

                self.number += 1

            if "prompt" in json_data:
                prompt = json_data["prompt"]
467
                valid = execution.validate_prompt(prompt)
pythongosssss's avatar
pythongosssss committed
468
469
470
471
472
473
474
                extra_data = {}
                if "extra_data" in json_data:
                    extra_data = json_data["extra_data"]

                if "client_id" in json_data:
                    extra_data["client_id"] = json_data["client_id"]
                if valid[0]:
475
                    prompt_id = str(uuid.uuid4())
comfyanonymous's avatar
comfyanonymous committed
476
477
                    outputs_to_execute = valid[2]
                    self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
478
479
                    response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
                    return web.json_response(response)
pythongosssss's avatar
pythongosssss committed
480
481
                else:
                    print("invalid prompt:", valid[1])
482
                    return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
483
            else:
484
                return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
pythongosssss's avatar
pythongosssss committed
485
486
487
488
489
490
491
492
493
494

        @routes.post("/queue")
        async def post_queue(request):
            json_data =  await request.json()
            if "clear" in json_data:
                if json_data["clear"]:
                    self.prompt_queue.wipe_queue()
            if "delete" in json_data:
                to_delete = json_data['delete']
                for id_to_delete in to_delete:
comfyanonymous's avatar
comfyanonymous committed
495
                    delete_func = lambda a: a[1] == id_to_delete
pythongosssss's avatar
pythongosssss committed
496
                    self.prompt_queue.delete_queue_item(delete_func)
comfyanonymous's avatar
comfyanonymous committed
497

pythongosssss's avatar
pythongosssss committed
498
            return web.Response(status=200)
pythongosssss's avatar
pythongosssss committed
499
500
501
502
503
504

        @routes.post("/interrupt")
        async def post_interrupt(request):
            nodes.interrupt_processing()
            return web.Response(status=200)

505
        @routes.post("/free")
ramyma's avatar
ramyma committed
506
        async def post_free(request):
507
508
509
510
511
512
513
514
515
            json_data = await request.json()
            unload_models = json_data.get("unload_models", False)
            free_memory = json_data.get("free_memory", False)
            if unload_models:
                self.prompt_queue.set_flag("unload_models", unload_models)
            if free_memory:
                self.prompt_queue.set_flag("free_memory", free_memory)
            return web.Response(status=200)

pythongosssss's avatar
pythongosssss committed
516
517
518
519
520
        @routes.post("/history")
        async def post_history(request):
            json_data =  await request.json()
            if "clear" in json_data:
                if json_data["clear"]:
521
                    self.prompt_queue.wipe_history()
pythongosssss's avatar
pythongosssss committed
522
523
524
            if "delete" in json_data:
                to_delete = json_data['delete']
                for id_to_delete in to_delete:
525
526
                    self.prompt_queue.delete_history_item(id_to_delete)

pythongosssss's avatar
pythongosssss committed
527
            return web.Response(status=200)
528
529
        
    def add_routes(self):
530
        self.user_manager.add_routes(self.routes)
531
        self.app.add_routes(self.routes)
532
533
534

        for name, dir in nodes.EXTENSION_WEB_DIRS.items():
            self.app.add_routes([
535
                web.static('/extensions/' + urllib.parse.quote(name), dir),
536
537
            ])

pythongosssss's avatar
pythongosssss committed
538
        self.app.add_routes([
539
            web.static('/', self.web_root),
pythongosssss's avatar
pythongosssss committed
540
541
542
543
544
545
546
547
548
549
        ])

    def get_queue_info(self):
        prompt_info = {}
        exec_info = {}
        exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
        prompt_info['exec_info'] = exec_info
        return prompt_info

    async def send(self, event, data, sid=None):
550
551
552
        if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
            await self.send_image(data, sid=sid)
        elif isinstance(data, (bytes, bytearray)):
space-nuko's avatar
space-nuko committed
553
554
555
556
557
558
559
560
561
562
563
564
565
            await self.send_bytes(event, data, sid)
        else:
            await self.send_json(event, data, sid)

    def encode_bytes(self, event, data):
        if not isinstance(event, int):
            raise RuntimeError(f"Binary event types must be integers, got {event}")

        packed = struct.pack(">I", event)
        message = bytearray(packed)
        message.extend(data)
        return message

566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
    async def send_image(self, image_data, sid=None):
        image_type = image_data[0]
        image = image_data[1]
        max_size = image_data[2]
        if max_size is not None:
            if hasattr(Image, 'Resampling'):
                resampling = Image.Resampling.BILINEAR
            else:
                resampling = Image.ANTIALIAS

            image = ImageOps.contain(image, (max_size, max_size), resampling)
        type_num = 1
        if image_type == "JPEG":
            type_num = 1
        elif image_type == "PNG":
            type_num = 2

        bytesIO = BytesIO()
        header = struct.pack(">I", type_num)
        bytesIO.write(header)
586
        image.save(bytesIO, format=image_type, quality=95, compress_level=1)
587
588
589
        preview_bytes = bytesIO.getvalue()
        await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)

space-nuko's avatar
space-nuko committed
590
591
592
593
    async def send_bytes(self, event, data, sid=None):
        message = self.encode_bytes(event, data)

        if sid is None:
594
595
            sockets = list(self.sockets.values())
            for ws in sockets:
596
                await send_socket_catch_exception(ws.send_bytes, message)
space-nuko's avatar
space-nuko committed
597
        elif sid in self.sockets:
598
            await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
space-nuko's avatar
space-nuko committed
599
600

    async def send_json(self, event, data, sid=None):
pythongosssss's avatar
pythongosssss committed
601
602
603
        message = {"type": event, "data": data}

        if sid is None:
604
605
            sockets = list(self.sockets.values())
            for ws in sockets:
606
                await send_socket_catch_exception(ws.send_json, message)
pythongosssss's avatar
pythongosssss committed
607
        elif sid in self.sockets:
608
            await send_socket_catch_exception(self.sockets[sid].send_json, message)
pythongosssss's avatar
pythongosssss committed
609
610
611
612

    def send_sync(self, event, data, sid=None):
        self.loop.call_soon_threadsafe(
            self.messages.put_nowait, (event, data, sid))
613

pythongosssss's avatar
pythongosssss committed
614
615
616
617
618
619
620
621
    def queue_updated(self):
        self.send_sync("status", { "status": self.get_queue_info() })

    async def publish_loop(self):
        while True:
            msg = await self.messages.get()
            await self.send(*msg)

622
    async def start(self, address, port, verbose=True, call_on_start=None):
623
        runner = web.AppRunner(self.app, access_log=None)
pythongosssss's avatar
pythongosssss committed
624
625
626
        await runner.setup()
        site = web.TCPSite(runner, address, port)
        await site.start()
627

comfyanonymous's avatar
comfyanonymous committed
628
629
630
        if verbose:
            print("Starting server\n")
            print("To see the GUI go to: http://{}:{}".format(address, port))
631
632
633
        if call_on_start is not None:
            call_on_start(address, port)

634
635
636
637
638
639
640
641
642
643
644
645
    def add_on_prompt_handler(self, handler):
        self.on_prompt_handlers.append(handler)

    def trigger_on_prompt(self, json_data):
        for handler in self.on_prompt_handlers:
            try:
                json_data = handler(json_data)
            except Exception as e:
                print(f"[ERROR] An error occurred during the on_prompt_handler processing")
                traceback.print_exc()

        return json_data