# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py import logging import os import random import tempfile import time from datetime import datetime import GPUtil import numpy as np import torch from PIL import Image from diffusers import FluxFillPipeline from nunchaku.models.safety_checker import SafetyChecker from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from utils import get_args from vars import DEFAULT_GUIDANCE, DEFAULT_INFERENCE_STEP, DEFAULT_STYLE_NAME, MAX_SEED, STYLE_NAMES, STYLES # import gradio last to avoid conflicts with other imports import gradio as gr blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255)) args = get_args() if args.precision == "bf16": pipeline = FluxFillPipeline.from_pretrained(f"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16) pipeline = pipeline.to("cuda") pipeline.precision = "bf16" else: assert args.precision == "int4" pipeline_init_kwargs = {} transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-int4-flux.1-fill-dev") pipeline_init_kwargs["transformer"] = transformer if args.use_qencoder: from nunchaku.models.text_encoder import NunchakuT5EncoderModel text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5") pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 pipeline = FluxFillPipeline.from_pretrained( f"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs ) pipeline = pipeline.to("cuda") pipeline.precision = "int4" safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker) def save_image(img): if isinstance(img, dict): img = img["composite"] temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) img.save(temp_file.name) return temp_file.name def run(image, prompt: str, prompt_template: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]: print(f"Prompt: {prompt}") image_numpy = np.array(image["composite"].convert("RGB")) if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628): return blank_image, "Please input the prompt or draw something." is_unsafe_prompt = False if not safety_checker(prompt): is_unsafe_prompt = True prompt = "A peaceful world." prompt = prompt_template.format(prompt=prompt) mask = image["layers"][0].getchannel(3) # Mask is stored in the last channel pic = image["background"].convert("RGB") # This is the original photo start_time = time.time() result_image = pipeline( prompt=prompt, image=pic, mask_image=mask, guidance_scale=guidance_scale, height=1024, width=1024, num_inference_steps=num_inference_steps, max_sequence_length=512, generator=torch.Generator().manual_seed(seed) ).images[0] latency = time.time() - start_time if latency < 1: latency = latency * 1000 latency_str = f"{latency:.2f}ms" else: latency_str = f"{latency:.2f}s" if is_unsafe_prompt: latency_str += " (Unsafe prompt detected)" torch.cuda.empty_cache() if args.count_use: if os.path.exists("use_count.txt"): with open("use_count.txt", "r") as f: count = int(f.read()) else: count = 0 count += 1 current_time = datetime.now() print(f"{current_time}: {count}") with open("use_count.txt", "w") as f: f.write(str(count)) with open("use_record.txt", "a") as f: f.write(f"{current_time}: {count}\n") return result_image, latency_str with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sketch-to-Image Demo") as demo: with open("assets/description.html", "r") as f: DESCRIPTION = f.read() gpus = GPUtil.getGPUs() if len(gpus) > 0: gpu = gpus[0] memory = gpu.memoryTotal / 1024 device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." else: device_info = "Running on CPU 🥶 This demo does not work on CPU." notice = f'Notice: We will replace unsafe prompts with a default prompt: "A peaceful world."' def get_header_str(): if args.count_use: if os.path.exists("use_count.txt"): with open("use_count.txt", "r") as f: count = int(f.read()) else: count = 0 count_info = ( f"