Commit e6cd772c authored by April Hu's avatar April Hu
Browse files

Add flux1 demo for depth and canny

parent 6c333071
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg"
alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/>
INT4 FLUX.1-{model_name}-dev Demo
</h1>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me'>Muyang Li*</a>,
<a href='https://yujunlin.com'>Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang'>Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/'>Tianle Cai</a>,
<a href='https://xiuyuli.com'>Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX'>Junxian Guo</a>,
<a href='https://xieenze.github.io'>Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/'>Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/'>Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan'>Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2411.05007">[Paper]</a>
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku'>
[Code]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/projects/svdquant'>
[Website]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/blog/svdquant'>
[Blog]
</a>
</div>
<h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor'>DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku'>Nunchaku</a>&nbsp;
Image Control: <a href="https://github.com/GaParmar/img2img-turbo">img2img-turbo</a>
</h4>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info}
</div>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{notice}
</div>
{count_info}
</div>
</div>
\ No newline at end of file
@import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css');
.gradio-container{max-width: 1200px !important}
h1{text-align:center}
.wrap.svelte-p4aq0j.svelte-p4aq0j {
display: none;
}
#column_input, #column_output {
width: 500px;
display: flex;
align-items: center;
}
#input_header, #output_header {
display: flex;
justify-content: center;
align-items: center;
width: 400px;
}
#accessibility {
text-align: center; /* Center-aligns the text */
margin: auto; /* Centers the element horizontally */
}
#random_seed {height: 71px;}
#run_button {height: 87px;}
\ No newline at end of file
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import logging
import os
import random
import tempfile
import time
from datetime import datetime
import GPUtil
import numpy as np
import torch
from PIL import Image
from image_gen_aux import DepthPreprocessor
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline
from nunchaku.models.safety_checker import SafetyChecker
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
from utils import get_args
from vars import DEFAULT_INFERENCE_STEP_CANNY, DEFAULT_GUIDANCE_CANNY, DEFAULT_INFERENCE_STEP_DEPTH, \
DEFAULT_GUIDANCE_DEPTH, DEFAULT_STYLE_NAME, MAX_SEED, STYLE_NAMES, STYLES
# import gradio last to avoid conflicts with other imports
import gradio as gr
blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255))
args = get_args()
pipeline_class = None
processor = None
model_name = None
model_name = f"{args.model}-dev"
pipeline_class = FluxControlPipeline
if args.model == "canny":
processor = CannyDetector()
else:
assert args.model == "depth", f"Model {args.model} not suppported"
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
if args.precision == "bf16":
pipeline = pipeline_class.from_pretrained(f"black-forest-labs/FLUX.1-{model_name.capitalize()}", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda")
pipeline.precision = "bf16"
else:
assert args.precision == "int4"
pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-int4-flux.1-{model_name}")
pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder:
from nunchaku.models.text_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = pipeline_class.from_pretrained(
f"black-forest-labs/FLUX.1-{model_name.capitalize()}", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
pipeline.precision = "int4"
safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)
def save_image(img):
if isinstance(img, dict):
img = img["composite"]
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(temp_file.name)
return temp_file.name
def run(image, prompt: str, prompt_template: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]:
print(f"Prompt: {prompt}")
if args.model == "canny":
processed_img = processor(image["composite"]).convert("RGB")
else:
assert args.model == "depth"
processed_img = processor(image["composite"])[0].convert("RGB")
image_numpy = np.array(processed_img)
if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628):
return blank_image, "Please input the prompt or draw something."
is_unsafe_prompt = False
if not safety_checker(prompt):
is_unsafe_prompt = True
prompt = "A peaceful world."
prompt = prompt_template.format(prompt=prompt)
start_time = time.time()
result_image = pipeline(
prompt=prompt,
control_image=processed_img,
height=1024,
width=1024,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=torch.Generator().manual_seed(int(seed)),
).images[0]
latency = time.time() - start_time
if latency < 1:
latency = latency * 1000
latency_str = f"{latency:.2f}ms"
else:
latency_str = f"{latency:.2f}s"
if is_unsafe_prompt:
latency_str += " (Unsafe prompt detected)"
torch.cuda.empty_cache()
if args.count_use:
if os.path.exists("use_count.txt"):
with open("use_count.txt", "r") as f:
count = int(f.read())
else:
count = 0
count += 1
current_time = datetime.now()
print(f"{current_time}: {count}")
with open("use_count.txt", "w") as f:
f.write(str(count))
with open("use_record.txt", "a") as f:
f.write(f"{current_time}: {count}\n")
return result_image, latency_str
with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name} Demo") as demo:
with open("assets/description.html", "r") as f:
DESCRIPTION = f.read()
gpus = GPUtil.getGPUs()
if len(gpus) > 0:
gpu = gpus[0]
memory = gpu.memoryTotal / 1024
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
def get_header_str():
if args.count_use:
if os.path.exists("use_count.txt"):
with open("use_count.txt", "r") as f:
count = int(f.read())
else:
count = 0
count_info = (
f"<div style='display: flex; justify-content: center; align-items: center; text-align: center;'>"
f"<span style='font-size: 18px; font-weight: bold;'>Total inference runs: </span>"
f"<span style='font-size: 18px; color:red; font-weight: bold;'>&nbsp;{count}</span></div>"
)
else:
count_info = ""
header_str = DESCRIPTION.format(model_name=args.model, device_info=device_info, notice=notice, count_info=count_info)
return header_str
header = gr.HTML(get_header_str())
demo.load(fn=get_header_str, outputs=header)
with gr.Row(elem_id="main_row"):
with gr.Column(elem_id="column_input"):
gr.Markdown("## INPUT", elem_id="input_header")
with gr.Group():
canvas = gr.Sketchpad(
value=blank_image,
height=640,
image_mode="RGB",
sources=["upload", "clipboard"],
type="pil",
label="Sketch",
show_label=False,
show_download_button=True,
interactive=True,
transforms=[],
canvas_size=(1024, 1024),
scale=1,
format="png",
layers=False,
)
with gr.Row():
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
run_button = gr.Button("Run", scale=1, elem_id="run_button")
download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch")
with gr.Row():
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
prompt_template = gr.Textbox(
label="Prompt Style Template", value=STYLES[DEFAULT_STYLE_NAME], scale=2, max_lines=1
)
with gr.Row():
seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
with gr.Accordion("Advanced options", open=False):
with gr.Group():
num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, \
value=DEFAULT_INFERENCE_STEP_CANNY if args.model == "canny" else DEFAULT_INFERENCE_STEP_DEPTH)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=50, step=1, \
value=DEFAULT_GUIDANCE_CANNY if args.model == "canny" else DEFAULT_GUIDANCE_DEPTH)
with gr.Column(elem_id="column_output"):
gr.Markdown("## OUTPUT", elem_id="output_header")
with gr.Group():
result = gr.Image(
format="png",
height=640,
image_mode="RGB",
type="pil",
label="Result",
show_label=False,
show_download_button=True,
interactive=False,
elem_id="output_image",
)
latency_result = gr.Text(label="Inference Latency", show_label=True)
download_result = gr.DownloadButton("Download Result", elem_id="download_result")
gr.Markdown("### Instructions")
gr.Markdown("**1**. Enter a text prompt (e.g. a cat)")
gr.Markdown("**2**. Start sketching")
gr.Markdown("**3**. Change the image style using a style template")
gr.Markdown("**4**. Adjust the effect of sketch guidance using the slider (typically between 0.2 and 0.4)")
gr.Markdown("**5**. Try different seeds to generate different results")
run_inputs = [canvas, prompt, prompt_template, num_inference_steps, guidance_scale, seed]
run_outputs = [result, latency_result]
randomize_seed.click(
lambda: random.randint(0, MAX_SEED),
inputs=[],
outputs=seed,
api_name=False,
queue=False,
).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)
style.change(
lambda x: STYLES[x],
inputs=[style],
outputs=[prompt_template],
api_name=False,
queue=False,
).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False)
gr.on(
triggers=[prompt.submit, run_button.click, canvas.change],
fn=run,
inputs=run_inputs,
outputs=run_outputs,
api_name=False,
)
download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch)
download_result.click(fn=save_image, inputs=result, outputs=download_result)
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
if __name__ == "__main__":
demo.queue().launch(debug=True, share=True)
import argparse
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precisions to use"
)
parser.add_argument(
"-m", "--model", type=str, default="canny", choices=["canny", "depth"], help="Which FLUX.1 model to use"
)
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
args = parser.parse_args()
return args
STYLES = {
"None": "{prompt}",
"Cinematic": "cinematic still {prompt}. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
"3D Model": "professional 3d model {prompt}. octane render, highly detailed, volumetric, dramatic lighting",
"Anime": "anime artwork {prompt}. anime style, key visual, vibrant, studio anime, highly detailed",
"Digital Art": "concept art {prompt}. digital artwork, illustrative, painterly, matte painting, highly detailed",
"Photographic": "cinematic photo {prompt}. 35mm photograph, film, bokeh, professional, 4k, highly detailed",
"Pixel art": "pixel-art {prompt}. low-res, blocky, pixel art style, 8-bit graphics",
"Fantasy art": "ethereal fantasy concept art of {prompt}. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
"Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
"Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style",
}
DEFAULT_STYLE_NAME = "3D Model"
STYLE_NAMES = list(STYLES.keys())
MAX_SEED = 1000000000
DEFAULT_INFERENCE_STEP_CANNY = 50
DEFAULT_GUIDANCE_CANNY = 30.0
DEFAULT_INFERENCE_STEP_DEPTH = 30
DEFAULT_GUIDANCE_DEPTH = 10.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment