Commit ccb6b70d authored by comfyanonymous's avatar comfyanonymous
Browse files

Move image encoding outside of sampling loop for better preview perf.

parent 39c58b22
import torch
from PIL import Image, ImageOps
from io import BytesIO
from PIL import Image
import struct
import numpy as np
from comfy.cli_args import args, LatentPreviewMethod
......@@ -15,26 +14,7 @@ class LatentPreviewer:
def decode_latent_to_preview_image(self, preview_format, x0):
preview_image = self.decode_latent_to_preview(x0)
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.ANTIALIAS
preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), resampling)
preview_type = 1
if preview_format == "JPEG":
preview_type = 1
elif preview_format == "PNG":
preview_type = 2
bytesIO = BytesIO()
header = struct.pack(">I", preview_type)
bytesIO.write(header)
preview_image.save(bytesIO, format=preview_format, quality=95)
preview_bytes = bytesIO.getvalue()
return preview_bytes
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
class TAESDPreviewerImpl(LatentPreviewer):
def __init__(self, taesd):
......
......@@ -92,10 +92,10 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
def hijack_progress(server):
def hook(value, total, preview_image_bytes):
def hook(value, total, preview_image):
server.send_sync("progress", {"value": value, "max": total}, server.client_id)
if preview_image_bytes is not None:
server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id)
if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
comfy.utils.set_progress_bar_global_hook(hook)
......
......@@ -8,7 +8,7 @@ import uuid
import json
import glob
import struct
from PIL import Image
from PIL import Image, ImageOps
from io import BytesIO
try:
......@@ -29,6 +29,7 @@ import comfy.model_management
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
async def send_socket_catch_exception(function, message):
try:
......@@ -498,7 +499,9 @@ class PromptServer():
return prompt_info
async def send(self, event, data, sid=None):
if isinstance(data, (bytes, bytearray)):
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
await self.send_image(data, sid=sid)
elif isinstance(data, (bytes, bytearray)):
await self.send_bytes(event, data, sid)
else:
await self.send_json(event, data, sid)
......@@ -512,6 +515,30 @@ class PromptServer():
message.extend(data)
return message
async def send_image(self, image_data, sid=None):
image_type = image_data[0]
image = image_data[1]
max_size = image_data[2]
if max_size is not None:
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.ANTIALIAS
image = ImageOps.contain(image, (max_size, max_size), resampling)
type_num = 1
if image_type == "JPEG":
type_num = 1
elif image_type == "PNG":
type_num = 2
bytesIO = BytesIO()
header = struct.pack(">I", type_num)
bytesIO.write(header)
image.save(bytesIO, format=image_type, quality=95, compress_level=4)
preview_bytes = bytesIO.getvalue()
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
async def send_bytes(self, event, data, sid=None):
message = self.encode_bytes(event, data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment