Commit c17a2f6e authored by muyangli's avatar muyangli
Browse files

[major] add flux.1-redux; update the flux.1-tools demos

parent de9b25d6
......@@ -4,7 +4,8 @@ Nunchaku is an inference engine designed for 4-bit diffusion models, as demonstr
### [Paper](http://arxiv.org/abs/2411.05007) | [Project](https://hanlab.mit.edu/projects/svdquant) | [Blog](https://hanlab.mit.edu/blog/svdquant) | [Demo](https://svdquant.mit.edu)
- **[2025-02-04]** **🚀 4-bit [FLUX.1-tools](https://blackforestlabs.ai/flux-1-tools/) is here!** Enjoy a **2-3× speedup** over the original models. Check out the [examples](./examples) for usage. **Gradio demo and ComfyUI integration are coming soon!**
- **[2025-02-11]** 🔥 **FLUX.1-tools Gradio demos are now available!** Check [here] for the usage details! Our new [depth-to-image demo](https://svdquant.mit.edu/flux.1-depth-dev/) is also online—try it out!
- **[2025-02-04]** **🚀 4-bit [FLUX.1-tools](https://blackforestlabs.ai/flux-1-tools/) is here!** Enjoy a **2-3× speedup** over the original models. Check out the [examples](./examples) for usage. **ComfyUI integration is coming soon!**
- **[2025-01-23]** 🚀 **4-bit [SANA](https://nvlabs.github.io/Sana/) support is here!** Experience a 2-3× speedup compared to the 16-bit model. Check out the [usage example](./examples/sana_1600m_pag.py) and the [deployment guide](app/sana/t2i) for more details. Explore our live demo at [svdquant.mit.edu](https://svdquant.mit.edu)!
- **[2025-01-22]** 🎉 [**SVDQuant**](http://arxiv.org/abs/2411.05007) has been accepted to **ICLR 2025**!
- **[2024-12-08]** Support [ComfyUI](https://github.com/comfyanonymous/ComfyUI). Please check [comfyui/README.md](comfyui/README.md) for the usage.
......@@ -69,12 +70,12 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat
cd nunchaku
git submodule init
git submodule update
pip install -e .
pip install -e . --no-build-isolation
```
## Usage Example
In [examples](examples), we provide minimal scripts for running INT4 [FLUX.1](https://github.com/black-forest-labs/flux) and [Sana](https://github.com/NVlabs/Sana) models with Nunchaku. For example, the [script](examples/flux.1-dev.py) for [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) is as follows:
In [examples](examples), we provide minimal scripts for running INT4 [FLUX.1](https://github.com/black-forest-labs/flux) and [SANA](https://github.com/NVlabs/Sana) models with Nunchaku. For example, the [script](examples/flux.1-dev.py) for [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) is as follows:
```python
import torch
......@@ -98,39 +99,13 @@ Please refer to [comfyui/README.md](comfyui/README.md) for the usage in [ComfyUI
## Gradio Demos
### FLUX.1 Models
#### Text-to-Image
```shell
cd app/flux.1/t2i
python run_gradio.py
```
* The demo also defaults to the FLUX.1-schnell model. To switch to the FLUX.1-dev model, use `-m dev`.
* By default, the Gemma-2B model is loaded as a safety checker. To disable this feature and save GPU memory, use `--no-safety-checker`.
* To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
* By default, only the INT4 DiT is loaded. Use `-p int4 bf16` to add a BF16 DiT for side-by-side comparison, or `-p bf16` to load only the BF16 model.
#### Sketch-to-Image
```shell
cd app/flux.1/i2i
python run_gradio.py
```
* Similarly, the demo loads the Gemma-2B model as a safety checker by default. To disable this feature, use `--no-safety-checker`.
* To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
* By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
### Sana
#### Text-to-Image
```shell
cd app/sana/t2i
python run_gradio.py
```
* FLUX.1 Models
* Text-to-image: see [`app/flux.1/t2i`](app/flux.1/t2i).
* Sketch-to-Image ([pix2pix-Turbo](https://github.com/GaParmar/img2img-turbo)): see [`app/flux.1/sketch`](app/flux.1/sketch).
* Depth/Canny-to-Image ([FLUX.1-tools](https://blackforestlabs.ai/flux-1-tools/)): see [`app/flux.1/depth_canny`](app/flux.1/depth_canny).
* Inpainting ([FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev)): see [`app/flux.1/fill`](app/flux.1/fill).
* SANA:
* Text-to-image: see [`app/sana.1/t2i`](app/sana.1/t2i).
## Benchmark
......
# Nunchaku INT4 FLUX.1 Depth/Canny-to-Image Demo
![demo](./assets/demo.jpg)
This interactive Gradio application transforms your uploaded image into a different style based on a text prompt. The generated image preserves either the depth map or Canny edge of the original image, depending on the selected model.
The base models are:
* [FLUX.1-Depth-dev](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) (preserves depth map)
* [FLUX.1-Canny-dev](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) (preserves Canny edge)
To launch the application, run:
```shell
python run_gradio.py
```
* By default, the model is `FLUX.1-Depth-dev`. You can add `-m canny` to switch to `FLUX.1-Canny-dev`.
* The demo loads the Gemma-2B model as a safety checker by default. To disable this feature, use `--no-safety-checker`.
* To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
* By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
\ No newline at end of file
......@@ -40,7 +40,6 @@
<h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor'>DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku'>Nunchaku</a>&nbsp;
Image Control: <a href="https://github.com/GaParmar/img2img-turbo">img2img-turbo</a>
</h4>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info}
......
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import logging
import os
import random
import tempfile
import time
from datetime import datetime
import GPUtil
import numpy as np
import torch
from PIL import Image
from image_gen_aux import DepthPreprocessor
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline
from image_gen_aux import DepthPreprocessor
from PIL import Image
from nunchaku.models.safety_checker import SafetyChecker
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
from utils import get_args
from vars import DEFAULT_INFERENCE_STEP_CANNY, DEFAULT_GUIDANCE_CANNY, DEFAULT_INFERENCE_STEP_DEPTH, \
DEFAULT_GUIDANCE_DEPTH, DEFAULT_STYLE_NAME, MAX_SEED, STYLE_NAMES, STYLES
from vars import (
DEFAULT_GUIDANCE_CANNY,
DEFAULT_GUIDANCE_DEPTH,
DEFAULT_INFERENCE_STEP_CANNY,
DEFAULT_INFERENCE_STEP_DEPTH,
DEFAULT_STYLE_NAME,
EXAMPLES,
HEIGHT,
MAX_SEED,
STYLE_NAMES,
STYLES,
WIDTH,
)
# import gradio last to avoid conflicts with other imports
import gradio as gr
blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255))
args = get_args()
pipeline_class = None
......@@ -36,11 +42,13 @@ pipeline_class = FluxControlPipeline
if args.model == "canny":
processor = CannyDetector()
else:
assert args.model == "depth", f"Model {args.model} not suppported"
assert args.model == "depth", f"Model {args.model} not supported"
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
if args.precision == "bf16":
pipeline = pipeline_class.from_pretrained(f"black-forest-labs/FLUX.1-{model_name.capitalize()}", torch_dtype=torch.bfloat16)
pipeline = pipeline_class.from_pretrained(
f"black-forest-labs/FLUX.1-{model_name.capitalize()}", torch_dtype=torch.bfloat16
)
pipeline = pipeline.to("cuda")
pipeline.precision = "bf16"
else:
......@@ -63,41 +71,30 @@ else:
safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)
def save_image(img):
if isinstance(img, dict):
img = img["composite"]
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(temp_file.name)
return temp_file.name
def run(image, prompt: str, prompt_template: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]:
print(f"Prompt: {prompt}")
def run(
image, prompt: str, style: str, prompt_template: str, num_inference_steps: int, guidance_scale: float, seed: int
) -> tuple[Image, str]:
if args.model == "canny":
processed_img = processor(image["composite"]).convert("RGB")
else:
assert args.model == "depth"
processed_img = processor(image["composite"])[0].convert("RGB")
image_numpy = np.array(processed_img)
if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628):
return blank_image, "Please input the prompt or draw something."
is_unsafe_prompt = False
if not safety_checker(prompt):
is_unsafe_prompt = True
prompt = "A peaceful world."
prompt = prompt_template.format(prompt=prompt)
print(f"Prompt: {prompt}")
start_time = time.time()
result_image = pipeline(
prompt=prompt,
control_image=processed_img,
height=1024,
width=1024,
num_inference_steps=num_inference_steps,
prompt=prompt,
control_image=processed_img,
height=HEIGHT,
width=WIDTH,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=torch.Generator().manual_seed(int(seed)),
generator=torch.Generator().manual_seed(seed),
).images[0]
latency = time.time() - start_time
......@@ -110,17 +107,17 @@ def run(image, prompt: str, prompt_template: str, num_inference_steps: int, guid
latency_str += " (Unsafe prompt detected)"
torch.cuda.empty_cache()
if args.count_use:
if os.path.exists("use_count.txt"):
with open("use_count.txt", "r") as f:
if os.path.exists(f"{args.model}-use_count.txt"):
with open(f"{args.model}-use_count.txt", "r") as f:
count = int(f.read())
else:
count = 0
count += 1
current_time = datetime.now()
print(f"{current_time}: {count}")
with open("use_count.txt", "w") as f:
with open(f"{args.model}-use_count.txt", "w") as f:
f.write(str(count))
with open("use_record.txt", "a") as f:
with open(f"{args.model}-use_record.txt", "a") as f:
f.write(f"{current_time}: {count}\n")
return result_image, latency_str
......@@ -152,7 +149,9 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name
)
else:
count_info = ""
header_str = DESCRIPTION.format(model_name=args.model, device_info=device_info, notice=notice, count_info=count_info)
header_str = DESCRIPTION.format(
model_name=args.model, device_info=device_info, notice=notice, count_info=count_info
)
return header_str
header = gr.HTML(get_header_str())
......@@ -162,13 +161,12 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name
with gr.Column(elem_id="column_input"):
gr.Markdown("## INPUT", elem_id="input_header")
with gr.Group():
canvas = gr.Sketchpad(
value=blank_image,
canvas = gr.ImageEditor(
height=640,
image_mode="RGB",
sources=["upload", "clipboard"],
type="pil",
label="Sketch",
label="Input",
show_label=False,
show_download_button=True,
interactive=True,
......@@ -181,7 +179,6 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name
with gr.Row():
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
run_button = gr.Button("Run", scale=1, elem_id="run_button")
download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch")
with gr.Row():
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
prompt_template = gr.Textbox(
......@@ -193,10 +190,20 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
with gr.Accordion("Advanced options", open=False):
with gr.Group():
num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, \
value=DEFAULT_INFERENCE_STEP_CANNY if args.model == "canny" else DEFAULT_INFERENCE_STEP_DEPTH)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=50, step=1, \
value=DEFAULT_GUIDANCE_CANNY if args.model == "canny" else DEFAULT_GUIDANCE_DEPTH)
num_inference_steps = gr.Slider(
label="Inference Steps",
minimum=10,
maximum=50,
step=1,
value=DEFAULT_INFERENCE_STEP_CANNY if args.model == "canny" else DEFAULT_INFERENCE_STEP_DEPTH,
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=50,
step=1,
value=DEFAULT_GUIDANCE_CANNY if args.model == "canny" else DEFAULT_GUIDANCE_DEPTH,
)
with gr.Column(elem_id="column_output"):
gr.Markdown("## OUTPUT", elem_id="output_header")
......@@ -214,17 +221,18 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name
)
latency_result = gr.Text(label="Inference Latency", show_label=True)
download_result = gr.DownloadButton("Download Result", elem_id="download_result")
gr.Markdown("### Instructions")
gr.Markdown("**1**. Enter a text prompt (e.g. a cat)")
gr.Markdown("**2**. Start sketching")
gr.Markdown("**1**. Enter a text prompt (e.g., a cat)")
gr.Markdown("**2**. Upload or paste an image")
gr.Markdown("**3**. Change the image style using a style template")
gr.Markdown("**4**. Adjust the effect of sketch guidance using the slider (typically between 0.2 and 0.4)")
gr.Markdown("**4**. Adjust the effect of sketch guidance using the slider")
gr.Markdown("**5**. Try different seeds to generate different results")
run_inputs = [canvas, prompt, prompt_template, num_inference_steps, guidance_scale, seed]
run_inputs = [canvas, prompt, style, prompt_template, num_inference_steps, guidance_scale, seed]
run_outputs = [result, latency_result]
gr.Examples(examples=EXAMPLES[args.model], inputs=run_inputs, outputs=run_outputs, fn=run)
randomize_seed.click(
lambda: random.randint(0, MAX_SEED),
inputs=[],
......@@ -239,19 +247,17 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name
outputs=[prompt_template],
api_name=False,
queue=False,
).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False)
)
gr.on(
triggers=[prompt.submit, run_button.click, canvas.change],
triggers=[prompt.submit, run_button.click],
fn=run,
inputs=run_inputs,
outputs=run_outputs,
api_name=False,
)
download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch)
download_result.click(fn=save_image, inputs=result, outputs=download_result)
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
if __name__ == "__main__":
demo.queue().launch(debug=True, share=True)
demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path)
......@@ -7,10 +7,11 @@ def get_args() -> argparse.Namespace:
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precisions to use"
)
parser.add_argument(
"-m", "--model", type=str, default="canny", choices=["canny", "depth"], help="Which FLUX.1 model to use"
"-m", "--model", type=str, default="depth", choices=["canny", "depth"], help="Which FLUX.1 model to use"
)
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
parser.add_argument("--gradio-root-path", type=str, default="")
args = parser.parse_args()
return args
......@@ -10,7 +10,7 @@ STYLES = {
"Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
"Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style",
}
DEFAULT_STYLE_NAME = "3D Model"
DEFAULT_STYLE_NAME = "None"
STYLE_NAMES = list(STYLES.keys())
MAX_SEED = 1000000000
......@@ -19,3 +19,58 @@ DEFAULT_GUIDANCE_CANNY = 30.0
DEFAULT_INFERENCE_STEP_DEPTH = 30
DEFAULT_GUIDANCE_DEPTH = 10.0
HEIGHT = 1024
WIDTH = 1024
EXAMPLES = {
"canny": [
[
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png",
"A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts.",
DEFAULT_STYLE_NAME,
STYLES[DEFAULT_STYLE_NAME],
50,
30,
0,
],
[
"https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/example.png",
"A wooden basked of several individual cartons of strawberries.",
DEFAULT_STYLE_NAME,
STYLES[DEFAULT_STYLE_NAME],
50,
30,
1,
],
],
"depth": [
[
"https://huggingface.co/mit-han-lab/svdq-int4-flux.1-canny-dev/resolve/main/logo_example.png",
"A logo of 'MIT HAN Lab'.",
DEFAULT_STYLE_NAME,
STYLES[DEFAULT_STYLE_NAME],
30,
10,
2,
],
[
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png",
"A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts.",
DEFAULT_STYLE_NAME,
STYLES[DEFAULT_STYLE_NAME],
30,
10,
0,
],
[
"https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/example.png",
"A wooden basked of several individual cartons of strawberries.",
DEFAULT_STYLE_NAME,
STYLES[DEFAULT_STYLE_NAME],
30,
10,
1,
],
],
}
# Nunchaku INT4 FLUX.1 Inpainting Demo
![demo](./assets/demo.jpg)
This interactive Gradio application allows you to interactively inpaint an uploaded image based on a text prompt. The base model is [FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev). To launch the application, run:
```shell
python run_gradio.py
```
* The demo loads the Gemma-2B model as a safety checker by default. To disable this feature, use `--no-safety-checker`.
* To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
* By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
\ No newline at end of file
......@@ -40,7 +40,6 @@
<h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor'>DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku'>Nunchaku</a>&nbsp;
Image Control: <a href="https://github.com/GaParmar/img2img-turbo">img2img-turbo</a>
</h4>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info}
......
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import logging
import os
import random
import tempfile
import time
from datetime import datetime
import GPUtil
import numpy as np
import torch
from diffusers import FluxFillPipeline
from PIL import Image
from diffusers import FluxFillPipeline
from nunchaku.models.safety_checker import SafetyChecker
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
from utils import get_args
from vars import DEFAULT_GUIDANCE, DEFAULT_INFERENCE_STEP, DEFAULT_STYLE_NAME, MAX_SEED, STYLE_NAMES, STYLES
from vars import DEFAULT_GUIDANCE, DEFAULT_INFERENCE_STEP, DEFAULT_STYLE_NAME, EXAMPLES, MAX_SEED, STYLE_NAMES, STYLES
# import gradio last to avoid conflicts with other imports
import gradio as gr
blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255))
args = get_args()
if args.precision == "bf16":
......@@ -47,28 +42,19 @@ else:
safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)
def save_image(img):
if isinstance(img, dict):
img = img["composite"]
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(temp_file.name)
return temp_file.name
def run(
image, prompt: str, style: str, prompt_template: str, num_inference_steps: int, guidance_scale: float, seed: int
) -> tuple[Image, str]:
def run(image, prompt: str, prompt_template: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]:
print(f"Prompt: {prompt}")
image_numpy = np.array(image["composite"].convert("RGB"))
if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628):
return blank_image, "Please input the prompt or draw something."
is_unsafe_prompt = False
if not safety_checker(prompt):
is_unsafe_prompt = True
prompt = "A peaceful world."
prompt = prompt_template.format(prompt=prompt)
mask = image["layers"][0].getchannel(3) # Mask is stored in the last channel
pic = image["background"].convert("RGB") # This is the original photo
mask = image["layers"][0].getchannel(3) # Mask is stored in the last channel
pic = image["background"].convert("RGB") # This is the original photo
start_time = time.time()
result_image = pipeline(
......@@ -80,7 +66,7 @@ def run(image, prompt: str, prompt_template: str, num_inference_steps: int, guid
width=1024,
num_inference_steps=num_inference_steps,
max_sequence_length=512,
generator=torch.Generator().manual_seed(seed)
generator=torch.Generator().manual_seed(seed),
).images[0]
latency = time.time() - start_time
......@@ -146,12 +132,11 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk
gr.Markdown("## INPUT", elem_id="input_header")
with gr.Group():
canvas = gr.ImageMask(
value=blank_image,
height=640,
image_mode="RGBA",
sources=["upload", "clipboard"],
type="pil",
label="Sketch",
label="canvas",
show_label=False,
show_download_button=True,
interactive=True,
......@@ -160,11 +145,11 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk
scale=1,
format="png",
layers=False,
brush=gr.Brush(default_size=30),
)
with gr.Row():
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
run_button = gr.Button("Run", scale=1, elem_id="run_button")
download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch")
with gr.Row():
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
prompt_template = gr.Textbox(
......@@ -176,8 +161,12 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
with gr.Accordion("Advanced options", open=False):
with gr.Group():
num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=DEFAULT_INFERENCE_STEP)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=50, step=1, value=DEFAULT_GUIDANCE)
num_inference_steps = gr.Slider(
label="Inference Steps", minimum=10, maximum=50, step=1, value=DEFAULT_INFERENCE_STEP
)
guidance_scale = gr.Slider(
label="Guidance Scale", minimum=1, maximum=50, step=1, value=DEFAULT_GUIDANCE
)
with gr.Column(elem_id="column_output"):
gr.Markdown("## OUTPUT", elem_id="output_header")
......@@ -195,16 +184,16 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk
)
latency_result = gr.Text(label="Inference Latency", show_label=True)
download_result = gr.DownloadButton("Download Result", elem_id="download_result")
gr.Markdown("### Instructions")
gr.Markdown("**1**. Enter a text prompt (e.g. a cat)")
gr.Markdown("**2**. Start sketching")
gr.Markdown("**1**. Enter a text prompt (e.g., a cat)")
gr.Markdown("**2**. Upload the image and draw the inpainting mask")
gr.Markdown("**3**. Change the image style using a style template")
gr.Markdown("**4**. Adjust the effect of sketch guidance using the slider (typically between 0.2 and 0.4)")
gr.Markdown("**4**. Adjust guidance scale using the slider")
gr.Markdown("**5**. Try different seeds to generate different results")
run_inputs = [canvas, prompt, prompt_template, num_inference_steps, guidance_scale, seed]
run_inputs = [canvas, prompt, style, prompt_template, num_inference_steps, guidance_scale, seed]
run_outputs = [result, latency_result]
gr.Examples(examples=EXAMPLES, inputs=run_inputs, outputs=run_outputs, fn=run)
randomize_seed.click(
lambda: random.randint(0, MAX_SEED),
......@@ -220,19 +209,11 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk
outputs=[prompt_template],
api_name=False,
queue=False,
).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False)
gr.on(
triggers=[prompt.submit, run_button.click, canvas.change],
fn=run,
inputs=run_inputs,
outputs=run_outputs,
api_name=False,
)
gr.on(triggers=[prompt.submit, run_button.click], fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False)
download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch)
download_result.click(fn=save_image, inputs=result, outputs=download_result)
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
if __name__ == "__main__":
demo.queue().launch(debug=True, share=True)
demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path)
......@@ -9,5 +9,6 @@ def get_args() -> argparse.Namespace:
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
parser.add_argument("--gradio-root-path", type=str, default="")
args = parser.parse_args()
return args
......@@ -10,9 +10,24 @@ STYLES = {
"Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
"Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style",
}
DEFAULT_STYLE_NAME = "3D Model"
DEFAULT_STYLE_NAME = "None"
STYLE_NAMES = list(STYLES.keys())
MAX_SEED = 1000000000
DEFAULT_GUIDANCE = 30
DEFAULT_INFERENCE_STEP = 50
HEIGHT = 1024
WIDTH = 1024
EXAMPLES = [
[
"https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/example.png",
"A wooden basket of a cat.",
DEFAULT_STYLE_NAME,
STYLES[DEFAULT_STYLE_NAME],
DEFAULT_GUIDANCE,
DEFAULT_INFERENCE_STEP,
1,
]
]
# Nunchaku INT4 FLUX.1 Sketch-to-Image Demo
![demo](./assets/demo.jpg)
This interactive Gradio application transforms your drawing scribbles into realistic images given a text prompt. The base model is [FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) with the [pix2pix-Turbo](https://github.com/GaParmar/img2img-turbo) sketch LoRA.
To launch the application, simply run:
```shell
python run_gradio.py
```
* The demo loads the Gemma-2B model as a safety checker by default. To disable this feature, use `--no-safety-checker`.
* To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
* By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment