Commit c17a2f6e authored by muyangli's avatar muyangli
Browse files

[major] add flux.1-redux; update the flux.1-tools demos

parent de9b25d6
...@@ -4,7 +4,8 @@ Nunchaku is an inference engine designed for 4-bit diffusion models, as demonstr ...@@ -4,7 +4,8 @@ Nunchaku is an inference engine designed for 4-bit diffusion models, as demonstr
### [Paper](http://arxiv.org/abs/2411.05007) | [Project](https://hanlab.mit.edu/projects/svdquant) | [Blog](https://hanlab.mit.edu/blog/svdquant) | [Demo](https://svdquant.mit.edu) ### [Paper](http://arxiv.org/abs/2411.05007) | [Project](https://hanlab.mit.edu/projects/svdquant) | [Blog](https://hanlab.mit.edu/blog/svdquant) | [Demo](https://svdquant.mit.edu)
- **[2025-02-04]** **🚀 4-bit [FLUX.1-tools](https://blackforestlabs.ai/flux-1-tools/) is here!** Enjoy a **2-3× speedup** over the original models. Check out the [examples](./examples) for usage. **Gradio demo and ComfyUI integration are coming soon!** - **[2025-02-11]** 🔥 **FLUX.1-tools Gradio demos are now available!** Check [here] for the usage details! Our new [depth-to-image demo](https://svdquant.mit.edu/flux.1-depth-dev/) is also online—try it out!
- **[2025-02-04]** **🚀 4-bit [FLUX.1-tools](https://blackforestlabs.ai/flux-1-tools/) is here!** Enjoy a **2-3× speedup** over the original models. Check out the [examples](./examples) for usage. **ComfyUI integration is coming soon!**
- **[2025-01-23]** 🚀 **4-bit [SANA](https://nvlabs.github.io/Sana/) support is here!** Experience a 2-3× speedup compared to the 16-bit model. Check out the [usage example](./examples/sana_1600m_pag.py) and the [deployment guide](app/sana/t2i) for more details. Explore our live demo at [svdquant.mit.edu](https://svdquant.mit.edu)! - **[2025-01-23]** 🚀 **4-bit [SANA](https://nvlabs.github.io/Sana/) support is here!** Experience a 2-3× speedup compared to the 16-bit model. Check out the [usage example](./examples/sana_1600m_pag.py) and the [deployment guide](app/sana/t2i) for more details. Explore our live demo at [svdquant.mit.edu](https://svdquant.mit.edu)!
- **[2025-01-22]** 🎉 [**SVDQuant**](http://arxiv.org/abs/2411.05007) has been accepted to **ICLR 2025**! - **[2025-01-22]** 🎉 [**SVDQuant**](http://arxiv.org/abs/2411.05007) has been accepted to **ICLR 2025**!
- **[2024-12-08]** Support [ComfyUI](https://github.com/comfyanonymous/ComfyUI). Please check [comfyui/README.md](comfyui/README.md) for the usage. - **[2024-12-08]** Support [ComfyUI](https://github.com/comfyanonymous/ComfyUI). Please check [comfyui/README.md](comfyui/README.md) for the usage.
...@@ -69,12 +70,12 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat ...@@ -69,12 +70,12 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat
cd nunchaku cd nunchaku
git submodule init git submodule init
git submodule update git submodule update
pip install -e . pip install -e . --no-build-isolation
``` ```
## Usage Example ## Usage Example
In [examples](examples), we provide minimal scripts for running INT4 [FLUX.1](https://github.com/black-forest-labs/flux) and [Sana](https://github.com/NVlabs/Sana) models with Nunchaku. For example, the [script](examples/flux.1-dev.py) for [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) is as follows: In [examples](examples), we provide minimal scripts for running INT4 [FLUX.1](https://github.com/black-forest-labs/flux) and [SANA](https://github.com/NVlabs/Sana) models with Nunchaku. For example, the [script](examples/flux.1-dev.py) for [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) is as follows:
```python ```python
import torch import torch
...@@ -98,39 +99,13 @@ Please refer to [comfyui/README.md](comfyui/README.md) for the usage in [ComfyUI ...@@ -98,39 +99,13 @@ Please refer to [comfyui/README.md](comfyui/README.md) for the usage in [ComfyUI
## Gradio Demos ## Gradio Demos
### FLUX.1 Models * FLUX.1 Models
* Text-to-image: see [`app/flux.1/t2i`](app/flux.1/t2i).
#### Text-to-Image * Sketch-to-Image ([pix2pix-Turbo](https://github.com/GaParmar/img2img-turbo)): see [`app/flux.1/sketch`](app/flux.1/sketch).
* Depth/Canny-to-Image ([FLUX.1-tools](https://blackforestlabs.ai/flux-1-tools/)): see [`app/flux.1/depth_canny`](app/flux.1/depth_canny).
```shell * Inpainting ([FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev)): see [`app/flux.1/fill`](app/flux.1/fill).
cd app/flux.1/t2i * SANA:
python run_gradio.py * Text-to-image: see [`app/sana.1/t2i`](app/sana.1/t2i).
```
* The demo also defaults to the FLUX.1-schnell model. To switch to the FLUX.1-dev model, use `-m dev`.
* By default, the Gemma-2B model is loaded as a safety checker. To disable this feature and save GPU memory, use `--no-safety-checker`.
* To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
* By default, only the INT4 DiT is loaded. Use `-p int4 bf16` to add a BF16 DiT for side-by-side comparison, or `-p bf16` to load only the BF16 model.
#### Sketch-to-Image
```shell
cd app/flux.1/i2i
python run_gradio.py
```
* Similarly, the demo loads the Gemma-2B model as a safety checker by default. To disable this feature, use `--no-safety-checker`.
* To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
* By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
### Sana
#### Text-to-Image
```shell
cd app/sana/t2i
python run_gradio.py
```
## Benchmark ## Benchmark
......
# Nunchaku INT4 FLUX.1 Depth/Canny-to-Image Demo
![demo](./assets/demo.jpg)
This interactive Gradio application transforms your uploaded image into a different style based on a text prompt. The generated image preserves either the depth map or Canny edge of the original image, depending on the selected model.
The base models are:
* [FLUX.1-Depth-dev](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) (preserves depth map)
* [FLUX.1-Canny-dev](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) (preserves Canny edge)
To launch the application, run:
```shell
python run_gradio.py
```
* By default, the model is `FLUX.1-Depth-dev`. You can add `-m canny` to switch to `FLUX.1-Canny-dev`.
* The demo loads the Gemma-2B model as a safety checker by default. To disable this feature, use `--no-safety-checker`.
* To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
* By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
\ No newline at end of file
...@@ -40,7 +40,6 @@ ...@@ -40,7 +40,6 @@
<h4>Quantization Library: <h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor'>DeepCompressor</a>&nbsp; <a href='https://github.com/mit-han-lab/deepcompressor'>DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku'>Nunchaku</a>&nbsp; Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku'>Nunchaku</a>&nbsp;
Image Control: <a href="https://github.com/GaParmar/img2img-turbo">img2img-turbo</a>
</h4> </h4>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info} {device_info}
......
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py # Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import logging
import os import os
import random import random
import tempfile
import time import time
from datetime import datetime from datetime import datetime
import GPUtil import GPUtil
import numpy as np
import torch import torch
from PIL import Image
from image_gen_aux import DepthPreprocessor
from controlnet_aux import CannyDetector from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline from diffusers import FluxControlPipeline
from image_gen_aux import DepthPreprocessor
from PIL import Image
from nunchaku.models.safety_checker import SafetyChecker from nunchaku.models.safety_checker import SafetyChecker
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
from utils import get_args from utils import get_args
from vars import DEFAULT_INFERENCE_STEP_CANNY, DEFAULT_GUIDANCE_CANNY, DEFAULT_INFERENCE_STEP_DEPTH, \ from vars import (
DEFAULT_GUIDANCE_DEPTH, DEFAULT_STYLE_NAME, MAX_SEED, STYLE_NAMES, STYLES DEFAULT_GUIDANCE_CANNY,
DEFAULT_GUIDANCE_DEPTH,
DEFAULT_INFERENCE_STEP_CANNY,
DEFAULT_INFERENCE_STEP_DEPTH,
DEFAULT_STYLE_NAME,
EXAMPLES,
HEIGHT,
MAX_SEED,
STYLE_NAMES,
STYLES,
WIDTH,
)
# import gradio last to avoid conflicts with other imports # import gradio last to avoid conflicts with other imports
import gradio as gr import gradio as gr
blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255))
args = get_args() args = get_args()
pipeline_class = None pipeline_class = None
...@@ -36,11 +42,13 @@ pipeline_class = FluxControlPipeline ...@@ -36,11 +42,13 @@ pipeline_class = FluxControlPipeline
if args.model == "canny": if args.model == "canny":
processor = CannyDetector() processor = CannyDetector()
else: else:
assert args.model == "depth", f"Model {args.model} not suppported" assert args.model == "depth", f"Model {args.model} not supported"
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf") processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
if args.precision == "bf16": if args.precision == "bf16":
pipeline = pipeline_class.from_pretrained(f"black-forest-labs/FLUX.1-{model_name.capitalize()}", torch_dtype=torch.bfloat16) pipeline = pipeline_class.from_pretrained(
f"black-forest-labs/FLUX.1-{model_name.capitalize()}", torch_dtype=torch.bfloat16
)
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
pipeline.precision = "bf16" pipeline.precision = "bf16"
else: else:
...@@ -63,41 +71,30 @@ else: ...@@ -63,41 +71,30 @@ else:
safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker) safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)
def save_image(img): def run(
if isinstance(img, dict): image, prompt: str, style: str, prompt_template: str, num_inference_steps: int, guidance_scale: float, seed: int
img = img["composite"] ) -> tuple[Image, str]:
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(temp_file.name)
return temp_file.name
def run(image, prompt: str, prompt_template: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]:
print(f"Prompt: {prompt}")
if args.model == "canny": if args.model == "canny":
processed_img = processor(image["composite"]).convert("RGB") processed_img = processor(image["composite"]).convert("RGB")
else: else:
assert args.model == "depth" assert args.model == "depth"
processed_img = processor(image["composite"])[0].convert("RGB") processed_img = processor(image["composite"])[0].convert("RGB")
image_numpy = np.array(processed_img)
if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628):
return blank_image, "Please input the prompt or draw something."
is_unsafe_prompt = False is_unsafe_prompt = False
if not safety_checker(prompt): if not safety_checker(prompt):
is_unsafe_prompt = True is_unsafe_prompt = True
prompt = "A peaceful world." prompt = "A peaceful world."
prompt = prompt_template.format(prompt=prompt) prompt = prompt_template.format(prompt=prompt)
print(f"Prompt: {prompt}")
start_time = time.time() start_time = time.time()
result_image = pipeline( result_image = pipeline(
prompt=prompt, prompt=prompt,
control_image=processed_img, control_image=processed_img,
height=1024, height=HEIGHT,
width=1024, width=WIDTH,
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale, guidance_scale=guidance_scale,
generator=torch.Generator().manual_seed(int(seed)), generator=torch.Generator().manual_seed(seed),
).images[0] ).images[0]
latency = time.time() - start_time latency = time.time() - start_time
...@@ -110,17 +107,17 @@ def run(image, prompt: str, prompt_template: str, num_inference_steps: int, guid ...@@ -110,17 +107,17 @@ def run(image, prompt: str, prompt_template: str, num_inference_steps: int, guid
latency_str += " (Unsafe prompt detected)" latency_str += " (Unsafe prompt detected)"
torch.cuda.empty_cache() torch.cuda.empty_cache()
if args.count_use: if args.count_use:
if os.path.exists("use_count.txt"): if os.path.exists(f"{args.model}-use_count.txt"):
with open("use_count.txt", "r") as f: with open(f"{args.model}-use_count.txt", "r") as f:
count = int(f.read()) count = int(f.read())
else: else:
count = 0 count = 0
count += 1 count += 1
current_time = datetime.now() current_time = datetime.now()
print(f"{current_time}: {count}") print(f"{current_time}: {count}")
with open("use_count.txt", "w") as f: with open(f"{args.model}-use_count.txt", "w") as f:
f.write(str(count)) f.write(str(count))
with open("use_record.txt", "a") as f: with open(f"{args.model}-use_record.txt", "a") as f:
f.write(f"{current_time}: {count}\n") f.write(f"{current_time}: {count}\n")
return result_image, latency_str return result_image, latency_str
...@@ -152,7 +149,9 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name ...@@ -152,7 +149,9 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name
) )
else: else:
count_info = "" count_info = ""
header_str = DESCRIPTION.format(model_name=args.model, device_info=device_info, notice=notice, count_info=count_info) header_str = DESCRIPTION.format(
model_name=args.model, device_info=device_info, notice=notice, count_info=count_info
)
return header_str return header_str
header = gr.HTML(get_header_str()) header = gr.HTML(get_header_str())
...@@ -162,13 +161,12 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name ...@@ -162,13 +161,12 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name
with gr.Column(elem_id="column_input"): with gr.Column(elem_id="column_input"):
gr.Markdown("## INPUT", elem_id="input_header") gr.Markdown("## INPUT", elem_id="input_header")
with gr.Group(): with gr.Group():
canvas = gr.Sketchpad( canvas = gr.ImageEditor(
value=blank_image,
height=640, height=640,
image_mode="RGB", image_mode="RGB",
sources=["upload", "clipboard"], sources=["upload", "clipboard"],
type="pil", type="pil",
label="Sketch", label="Input",
show_label=False, show_label=False,
show_download_button=True, show_download_button=True,
interactive=True, interactive=True,
...@@ -181,7 +179,6 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name ...@@ -181,7 +179,6 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name
with gr.Row(): with gr.Row():
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6) prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
run_button = gr.Button("Run", scale=1, elem_id="run_button") run_button = gr.Button("Run", scale=1, elem_id="run_button")
download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch")
with gr.Row(): with gr.Row():
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1) style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
prompt_template = gr.Textbox( prompt_template = gr.Textbox(
...@@ -193,10 +190,20 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name ...@@ -193,10 +190,20 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed") randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
with gr.Accordion("Advanced options", open=False): with gr.Accordion("Advanced options", open=False):
with gr.Group(): with gr.Group():
num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, \ num_inference_steps = gr.Slider(
value=DEFAULT_INFERENCE_STEP_CANNY if args.model == "canny" else DEFAULT_INFERENCE_STEP_DEPTH) label="Inference Steps",
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=50, step=1, \ minimum=10,
value=DEFAULT_GUIDANCE_CANNY if args.model == "canny" else DEFAULT_GUIDANCE_DEPTH) maximum=50,
step=1,
value=DEFAULT_INFERENCE_STEP_CANNY if args.model == "canny" else DEFAULT_INFERENCE_STEP_DEPTH,
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=50,
step=1,
value=DEFAULT_GUIDANCE_CANNY if args.model == "canny" else DEFAULT_GUIDANCE_DEPTH,
)
with gr.Column(elem_id="column_output"): with gr.Column(elem_id="column_output"):
gr.Markdown("## OUTPUT", elem_id="output_header") gr.Markdown("## OUTPUT", elem_id="output_header")
...@@ -214,17 +221,18 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name ...@@ -214,17 +221,18 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name
) )
latency_result = gr.Text(label="Inference Latency", show_label=True) latency_result = gr.Text(label="Inference Latency", show_label=True)
download_result = gr.DownloadButton("Download Result", elem_id="download_result")
gr.Markdown("### Instructions") gr.Markdown("### Instructions")
gr.Markdown("**1**. Enter a text prompt (e.g. a cat)") gr.Markdown("**1**. Enter a text prompt (e.g., a cat)")
gr.Markdown("**2**. Start sketching") gr.Markdown("**2**. Upload or paste an image")
gr.Markdown("**3**. Change the image style using a style template") gr.Markdown("**3**. Change the image style using a style template")
gr.Markdown("**4**. Adjust the effect of sketch guidance using the slider (typically between 0.2 and 0.4)") gr.Markdown("**4**. Adjust the effect of sketch guidance using the slider")
gr.Markdown("**5**. Try different seeds to generate different results") gr.Markdown("**5**. Try different seeds to generate different results")
run_inputs = [canvas, prompt, prompt_template, num_inference_steps, guidance_scale, seed] run_inputs = [canvas, prompt, style, prompt_template, num_inference_steps, guidance_scale, seed]
run_outputs = [result, latency_result] run_outputs = [result, latency_result]
gr.Examples(examples=EXAMPLES[args.model], inputs=run_inputs, outputs=run_outputs, fn=run)
randomize_seed.click( randomize_seed.click(
lambda: random.randint(0, MAX_SEED), lambda: random.randint(0, MAX_SEED),
inputs=[], inputs=[],
...@@ -239,19 +247,17 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name ...@@ -239,19 +247,17 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name
outputs=[prompt_template], outputs=[prompt_template],
api_name=False, api_name=False,
queue=False, queue=False,
).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False) )
gr.on( gr.on(
triggers=[prompt.submit, run_button.click, canvas.change], triggers=[prompt.submit, run_button.click],
fn=run, fn=run,
inputs=run_inputs, inputs=run_inputs,
outputs=run_outputs, outputs=run_outputs,
api_name=False, api_name=False,
) )
download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch)
download_result.click(fn=save_image, inputs=result, outputs=download_result)
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility") gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
if __name__ == "__main__": if __name__ == "__main__":
demo.queue().launch(debug=True, share=True) demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path)
...@@ -7,10 +7,11 @@ def get_args() -> argparse.Namespace: ...@@ -7,10 +7,11 @@ def get_args() -> argparse.Namespace:
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precisions to use" "-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precisions to use"
) )
parser.add_argument( parser.add_argument(
"-m", "--model", type=str, default="canny", choices=["canny", "depth"], help="Which FLUX.1 model to use" "-m", "--model", type=str, default="depth", choices=["canny", "depth"], help="Which FLUX.1 model to use"
) )
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder") parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker") parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses") parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
parser.add_argument("--gradio-root-path", type=str, default="")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -10,7 +10,7 @@ STYLES = { ...@@ -10,7 +10,7 @@ STYLES = {
"Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional", "Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
"Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style", "Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style",
} }
DEFAULT_STYLE_NAME = "3D Model" DEFAULT_STYLE_NAME = "None"
STYLE_NAMES = list(STYLES.keys()) STYLE_NAMES = list(STYLES.keys())
MAX_SEED = 1000000000 MAX_SEED = 1000000000
...@@ -19,3 +19,58 @@ DEFAULT_GUIDANCE_CANNY = 30.0 ...@@ -19,3 +19,58 @@ DEFAULT_GUIDANCE_CANNY = 30.0
DEFAULT_INFERENCE_STEP_DEPTH = 30 DEFAULT_INFERENCE_STEP_DEPTH = 30
DEFAULT_GUIDANCE_DEPTH = 10.0 DEFAULT_GUIDANCE_DEPTH = 10.0
HEIGHT = 1024
WIDTH = 1024
EXAMPLES = {
"canny": [
[
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png",
"A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts.",
DEFAULT_STYLE_NAME,
STYLES[DEFAULT_STYLE_NAME],
50,
30,
0,
],
[
"https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/example.png",
"A wooden basked of several individual cartons of strawberries.",
DEFAULT_STYLE_NAME,
STYLES[DEFAULT_STYLE_NAME],
50,
30,
1,
],
],
"depth": [
[
"https://huggingface.co/mit-han-lab/svdq-int4-flux.1-canny-dev/resolve/main/logo_example.png",
"A logo of 'MIT HAN Lab'.",
DEFAULT_STYLE_NAME,
STYLES[DEFAULT_STYLE_NAME],
30,
10,
2,
],
[
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png",
"A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts.",
DEFAULT_STYLE_NAME,
STYLES[DEFAULT_STYLE_NAME],
30,
10,
0,
],
[
"https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/example.png",
"A wooden basked of several individual cartons of strawberries.",
DEFAULT_STYLE_NAME,
STYLES[DEFAULT_STYLE_NAME],
30,
10,
1,
],
],
}
# Nunchaku INT4 FLUX.1 Inpainting Demo
![demo](./assets/demo.jpg)
This interactive Gradio application allows you to interactively inpaint an uploaded image based on a text prompt. The base model is [FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev). To launch the application, run:
```shell
python run_gradio.py
```
* The demo loads the Gemma-2B model as a safety checker by default. To disable this feature, use `--no-safety-checker`.
* To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
* By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
\ No newline at end of file
...@@ -40,7 +40,6 @@ ...@@ -40,7 +40,6 @@
<h4>Quantization Library: <h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor'>DeepCompressor</a>&nbsp; <a href='https://github.com/mit-han-lab/deepcompressor'>DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku'>Nunchaku</a>&nbsp; Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku'>Nunchaku</a>&nbsp;
Image Control: <a href="https://github.com/GaParmar/img2img-turbo">img2img-turbo</a>
</h4> </h4>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info} {device_info}
......
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py # Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import logging
import os import os
import random import random
import tempfile
import time import time
from datetime import datetime from datetime import datetime
import GPUtil import GPUtil
import numpy as np
import torch import torch
from diffusers import FluxFillPipeline
from PIL import Image from PIL import Image
from diffusers import FluxFillPipeline
from nunchaku.models.safety_checker import SafetyChecker from nunchaku.models.safety_checker import SafetyChecker
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
from utils import get_args from utils import get_args
from vars import DEFAULT_GUIDANCE, DEFAULT_INFERENCE_STEP, DEFAULT_STYLE_NAME, MAX_SEED, STYLE_NAMES, STYLES from vars import DEFAULT_GUIDANCE, DEFAULT_INFERENCE_STEP, DEFAULT_STYLE_NAME, EXAMPLES, MAX_SEED, STYLE_NAMES, STYLES
# import gradio last to avoid conflicts with other imports # import gradio last to avoid conflicts with other imports
import gradio as gr import gradio as gr
blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255))
args = get_args() args = get_args()
if args.precision == "bf16": if args.precision == "bf16":
...@@ -47,28 +42,19 @@ else: ...@@ -47,28 +42,19 @@ else:
safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker) safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)
def save_image(img): def run(
if isinstance(img, dict): image, prompt: str, style: str, prompt_template: str, num_inference_steps: int, guidance_scale: float, seed: int
img = img["composite"] ) -> tuple[Image, str]:
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(temp_file.name)
return temp_file.name
def run(image, prompt: str, prompt_template: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]:
print(f"Prompt: {prompt}") print(f"Prompt: {prompt}")
image_numpy = np.array(image["composite"].convert("RGB"))
if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628):
return blank_image, "Please input the prompt or draw something."
is_unsafe_prompt = False is_unsafe_prompt = False
if not safety_checker(prompt): if not safety_checker(prompt):
is_unsafe_prompt = True is_unsafe_prompt = True
prompt = "A peaceful world." prompt = "A peaceful world."
prompt = prompt_template.format(prompt=prompt) prompt = prompt_template.format(prompt=prompt)
mask = image["layers"][0].getchannel(3) # Mask is stored in the last channel mask = image["layers"][0].getchannel(3) # Mask is stored in the last channel
pic = image["background"].convert("RGB") # This is the original photo pic = image["background"].convert("RGB") # This is the original photo
start_time = time.time() start_time = time.time()
result_image = pipeline( result_image = pipeline(
...@@ -80,7 +66,7 @@ def run(image, prompt: str, prompt_template: str, num_inference_steps: int, guid ...@@ -80,7 +66,7 @@ def run(image, prompt: str, prompt_template: str, num_inference_steps: int, guid
width=1024, width=1024,
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
max_sequence_length=512, max_sequence_length=512,
generator=torch.Generator().manual_seed(seed) generator=torch.Generator().manual_seed(seed),
).images[0] ).images[0]
latency = time.time() - start_time latency = time.time() - start_time
...@@ -146,12 +132,11 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk ...@@ -146,12 +132,11 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk
gr.Markdown("## INPUT", elem_id="input_header") gr.Markdown("## INPUT", elem_id="input_header")
with gr.Group(): with gr.Group():
canvas = gr.ImageMask( canvas = gr.ImageMask(
value=blank_image,
height=640, height=640,
image_mode="RGBA", image_mode="RGBA",
sources=["upload", "clipboard"], sources=["upload", "clipboard"],
type="pil", type="pil",
label="Sketch", label="canvas",
show_label=False, show_label=False,
show_download_button=True, show_download_button=True,
interactive=True, interactive=True,
...@@ -160,11 +145,11 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk ...@@ -160,11 +145,11 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk
scale=1, scale=1,
format="png", format="png",
layers=False, layers=False,
brush=gr.Brush(default_size=30),
) )
with gr.Row(): with gr.Row():
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6) prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
run_button = gr.Button("Run", scale=1, elem_id="run_button") run_button = gr.Button("Run", scale=1, elem_id="run_button")
download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch")
with gr.Row(): with gr.Row():
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1) style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
prompt_template = gr.Textbox( prompt_template = gr.Textbox(
...@@ -176,8 +161,12 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk ...@@ -176,8 +161,12 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed") randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
with gr.Accordion("Advanced options", open=False): with gr.Accordion("Advanced options", open=False):
with gr.Group(): with gr.Group():
num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=DEFAULT_INFERENCE_STEP) num_inference_steps = gr.Slider(
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=50, step=1, value=DEFAULT_GUIDANCE) label="Inference Steps", minimum=10, maximum=50, step=1, value=DEFAULT_INFERENCE_STEP
)
guidance_scale = gr.Slider(
label="Guidance Scale", minimum=1, maximum=50, step=1, value=DEFAULT_GUIDANCE
)
with gr.Column(elem_id="column_output"): with gr.Column(elem_id="column_output"):
gr.Markdown("## OUTPUT", elem_id="output_header") gr.Markdown("## OUTPUT", elem_id="output_header")
...@@ -195,16 +184,16 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk ...@@ -195,16 +184,16 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk
) )
latency_result = gr.Text(label="Inference Latency", show_label=True) latency_result = gr.Text(label="Inference Latency", show_label=True)
download_result = gr.DownloadButton("Download Result", elem_id="download_result")
gr.Markdown("### Instructions") gr.Markdown("### Instructions")
gr.Markdown("**1**. Enter a text prompt (e.g. a cat)") gr.Markdown("**1**. Enter a text prompt (e.g., a cat)")
gr.Markdown("**2**. Start sketching") gr.Markdown("**2**. Upload the image and draw the inpainting mask")
gr.Markdown("**3**. Change the image style using a style template") gr.Markdown("**3**. Change the image style using a style template")
gr.Markdown("**4**. Adjust the effect of sketch guidance using the slider (typically between 0.2 and 0.4)") gr.Markdown("**4**. Adjust guidance scale using the slider")
gr.Markdown("**5**. Try different seeds to generate different results") gr.Markdown("**5**. Try different seeds to generate different results")
run_inputs = [canvas, prompt, prompt_template, num_inference_steps, guidance_scale, seed] run_inputs = [canvas, prompt, style, prompt_template, num_inference_steps, guidance_scale, seed]
run_outputs = [result, latency_result] run_outputs = [result, latency_result]
gr.Examples(examples=EXAMPLES, inputs=run_inputs, outputs=run_outputs, fn=run)
randomize_seed.click( randomize_seed.click(
lambda: random.randint(0, MAX_SEED), lambda: random.randint(0, MAX_SEED),
...@@ -220,19 +209,11 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk ...@@ -220,19 +209,11 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk
outputs=[prompt_template], outputs=[prompt_template],
api_name=False, api_name=False,
queue=False, queue=False,
).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False)
gr.on(
triggers=[prompt.submit, run_button.click, canvas.change],
fn=run,
inputs=run_inputs,
outputs=run_outputs,
api_name=False,
) )
gr.on(triggers=[prompt.submit, run_button.click], fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False)
download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch)
download_result.click(fn=save_image, inputs=result, outputs=download_result)
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility") gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
if __name__ == "__main__": if __name__ == "__main__":
demo.queue().launch(debug=True, share=True) demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path)
...@@ -9,5 +9,6 @@ def get_args() -> argparse.Namespace: ...@@ -9,5 +9,6 @@ def get_args() -> argparse.Namespace:
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder") parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker") parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses") parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
parser.add_argument("--gradio-root-path", type=str, default="")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -10,9 +10,24 @@ STYLES = { ...@@ -10,9 +10,24 @@ STYLES = {
"Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional", "Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
"Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style", "Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style",
} }
DEFAULT_STYLE_NAME = "3D Model" DEFAULT_STYLE_NAME = "None"
STYLE_NAMES = list(STYLES.keys()) STYLE_NAMES = list(STYLES.keys())
MAX_SEED = 1000000000 MAX_SEED = 1000000000
DEFAULT_GUIDANCE = 30 DEFAULT_GUIDANCE = 30
DEFAULT_INFERENCE_STEP = 50 DEFAULT_INFERENCE_STEP = 50
HEIGHT = 1024
WIDTH = 1024
EXAMPLES = [
[
"https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/example.png",
"A wooden basket of a cat.",
DEFAULT_STYLE_NAME,
STYLES[DEFAULT_STYLE_NAME],
DEFAULT_GUIDANCE,
DEFAULT_INFERENCE_STEP,
1,
]
]
# Nunchaku INT4 FLUX.1 Sketch-to-Image Demo
![demo](./assets/demo.jpg)
This interactive Gradio application transforms your drawing scribbles into realistic images given a text prompt. The base model is [FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) with the [pix2pix-Turbo](https://github.com/GaParmar/img2img-turbo) sketch LoRA.
To launch the application, simply run:
```shell
python run_gradio.py
```
* The demo loads the Gemma-2B model as a safety checker by default. To disable this feature, use `--no-safety-checker`.
* To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
* By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment