Commit 6dc8e717 authored by April Hu's avatar April Hu
Browse files

Add gradio demo for flux.1 redux

parent 420ad33d
......@@ -104,6 +104,7 @@ Please refer to [comfyui/README.md](comfyui/README.md) for the usage in [ComfyUI
* Sketch-to-Image ([pix2pix-Turbo](https://github.com/GaParmar/img2img-turbo)): see [`app/flux.1/sketch`](app/flux.1/sketch).
* Depth/Canny-to-Image ([FLUX.1-tools](https://blackforestlabs.ai/flux-1-tools/)): see [`app/flux.1/depth_canny`](app/flux.1/depth_canny).
* Inpainting ([FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev)): see [`app/flux.1/fill`](app/flux.1/fill).
* Redux ([FLUX.1-Redux-dev](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev)): see [`app/flux.1/redux`](app/flux.1/redux).
* SANA:
* Text-to-image: see [`app/sana/t2i`](app/sana/t2i).
......
# Nunchaku INT4 FLUX.1 Redux Demo
![demo](./assets/demo.png)
This interactive Gradio application allows you to interactively generate image variations. The base model is [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev). We use [FLUX.1-Redux-dev](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) to preprocess the image before inputting it into Flux.1-dev. To launch the application, run:
```shell
python run_gradio.py
```
* By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
\ No newline at end of file
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg"
alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/>
INT4 FLUX.1-redux-dev Demo
</h1>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me'>Muyang Li*</a>,
<a href='https://yujunlin.com'>Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang'>Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/'>Tianle Cai</a>,
<a href='https://xiuyuli.com'>Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX'>Junxian Guo</a>,
<a href='https://xieenze.github.io'>Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/'>Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/'>Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan'>Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2411.05007">[Paper]</a>
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku'>
[Code]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/projects/svdquant'>
[Website]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/blog/svdquant'>
[Blog]
</a>
</div>
<h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor'>DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku'>Nunchaku</a>&nbsp;
</h4>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info}
</div>
{count_info}
</div>
</div>
\ No newline at end of file
@import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css');
.gradio-container{max-width: 1200px !important}
h1{text-align:center}
.wrap.svelte-p4aq0j.svelte-p4aq0j {
display: none;
}
#column_input, #column_output {
width: 500px;
display: flex;
align-items: center;
}
#input_header, #output_header {
display: flex;
justify-content: center;
align-items: center;
width: 400px;
}
#accessibility {
text-align: center; /* Center-aligns the text */
margin: auto; /* Centers the element horizontally */
}
#random_seed {height: 71px;}
#run_button {height: 87px;}
\ No newline at end of file
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import os
import random
import time
from datetime import datetime
import GPUtil
import torch
from diffusers import FluxPipeline, FluxPriorReduxPipeline
from PIL import Image
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
from utils import get_args
from vars import DEFAULT_GUIDANCE, DEFAULT_INFERENCE_STEP, EXAMPLES, MAX_SEED
# import gradio last to avoid conflicts with other imports
import gradio as gr
args = get_args()
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16
).to("cuda")
if args.precision == "bf16":
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
)
pipeline = pipeline.to("cuda")
pipeline.precision = "bf16"
else:
assert args.precision == "int4"
pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-dev"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
text_encoder=None,
text_encoder_2=None,
transformer=transformer,
torch_dtype=torch.bfloat16,
)
pipeline = pipeline.to("cuda")
pipeline.precision = "int4"
def run(
image, num_inference_steps: int, guidance_scale: float, seed: int
) -> tuple[Image, str]:
pipe_prior_output = pipe_prior_redux(image["composite"])
start_time = time.time()
result_image = pipeline(
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=torch.Generator().manual_seed(seed),
**pipe_prior_output,
).images[0]
latency = time.time() - start_time
if latency < 1:
latency = latency * 1000
latency_str = f"{latency:.2f}ms"
else:
latency_str = f"{latency:.2f}s"
torch.cuda.empty_cache()
if args.count_use:
if os.path.exists("use_count.txt"):
with open("use_count.txt", "r") as f:
count = int(f.read())
else:
count = 0
count += 1
current_time = datetime.now()
print(f"{current_time}: {count}")
with open("use_count.txt", "w") as f:
f.write(str(count))
with open("use_record.txt", "a") as f:
f.write(f"{current_time}: {count}\n")
return result_image, latency_str
with gr.Blocks(
css_paths="assets/style.css", title=f"SVDQuant Flux.1-redux-dev Demo"
) as demo:
with open("assets/description.html", "r") as f:
DESCRIPTION = f.read()
gpus = GPUtil.getGPUs()
if len(gpus) > 0:
gpu = gpus[0]
memory = gpu.memoryTotal / 1024
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
def get_header_str():
if args.count_use:
if os.path.exists("use_count.txt"):
with open("use_count.txt", "r") as f:
count = int(f.read())
else:
count = 0
count_info = (
f"<div style='display: flex; justify-content: center; align-items: center; text-align: center;'>"
f"<span style='font-size: 18px; font-weight: bold;'>Total inference runs: </span>"
f"<span style='font-size: 18px; color:red; font-weight: bold;'>&nbsp;{count}</span></div>"
)
else:
count_info = ""
header_str = DESCRIPTION.format(device_info=device_info, count_info=count_info)
return header_str
header = gr.HTML(get_header_str())
demo.load(fn=get_header_str, outputs=header)
with gr.Row(elem_id="main_row"):
with gr.Column(elem_id="column_input"):
gr.Markdown("## INPUT", elem_id="input_header")
with gr.Group():
canvas = gr.ImageEditor(
height=640,
image_mode="RGB",
sources=["upload", "clipboard"],
type="pil",
label="Input",
show_label=False,
show_download_button=True,
interactive=True,
transforms=[],
canvas_size=(1024, 1024),
scale=1,
format="png",
layers=False,
)
with gr.Row():
run_button = gr.Button("Run", elem_id="run_button")
with gr.Row():
seed = gr.Slider(
label="Seed",
show_label=True,
minimum=0,
maximum=MAX_SEED,
value=233,
step=1,
scale=4,
)
randomize_seed = gr.Button(
"Random Seed", scale=1, min_width=50, elem_id="random_seed"
)
with gr.Accordion("Advanced options", open=False):
with gr.Group():
num_inference_steps = gr.Slider(
label="Inference Steps",
minimum=10,
maximum=50,
step=1,
value=DEFAULT_INFERENCE_STEP,
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=10,
step=0.5,
value=DEFAULT_GUIDANCE,
)
with gr.Column(elem_id="column_output"):
gr.Markdown("## OUTPUT", elem_id="output_header")
with gr.Group():
result = gr.Image(
format="png",
height=640,
image_mode="RGB",
type="pil",
label="Result",
show_label=False,
show_download_button=True,
interactive=False,
elem_id="output_image",
)
latency_result = gr.Text(label="Inference Latency", show_label=True)
gr.Markdown("### Instructions")
gr.Markdown("**1**. Upload or paste an image")
gr.Markdown(
"**2**. Adjust the effect of sketch guidance and inference steps using sliders under Advanced options"
)
gr.Markdown("**3**. Try different seeds to generate different results")
run_inputs = [canvas, num_inference_steps, guidance_scale, seed]
run_outputs = [result, latency_result]
gr.Examples(examples=EXAMPLES, inputs=run_inputs, outputs=run_outputs, fn=run)
randomize_seed.click(
lambda: random.randint(0, MAX_SEED),
inputs=[],
outputs=seed,
api_name=False,
queue=False,
).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)
gr.on(
triggers=[run_button.click],
fn=run,
inputs=run_inputs,
outputs=run_outputs,
api_name=False,
)
gr.Markdown(
"MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility"
)
if __name__ == "__main__":
demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path)
import argparse
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"-p",
"--precision",
type=str,
default="int4",
choices=["int4", "bf16"],
help="Which precisions to use",
)
parser.add_argument(
"--count-use", action="store_true", help="Whether to count the number of uses"
)
parser.add_argument("--gradio-root-path", type=str, default="")
args = parser.parse_args()
return args
MAX_SEED = 1000000000
DEFAULT_GUIDANCE = 2.5
DEFAULT_INFERENCE_STEP = 50
EXAMPLES = [
[
"https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/example.png",
DEFAULT_INFERENCE_STEP,
DEFAULT_GUIDANCE,
1,
]
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment