# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py import os import random import tempfile import time from datetime import datetime import GPUtil import numpy as np import torch from PIL import Image from flux_pix2pix_pipeline import FluxPix2pixTurboPipeline from nunchaku.models.safety_checker import SafetyChecker from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel from utils import get_args from vars import DEFAULT_SKETCH_GUIDANCE, DEFAULT_STYLE_NAME, MAX_SEED, STYLE_NAMES, STYLES # import gradio last to avoid conflicts with other imports import gradio as gr blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255)) args = get_args() if args.precision == "bf16": pipeline = FluxPix2pixTurboPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) pipeline = pipeline.to("cuda") pipeline.precision = "bf16" pipeline.load_control_module( "mit-han-lab/svdq-flux.1-schnell-pix2pix-turbo", "sketch.safetensors", alpha=DEFAULT_SKETCH_GUIDANCE ) else: assert args.precision == "int4" pipeline_init_kwargs = {} transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell") pipeline_init_kwargs["transformer"] = transformer if args.use_qencoder: from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5") pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 pipeline = FluxPix2pixTurboPipeline.from_pretrained( "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs ) pipeline = pipeline.to("cuda") pipeline.precision = "int4" pipeline.load_control_module( "mit-han-lab/svdq-flux.1-schnell-pix2pix-turbo", "sketch.safetensors", svdq_lora_path="mit-han-lab/svdq-flux.1-schnell-pix2pix-turbo/svdq-int4-sketch.safetensors", alpha=DEFAULT_SKETCH_GUIDANCE, ) safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker) def save_image(img): if isinstance(img, dict): img = img["composite"] temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) img.save(temp_file.name) return temp_file.name def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed: int) -> tuple[Image, str]: print(f"Prompt: {prompt}") image_numpy = np.array(image["composite"].convert("RGB")) if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628): return blank_image, "Please input the prompt or draw something." is_unsafe_prompt = False if not safety_checker(prompt): is_unsafe_prompt = True prompt = "A peaceful world." prompt = prompt_template.format(prompt=prompt) start_time = time.time() result_image = pipeline( image=image["composite"], image_type="sketch", alpha=sketch_guidance, prompt=prompt, generator=torch.Generator().manual_seed(seed), ).images[0] latency = time.time() - start_time if latency < 1: latency = latency * 1000 latency_str = f"{latency:.2f}ms" else: latency_str = f"{latency:.2f}s" if is_unsafe_prompt: latency_str += " (Unsafe prompt detected)" torch.cuda.empty_cache() if args.count_use: if os.path.exists("use_count.txt"): with open("use_count.txt", "r") as f: count = int(f.read()) else: count = 0 count += 1 current_time = datetime.now() print(f"{current_time}: {count}") with open("use_count.txt", "w") as f: f.write(str(count)) with open("use_record.txt", "a") as f: f.write(f"{current_time}: {count}\n") return result_image, latency_str with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image Demo") as demo: with open("assets/description.html", "r") as f: DESCRIPTION = f.read() gpus = GPUtil.getGPUs() if len(gpus) > 0: gpu = gpus[0] memory = gpu.memoryTotal / 1024 device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." else: device_info = "Running on CPU 🥶 This demo does not work on CPU." notice = f'Notice: We will replace unsafe prompts with a default prompt: "A peaceful world."' def get_header_str(): if args.count_use: if os.path.exists("use_count.txt"): with open("use_count.txt", "r") as f: count = int(f.read()) else: count = 0 count_info = ( f"