Unverified Commit e419dc29 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

fix: update the kontext demos (#515)

* fix the use count

* update

* change the example step to 20

* allow embedding website

* support fp4 on blackwell

* fp4 kontext runnable

* add more examples

* add kontext examples
parent 121ee754
......@@ -2,15 +2,13 @@
<div>
<!-- Logo Row -->
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
<a href="https://github.com/mit-han-lab/nunchaku">
<a href="https://github.com/mit-han-lab/nunchaku" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
alt="nunchaku logo"
style="height: 150px; width: auto;"/>
alt="nunchaku logo" style="height: 150px; width: auto;" />
</a>
<a href="https://hanlab.mit.edu/projects/svdquant">
<a href="https://hanlab.mit.edu/projects/svdquant" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="svdquant logo"
style="height: 40px; width: auto;"/>
alt="svdquant logo" style="height: 40px; width: auto;" />
</a>
</div>
<h1 style="margin-top: 0;">INT4 FLUX.1-{model_name}-dev Demo</h1>
......
......@@ -2,24 +2,18 @@
<div>
<!-- Logo Row -->
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
<a href="https://github.com/mit-han-lab/nunchaku">
<a href="https://github.com/mit-han-lab/nunchaku" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
alt="nunchaku logo"
style="height: 150px; width: auto;"/>
alt="nunchaku logo" style="height: 150px; width: auto;" />
</a>
<a href="https://hanlab.mit.edu/projects/svdquant">
<a href="https://hanlab.mit.edu/projects/svdquant" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="svdquant logo"
style="height: 40px; width: auto;"/>
alt="svdquant logo" style="height: 40px; width: auto;" />
</a>
</div>
<!-- Title -->
<h1 style="margin-top: 0;">INT4 FLUX.1-fill-dev Demo</h1>
<h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor' target="_blank">DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku' target="_blank">Nunchaku</a>&nbsp;
</h4>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info}
</div>
......
......@@ -2,23 +2,20 @@
<div>
<!-- Logo Row -->
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
<a href="https://github.com/mit-han-lab/nunchaku">
<a href="https://github.com/mit-han-lab/nunchaku" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
alt="nunchaku logo" style="height: 150px; width: auto;" />
</a>
<a href="https://hanlab.mit.edu/projects/svdquant">
<a href="https://hanlab.mit.edu/projects/svdquant" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="svdquant logo" style="height: 40px; width: auto;" />
</a>
</div>
<h1 style="margin-top: 0;">INT4 FLUX.1-Kontext-dev Demo</h1>
<h1 style="margin-top: 0;">{precision} FLUX.1-Kontext-dev Demo</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info}
</div>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{notice}
</div>
{count_info}
</div>
</div>
......@@ -22,10 +22,10 @@ if args.precision == "bf16":
pipeline = pipeline.to("cuda")
pipeline.precision = "bf16"
else:
assert args.precision == "int4"
assert args.precision in ["int4", "fp4"]
pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-int4_r32-flux.1-kontext-dev.safetensors"
f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{args.precision}_r32-flux.1-kontext-dev.safetensors"
)
pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder:
......@@ -40,7 +40,7 @@ else:
"black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
pipeline.precision = "int4"
pipeline.precision = args.precision
def run(image, prompt: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]:
......@@ -65,17 +65,17 @@ def run(image, prompt: str, num_inference_steps: int, guidance_scale: float, see
latency_str = f"{latency:.2f}s"
torch.cuda.empty_cache()
if args.count_use:
if os.path.exists(f"{args.model}-use_count.txt"):
with open(f"{args.model}-use_count.txt", "r") as f:
if os.path.exists("use_count.txt"):
with open("use_count.txt", "r") as f:
count = int(f.read())
else:
count = 0
count += 1
current_time = datetime.now()
print(f"{current_time}: {count}")
with open(f"{args.model}-use_count.txt", "w") as f:
with open("use_count.txt", "w") as f:
f.write(str(count))
with open(f"{args.model}-use_record.txt", "a") as f:
with open("use_record.txt", "a") as f:
f.write(f"{current_time}: {count}\n")
return result_image, latency_str
......@@ -91,7 +91,6 @@ with gr.Blocks(css_paths="assets/style.css", title="Nunchaku FLUX.1-Kontext Demo
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
def get_header_str():
......@@ -108,7 +107,9 @@ with gr.Blocks(css_paths="assets/style.css", title="Nunchaku FLUX.1-Kontext Demo
)
else:
count_info = ""
header_str = DESCRIPTION.format(device_info=device_info, notice=notice, count_info=count_info)
header_str = DESCRIPTION.format(
precision=args.precision.upper(), device_info=device_info, count_info=count_info
)
return header_str
header = gr.HTML(get_header_str())
......
......@@ -4,7 +4,7 @@ import argparse
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precisions to use"
"-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precisions to use"
)
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
......
MAX_SEED = 1000000000
EXAMPLES = [
[
"https://images.pexels.com/photos/15460314/pexels-photo-15460314.jpeg",
"Change the color of the woman's dress to red. The background is a beach. "
"The woman is holding a sign that says 'Nunchaku is awesome'",
20,
2.5,
23,
],
[
"https://huggingface.co/mit-han-lab/nunchaku-artifacts/resolve/main/ComfyUI-nunchaku/test_data/logo.png",
"Change the logo of 'MIT HAN Lab' to 'MIT Nunchaku' in the same style.",
20,
2.5,
233,
],
[
"https://huggingface.co/mit-han-lab/nunchaku-artifacts/resolve/main/ComfyUI-nunchaku/test_data/monalisa.jpg",
"Convert the image to ghibli style",
20,
2.5,
2333,
],
[
"https://huggingface.co/mit-han-lab/nunchaku-artifacts/resolve/main/ComfyUI-nunchaku/test_data/mushroom_depth.webp",
"Transform the depth map into an ethereal fantasy image. A glowing mushroom forms the roof of an ancient treehouse, "
"its golden light casting warmth over moss, tiny flowers, and a stone path. Smaller glowing mushrooms create a "
"multi-level home among the tree’s twisted branches. In the background, a misty forest with waterfalls and a "
"starry night sky adds to the magical atmosphere.",
20,
2.5,
23333,
],
[
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png",
"Make Pikachu hold a sign that says 'Nunchaku is awesome', yarn art style, detailed, vibrant colors",
28,
20,
2.5,
3,
233333,
],
]
......@@ -3,20 +3,18 @@
<!-- Logo Row -->
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
<a href="https://github.com/mit-han-lab/nunchaku">
<a href="https://github.com/mit-han-lab/nunchaku" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
alt="nunchaku logo"
style="height: 150px; width: auto;"/>
alt="nunchaku logo" style="height: 150px; width: auto;" />
</a>
<a href="https://hanlab.mit.edu/projects/svdquant">
<a href="https://hanlab.mit.edu/projects/svdquant" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="svdquant logo"
style="height: 40px; width: auto;"/>
alt="svdquant logo" style="height: 40px; width: auto;" />
</a>
</div>
<!-- Title -->
<h1 style="margin-top: 0;">INT4 FLUX.1-redux-dev Demo</h1>
<h1 style="margin-top: 0;">INT4 FLUX.1-Redux-dev Demo</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info}
......
......@@ -26,10 +26,10 @@ if args.precision == "bf16":
pipeline = pipeline.to("cuda")
pipeline.precision = "bf16"
else:
assert args.precision == "int4"
assert args.precision in ["int4", "fp4"]
pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-dev/svdq-int4_r32-flux.1-dev.safetensors"
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{args.precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
......@@ -39,7 +39,7 @@ else:
torch_dtype=torch.bfloat16,
)
pipeline = pipeline.to("cuda")
pipeline.precision = "int4"
pipeline.precision = args.precision
def run(image, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]:
......
......@@ -8,7 +8,7 @@ def get_args() -> argparse.Namespace:
"--precision",
type=str,
default="int4",
choices=["int4", "bf16"],
choices=["int4", "fp4", "bf16"],
help="Which precisions to use",
)
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
......
......@@ -2,15 +2,13 @@
<div>
<!-- Logo Row -->
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
<a href="https://github.com/mit-han-lab/nunchaku">
<a href="https://github.com/mit-han-lab/nunchaku" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
alt="nunchaku logo"
style="height: 150px; width: auto;"/>
alt="nunchaku logo" style="height: 150px; width: auto;" />
</a>
<a href="https://hanlab.mit.edu/projects/svdquant">
<a href="https://hanlab.mit.edu/projects/svdquant" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="svdquant logo"
style="height: 40px; width: auto;"/>
alt="svdquant logo" style="height: 40px; width: auto;" />
</a>
</div>
<h1 style="margin-top: 0;">INT4 FLUX.1-schnell Sketch-to-Image Demo</h1>
......
......@@ -2,15 +2,13 @@
<div>
<!-- Logo Row -->
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
<a href="https://github.com/mit-han-lab/nunchaku">
<a href="https://github.com/mit-han-lab/nunchaku" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
alt="nunchaku logo"
style="height: 150px; width: auto;"/>
alt="nunchaku logo" style="height: 150px; width: auto;" />
</a>
<a href="https://hanlab.mit.edu/projects/svdquant">
<a href="https://hanlab.mit.edu/projects/svdquant" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="svdquant logo"
style="height: 40px; width: auto;"/>
alt="svdquant logo" style="height: 40px; width: auto;" />
</a>
</div>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment