Unverified Commit 259394ae authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

feat: update the kontext examples and models (#495)

* update kontext examples

* update tests

* add tests for kontext

* remove the warning of txt_ids and img_ids

* chore: add kontext to be synced from hf to ms

* add kontext demo

* make linter happy

* style: make linter happy

* update docs
parent 865561de
...@@ -19,6 +19,7 @@ jobs: ...@@ -19,6 +19,7 @@ jobs:
- nunchaku-flux.1-canny-dev - nunchaku-flux.1-canny-dev
- nunchaku-shuttle-jaguar - nunchaku-shuttle-jaguar
- nunchaku-sana - nunchaku-sana
- nunchaku-flux.1-kontext-dev
- svdq-fp4-flux.1-schnell - svdq-fp4-flux.1-schnell
- svdq-int4-flux.1-schnell - svdq-int4-flux.1-schnell
- svdq-fp4-flux.1-dev - svdq-fp4-flux.1-dev
......
...@@ -208,3 +208,5 @@ cython_debug/ ...@@ -208,3 +208,5 @@ cython_debug/
*.safetensors *.safetensors
*.onnx *.onnx
.gitattributes .gitattributes
nunchaku-models/
*.png
...@@ -15,6 +15,7 @@ Join our user groups on [**Slack**](https://join.slack.com/t/nunchaku/shared_inv ...@@ -15,6 +15,7 @@ Join our user groups on [**Slack**](https://join.slack.com/t/nunchaku/shared_inv
## News ## News
- **[2025-06-29]** 🔥 Support **FLUX.1-Kontext**! Try out our [example script](./examples/flux.1-kontext-dev.py) to see it in action!
- **[2025-06-01]** 🚀 **Release v0.3.0!** This update adds support for multiple-batch inference, [**ControlNet-Union-Pro 2.0**](https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0), initial integration of [**PuLID**](https://github.com/ToTheBeginning/PuLID), and introduces [**Double FB Cache**](examples/flux.1-dev-double_cache.py). You can now load Nunchaku FLUX models as a single file, and our upgraded [**4-bit T5 encoder**](https://huggingface.co/mit-han-lab/nunchaku-t5) now matches **FP8 T5** in quality! - **[2025-06-01]** 🚀 **Release v0.3.0!** This update adds support for multiple-batch inference, [**ControlNet-Union-Pro 2.0**](https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0), initial integration of [**PuLID**](https://github.com/ToTheBeginning/PuLID), and introduces [**Double FB Cache**](examples/flux.1-dev-double_cache.py). You can now load Nunchaku FLUX models as a single file, and our upgraded [**4-bit T5 encoder**](https://huggingface.co/mit-han-lab/nunchaku-t5) now matches **FP8 T5** in quality!
- **[2025-04-16]** 🎥 Released tutorial videos in both [**English**](https://youtu.be/YHAVe-oM7U8?si=cM9zaby_aEHiFXk0) and [**Chinese**](https://www.bilibili.com/video/BV1BTocYjEk5/?share_source=copy_web&vd_source=8926212fef622f25cc95380515ac74ee) to assist installation and usage. - **[2025-04-16]** 🎥 Released tutorial videos in both [**English**](https://youtu.be/YHAVe-oM7U8?si=cM9zaby_aEHiFXk0) and [**Chinese**](https://www.bilibili.com/video/BV1BTocYjEk5/?share_source=copy_web&vd_source=8926212fef622f25cc95380515ac74ee) to assist installation and usage.
- **[2025-04-09]** 📢 Published the [April roadmap](https://github.com/mit-han-lab/nunchaku/issues/266) and an [FAQ](https://github.com/mit-han-lab/nunchaku/discussions/262) to help the community get started and stay up to date with Nunchaku’s development. - **[2025-04-09]** 📢 Published the [April roadmap](https://github.com/mit-han-lab/nunchaku/issues/266) and an [FAQ](https://github.com/mit-han-lab/nunchaku/discussions/262) to help the community get started and stay up to date with Nunchaku’s development.
...@@ -275,6 +276,12 @@ You can specify individual strengths for each LoRA in the list. For a complete e ...@@ -275,6 +276,12 @@ You can specify individual strengths for each LoRA in the list. For a complete e
**For ComfyUI users, you can directly use our LoRA loader. The converted LoRA is deprecated. Please refer to [mit-han-lab/ComfyUI-nunchaku](https://github.com/mit-han-lab/ComfyUI-nunchaku) for more details.** **For ComfyUI users, you can directly use our LoRA loader. The converted LoRA is deprecated. Please refer to [mit-han-lab/ComfyUI-nunchaku](https://github.com/mit-han-lab/ComfyUI-nunchaku) for more details.**
## Kontext
Nunchaku supports [FLUX.1-Kontext-dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev), which enables natural language image editing. You can find the [example script](./examples/flux.1-kontext-dev.py) in our examples directory. **Note:** This feature requires diffusers>=0.35.
![kontext](https://huggingface.co/mit-han-lab/nunchaku-artifacts/resolve/main/nunchaku/assets/kontext.png)
## ControlNets ## ControlNets
Nunchaku supports both the [FLUX.1-tools](https://blackforestlabs.ai/flux-1-tools/) and the [FLUX.1-dev-ControlNet-Union-Pro](https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro) models. Example scripts can be found in the [`examples`](examples) directory. Nunchaku supports both the [FLUX.1-tools](https://blackforestlabs.ai/flux-1-tools/) and the [FLUX.1-dev-ControlNet-Union-Pro](https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro) models. Example scripts can be found in the [`examples`](examples) directory.
......
# Nunchaku INT4 FLUX.1 Kontext Demo
![demo](https://huggingface.co/mit-han-lab/nunchaku-artifacts/resolve/main/nunchaku/assets/kontext.png)
This interactive Gradio application allows you to edit an image with natural language. Simply run:
```shell
python run_gradio.py
```
- To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
- By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<!-- Logo Row -->
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
<a href="https://github.com/mit-han-lab/nunchaku">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
alt="nunchaku logo" style="height: 150px; width: auto;" />
</a>
<a href="https://hanlab.mit.edu/projects/svdquant">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="svdquant logo" style="height: 40px; width: auto;" />
</a>
</div>
<h1 style="margin-top: 0;">INT4 FLUX.1-Kontext-dev Demo</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info}
</div>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{notice}
</div>
{count_info}
</div>
</div>
@import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css');
.gradio-container {
max-width: 1200px !important;
margin: auto; /* Centers the element horizontally */
}
h1 {
text-align: center
}
.wrap.svelte-p4aq0j.svelte-p4aq0j {
display: none;
}
#column_input, #column_output {
width: 500px;
display: flex;
align-items: center;
}
#input_header, #output_header {
display: flex;
justify-content: center;
align-items: center;
width: 400px;
}
#accessibility {
text-align: center; /* Center-aligns the text */
margin: auto; /* Centers the element horizontally */
}
#random_seed {
height: 71px;
}
#run_button {
height: 87px;
}
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import os
import random
import time
from datetime import datetime
import torch
from diffusers import FluxKontextPipeline
from PIL import Image
from utils import get_args
from vars import EXAMPLES, MAX_SEED
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
# import gradio last to avoid conflicts with other imports
import gradio as gr # noqa: isort: skip
args = get_args()
if args.precision == "bf16":
pipeline = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda")
pipeline.precision = "bf16"
else:
assert args.precision == "int4"
pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-int4_r32-flux.1-kontext-dev.safetensors"
)
pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
pipeline.precision = "int4"
def run(image, prompt: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]:
img = image["composite"].convert("RGB")
start_time = time.time()
result_image = pipeline(
prompt=prompt,
image=img,
height=img.height,
width=img.width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=torch.Generator().manual_seed(seed),
).images[0]
latency = time.time() - start_time
if latency < 1:
latency = latency * 1000
latency_str = f"{latency:.2f}ms"
else:
latency_str = f"{latency:.2f}s"
torch.cuda.empty_cache()
if args.count_use:
if os.path.exists(f"{args.model}-use_count.txt"):
with open(f"{args.model}-use_count.txt", "r") as f:
count = int(f.read())
else:
count = 0
count += 1
current_time = datetime.now()
print(f"{current_time}: {count}")
with open(f"{args.model}-use_count.txt", "w") as f:
f.write(str(count))
with open(f"{args.model}-use_record.txt", "a") as f:
f.write(f"{current_time}: {count}\n")
return result_image, latency_str
with gr.Blocks(css_paths="assets/style.css", title="Nunchaku FLUX.1-Kontext Demo") as demo:
with open("assets/description.html", "r") as f:
DESCRIPTION = f.read()
# Get the GPU properties
if torch.cuda.device_count() > 0:
gpu_properties = torch.cuda.get_device_properties(0)
gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
gpu_name = torch.cuda.get_device_name(0)
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
def get_header_str():
if args.count_use:
if os.path.exists("use_count.txt"):
with open("use_count.txt", "r") as f:
count = int(f.read())
else:
count = 0
count_info = (
f"<div style='display: flex; justify-content: center; align-items: center; text-align: center;'>"
f"<span style='font-size: 18px; font-weight: bold;'>Total inference runs: </span>"
f"<span style='font-size: 18px; color:red; font-weight: bold;'>&nbsp;{count}</span></div>"
)
else:
count_info = ""
header_str = DESCRIPTION.format(device_info=device_info, notice=notice, count_info=count_info)
return header_str
header = gr.HTML(get_header_str())
demo.load(fn=get_header_str, outputs=header)
with gr.Row(elem_id="main_row"):
with gr.Column(elem_id="column_input"):
gr.Markdown("## INPUT", elem_id="input_header")
with gr.Group():
canvas = gr.ImageEditor(
height=640,
image_mode="RGB",
sources=["upload", "clipboard"],
type="pil",
label="Input",
show_label=False,
show_download_button=True,
interactive=True,
transforms=[],
canvas_size=(1024, 1024),
scale=1,
format="png",
layers=False,
)
with gr.Row():
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
run_button = gr.Button("Run", scale=1, elem_id="run_button")
with gr.Row():
seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
with gr.Accordion("Advanced options", open=False):
with gr.Group():
num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=28)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=2.5)
with gr.Column(elem_id="column_output"):
gr.Markdown("## OUTPUT", elem_id="output_header")
with gr.Group():
result = gr.Image(
format="png",
height=640,
image_mode="RGB",
type="pil",
label="Result",
show_label=False,
show_download_button=True,
interactive=False,
elem_id="output_image",
)
latency_result = gr.Text(label="Inference Latency", show_label=True)
gr.Markdown("### Instructions")
gr.Markdown("**1**. Enter a text prompt")
gr.Markdown("**2**. Upload an image")
gr.Markdown("**3**. Try different seeds to generate different results")
run_inputs = [canvas, prompt, num_inference_steps, guidance_scale, seed]
run_outputs = [result, latency_result]
gr.Examples(examples=EXAMPLES, inputs=run_inputs, outputs=run_outputs, fn=run)
randomize_seed.click(
lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False
).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)
gr.on(
triggers=[prompt.submit, run_button.click],
fn=run,
inputs=run_inputs,
outputs=run_outputs,
api_name=False,
)
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
if __name__ == "__main__":
demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path)
import argparse
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precisions to use"
)
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
parser.add_argument("--gradio-root-path", type=str, default="")
args = parser.parse_args()
return args
MAX_SEED = 1000000000
EXAMPLES = [
[
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png",
"Make Pikachu hold a sign that says 'Nunchaku is awesome', yarn art style, detailed, vibrant colors",
28,
2.5,
3,
],
]
import torch
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors"
)
pipeline = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
).convert("RGB")
prompt = "Make Pikachu hold a sign that says 'Nunchaku is awesome', yarn art style, detailed, vibrant colors"
image = pipeline(image=image, prompt=prompt, guidance_scale=2.5).images[0]
image.save("flux-kontext-dev.png")
...@@ -647,16 +647,8 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -647,16 +647,8 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
encoder_hidden_states = self.context_embedder(encoder_hidden_states) encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3: if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0] txt_ids = txt_ids[0]
if img_ids.ndim == 3: if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0] img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0) ids = torch.cat((txt_ids, img_ids), dim=0)
......
import gc
import os
from pathlib import Path
import pytest
import torch
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision, is_turing
from .utils import already_generate, compute_lpips, hash_str_to_int, offload_pipeline
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize("expected_lpips", [0.25 if get_precision() == "int4" else 0.18])
def test_flux_kontext(expected_lpips: float):
gc.collect()
torch.cuda.empty_cache()
precision = get_precision()
ref_root = Path(os.environ.get("NUNCHAKU_TEST_CACHE_ROOT", os.path.join("test_results", "ref")))
results_dir_16_bit = ref_root / "bf16" / "flux.1-kontext-dev" / "kontext"
results_dir_4_bit = Path("test_results") / precision / "flux.1-kontext-dev" / "kontext"
os.makedirs(results_dir_16_bit, exist_ok=True)
os.makedirs(results_dir_4_bit, exist_ok=True)
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
).convert("RGB")
prompts = [
"Make Pikachu hold a sign that says 'Nunchaku is awesome', yarn art style, detailed, vibrant colors",
"Convert the image to ghibli style",
"help me convert it to manga style",
"Convert it to a realistic photo",
]
# First, generate results with the 16-bit model
if not already_generate(results_dir_16_bit, 4):
pipeline = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
)
# Possibly offload the model to CPU when GPU memory is scarce
pipeline = offload_pipeline(pipeline)
for prompt in prompts:
seed = hash_str_to_int(prompt)
result = pipeline(image=image, prompt=prompt, generator=torch.Generator().manual_seed(seed)).images[0]
result.save(os.path.join(results_dir_16_bit, f"{seed}.png"))
# Clean up the 16-bit model
del pipeline.transformer
del pipeline.text_encoder
del pipeline.text_encoder_2
del pipeline.vae
del pipeline
del result
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
free, total = torch.cuda.mem_get_info() # bytes
print(f"After 16-bit generation: Free: {free/1024**2:.0f} MB / Total: {total/1024**2:.0f} MB")
# Then, generate results with the 4-bit model
if not already_generate(results_dir_4_bit, 4):
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{precision}_r32-flux.1-kontext-dev.safetensors"
)
pipeline = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
for prompt in prompts:
seed = hash_str_to_int(prompt)
result = pipeline(image=image, prompt=prompt, generator=torch.Generator().manual_seed(seed)).images[0]
result.save(os.path.join(results_dir_4_bit, f"{seed}.png"))
# Clean up the 4-bit model
del pipeline
del transformer
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
free, total = torch.cuda.mem_get_info() # bytes
print(f"After 4-bit generation: Free: {free/1024**2:.0f} MB / Total: {total/1024**2:.0f} MB")
lpips = compute_lpips(results_dir_16_bit, results_dir_4_bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.15
import os import os
from os import PathLike
from pathlib import Path
import datasets import datasets
import torch import torch
...@@ -18,10 +20,13 @@ def hash_str_to_int(s: str) -> int: ...@@ -18,10 +20,13 @@ def hash_str_to_int(s: str) -> int:
return hash_int return hash_int
def already_generate(save_dir: str, num_images) -> bool: def already_generate(save_dir: str | PathLike[str], num_images) -> bool:
if os.path.exists(save_dir): if isinstance(save_dir, str):
images = os.listdir(save_dir) save_dir = Path(save_dir)
images = [_ for _ in images if _.endswith(".png")] assert isinstance(save_dir, Path)
if save_dir.exists():
images = list(save_dir.iterdir())
images = [_ for _ in images if _.name.endswith(".png")]
if len(images) == num_images: if len(images) == num_images:
return True return True
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment