# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py import os import random import time from datetime import datetime import GPUtil import torch from diffusers import FluxPipeline, FluxPriorReduxPipeline from PIL import Image from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from utils import get_args from vars import DEFAULT_GUIDANCE, DEFAULT_INFERENCE_STEP, EXAMPLES, MAX_SEED # import gradio last to avoid conflicts with other imports import gradio as gr args = get_args() pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained( "black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16 ).to("cuda") if args.precision == "bf16": pipeline = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 ) pipeline = pipeline.to("cuda") pipeline.precision = "bf16" else: assert args.precision == "int4" pipeline_init_kwargs = {} transformer = NunchakuFluxTransformer2dModel.from_pretrained( "mit-han-lab/svdq-int4-flux.1-dev" ) pipeline = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", text_encoder=None, text_encoder_2=None, transformer=transformer, torch_dtype=torch.bfloat16, ) pipeline = pipeline.to("cuda") pipeline.precision = "int4" def run( image, num_inference_steps: int, guidance_scale: float, seed: int ) -> tuple[Image, str]: pipe_prior_output = pipe_prior_redux(image["composite"]) start_time = time.time() result_image = pipeline( num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=torch.Generator().manual_seed(seed), **pipe_prior_output, ).images[0] latency = time.time() - start_time if latency < 1: latency = latency * 1000 latency_str = f"{latency:.2f}ms" else: latency_str = f"{latency:.2f}s" torch.cuda.empty_cache() if args.count_use: if os.path.exists("use_count.txt"): with open("use_count.txt", "r") as f: count = int(f.read()) else: count = 0 count += 1 current_time = datetime.now() print(f"{current_time}: {count}") with open("use_count.txt", "w") as f: f.write(str(count)) with open("use_record.txt", "a") as f: f.write(f"{current_time}: {count}\n") return result_image, latency_str with gr.Blocks( css_paths="assets/style.css", title=f"SVDQuant Flux.1-redux-dev Demo" ) as demo: with open("assets/description.html", "r") as f: DESCRIPTION = f.read() gpus = GPUtil.getGPUs() if len(gpus) > 0: gpu = gpus[0] memory = gpu.memoryTotal / 1024 device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." else: device_info = "Running on CPU 🥶 This demo does not work on CPU." def get_header_str(): if args.count_use: if os.path.exists("use_count.txt"): with open("use_count.txt", "r") as f: count = int(f.read()) else: count = 0 count_info = ( f"