Unverified Commit e419dc29 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

fix: update the kontext demos (#515)

* fix the use count

* update

* change the example step to 20

* allow embedding website

* support fp4 on blackwell

* fp4 kontext runnable

* add more examples

* add kontext examples
parent 121ee754
...@@ -2,15 +2,13 @@ ...@@ -2,15 +2,13 @@
<div> <div>
<!-- Logo Row --> <!-- Logo Row -->
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;"> <div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
<a href="https://github.com/mit-han-lab/nunchaku"> <a href="https://github.com/mit-han-lab/nunchaku" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg" <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
alt="nunchaku logo" alt="nunchaku logo" style="height: 150px; width: auto;" />
style="height: 150px; width: auto;"/>
</a> </a>
<a href="https://hanlab.mit.edu/projects/svdquant"> <a href="https://hanlab.mit.edu/projects/svdquant" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg" <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="svdquant logo" alt="svdquant logo" style="height: 40px; width: auto;" />
style="height: 40px; width: auto;"/>
</a> </a>
</div> </div>
<h1 style="margin-top: 0;">INT4 FLUX.1-{model_name}-dev Demo</h1> <h1 style="margin-top: 0;">INT4 FLUX.1-{model_name}-dev Demo</h1>
......
...@@ -2,24 +2,18 @@ ...@@ -2,24 +2,18 @@
<div> <div>
<!-- Logo Row --> <!-- Logo Row -->
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;"> <div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
<a href="https://github.com/mit-han-lab/nunchaku"> <a href="https://github.com/mit-han-lab/nunchaku" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg" <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
alt="nunchaku logo" alt="nunchaku logo" style="height: 150px; width: auto;" />
style="height: 150px; width: auto;"/>
</a> </a>
<a href="https://hanlab.mit.edu/projects/svdquant"> <a href="https://hanlab.mit.edu/projects/svdquant" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg" <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="svdquant logo" alt="svdquant logo" style="height: 40px; width: auto;" />
style="height: 40px; width: auto;"/>
</a> </a>
</div> </div>
<!-- Title --> <!-- Title -->
<h1 style="margin-top: 0;">INT4 FLUX.1-fill-dev Demo</h1> <h1 style="margin-top: 0;">INT4 FLUX.1-fill-dev Demo</h1>
<h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor' target="_blank">DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku' target="_blank">Nunchaku</a>&nbsp;
</h4>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info} {device_info}
</div> </div>
......
...@@ -2,23 +2,20 @@ ...@@ -2,23 +2,20 @@
<div> <div>
<!-- Logo Row --> <!-- Logo Row -->
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;"> <div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
<a href="https://github.com/mit-han-lab/nunchaku"> <a href="https://github.com/mit-han-lab/nunchaku" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg" <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
alt="nunchaku logo" style="height: 150px; width: auto;" /> alt="nunchaku logo" style="height: 150px; width: auto;" />
</a> </a>
<a href="https://hanlab.mit.edu/projects/svdquant"> <a href="https://hanlab.mit.edu/projects/svdquant" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg" <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="svdquant logo" style="height: 40px; width: auto;" /> alt="svdquant logo" style="height: 40px; width: auto;" />
</a> </a>
</div> </div>
<h1 style="margin-top: 0;">INT4 FLUX.1-Kontext-dev Demo</h1> <h1 style="margin-top: 0;">{precision} FLUX.1-Kontext-dev Demo</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info} {device_info}
</div> </div>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{notice}
</div>
{count_info} {count_info}
</div> </div>
</div> </div>
...@@ -22,10 +22,10 @@ if args.precision == "bf16": ...@@ -22,10 +22,10 @@ if args.precision == "bf16":
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
pipeline.precision = "bf16" pipeline.precision = "bf16"
else: else:
assert args.precision == "int4" assert args.precision in ["int4", "fp4"]
pipeline_init_kwargs = {} pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-int4_r32-flux.1-kontext-dev.safetensors" f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{args.precision}_r32-flux.1-kontext-dev.safetensors"
) )
pipeline_init_kwargs["transformer"] = transformer pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder: if args.use_qencoder:
...@@ -40,7 +40,7 @@ else: ...@@ -40,7 +40,7 @@ else:
"black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
) )
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
pipeline.precision = "int4" pipeline.precision = args.precision
def run(image, prompt: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]: def run(image, prompt: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]:
...@@ -65,17 +65,17 @@ def run(image, prompt: str, num_inference_steps: int, guidance_scale: float, see ...@@ -65,17 +65,17 @@ def run(image, prompt: str, num_inference_steps: int, guidance_scale: float, see
latency_str = f"{latency:.2f}s" latency_str = f"{latency:.2f}s"
torch.cuda.empty_cache() torch.cuda.empty_cache()
if args.count_use: if args.count_use:
if os.path.exists(f"{args.model}-use_count.txt"): if os.path.exists("use_count.txt"):
with open(f"{args.model}-use_count.txt", "r") as f: with open("use_count.txt", "r") as f:
count = int(f.read()) count = int(f.read())
else: else:
count = 0 count = 0
count += 1 count += 1
current_time = datetime.now() current_time = datetime.now()
print(f"{current_time}: {count}") print(f"{current_time}: {count}")
with open(f"{args.model}-use_count.txt", "w") as f: with open("use_count.txt", "w") as f:
f.write(str(count)) f.write(str(count))
with open(f"{args.model}-use_record.txt", "a") as f: with open("use_record.txt", "a") as f:
f.write(f"{current_time}: {count}\n") f.write(f"{current_time}: {count}\n")
return result_image, latency_str return result_image, latency_str
...@@ -91,7 +91,6 @@ with gr.Blocks(css_paths="assets/style.css", title="Nunchaku FLUX.1-Kontext Demo ...@@ -91,7 +91,6 @@ with gr.Blocks(css_paths="assets/style.css", title="Nunchaku FLUX.1-Kontext Demo
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory." device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
def get_header_str(): def get_header_str():
...@@ -108,7 +107,9 @@ with gr.Blocks(css_paths="assets/style.css", title="Nunchaku FLUX.1-Kontext Demo ...@@ -108,7 +107,9 @@ with gr.Blocks(css_paths="assets/style.css", title="Nunchaku FLUX.1-Kontext Demo
) )
else: else:
count_info = "" count_info = ""
header_str = DESCRIPTION.format(device_info=device_info, notice=notice, count_info=count_info) header_str = DESCRIPTION.format(
precision=args.precision.upper(), device_info=device_info, count_info=count_info
)
return header_str return header_str
header = gr.HTML(get_header_str()) header = gr.HTML(get_header_str())
......
...@@ -4,7 +4,7 @@ import argparse ...@@ -4,7 +4,7 @@ import argparse
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precisions to use" "-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precisions to use"
) )
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder") parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker") parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
......
MAX_SEED = 1000000000 MAX_SEED = 1000000000
EXAMPLES = [ EXAMPLES = [
[
"https://images.pexels.com/photos/15460314/pexels-photo-15460314.jpeg",
"Change the color of the woman's dress to red. The background is a beach. "
"The woman is holding a sign that says 'Nunchaku is awesome'",
20,
2.5,
23,
],
[
"https://huggingface.co/mit-han-lab/nunchaku-artifacts/resolve/main/ComfyUI-nunchaku/test_data/logo.png",
"Change the logo of 'MIT HAN Lab' to 'MIT Nunchaku' in the same style.",
20,
2.5,
233,
],
[
"https://huggingface.co/mit-han-lab/nunchaku-artifacts/resolve/main/ComfyUI-nunchaku/test_data/monalisa.jpg",
"Convert the image to ghibli style",
20,
2.5,
2333,
],
[
"https://huggingface.co/mit-han-lab/nunchaku-artifacts/resolve/main/ComfyUI-nunchaku/test_data/mushroom_depth.webp",
"Transform the depth map into an ethereal fantasy image. A glowing mushroom forms the roof of an ancient treehouse, "
"its golden light casting warmth over moss, tiny flowers, and a stone path. Smaller glowing mushrooms create a "
"multi-level home among the tree’s twisted branches. In the background, a misty forest with waterfalls and a "
"starry night sky adds to the magical atmosphere.",
20,
2.5,
23333,
],
[ [
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png", "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png",
"Make Pikachu hold a sign that says 'Nunchaku is awesome', yarn art style, detailed, vibrant colors", "Make Pikachu hold a sign that says 'Nunchaku is awesome', yarn art style, detailed, vibrant colors",
28, 20,
2.5, 2.5,
3, 233333,
], ],
] ]
...@@ -3,20 +3,18 @@ ...@@ -3,20 +3,18 @@
<!-- Logo Row --> <!-- Logo Row -->
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;"> <div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
<a href="https://github.com/mit-han-lab/nunchaku"> <a href="https://github.com/mit-han-lab/nunchaku" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg" <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
alt="nunchaku logo" alt="nunchaku logo" style="height: 150px; width: auto;" />
style="height: 150px; width: auto;"/>
</a> </a>
<a href="https://hanlab.mit.edu/projects/svdquant"> <a href="https://hanlab.mit.edu/projects/svdquant" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg" <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="svdquant logo" alt="svdquant logo" style="height: 40px; width: auto;" />
style="height: 40px; width: auto;"/>
</a> </a>
</div> </div>
<!-- Title --> <!-- Title -->
<h1 style="margin-top: 0;">INT4 FLUX.1-redux-dev Demo</h1> <h1 style="margin-top: 0;">INT4 FLUX.1-Redux-dev Demo</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info} {device_info}
......
...@@ -26,10 +26,10 @@ if args.precision == "bf16": ...@@ -26,10 +26,10 @@ if args.precision == "bf16":
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
pipeline.precision = "bf16" pipeline.precision = "bf16"
else: else:
assert args.precision == "int4" assert args.precision in ["int4", "fp4"]
pipeline_init_kwargs = {} pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-dev/svdq-int4_r32-flux.1-dev.safetensors" f"mit-han-lab/nunchaku-flux.1-dev/svdq-{args.precision}_r32-flux.1-dev.safetensors"
) )
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-dev",
...@@ -39,7 +39,7 @@ else: ...@@ -39,7 +39,7 @@ else:
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
) )
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
pipeline.precision = "int4" pipeline.precision = args.precision
def run(image, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]: def run(image, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]:
......
...@@ -8,7 +8,7 @@ def get_args() -> argparse.Namespace: ...@@ -8,7 +8,7 @@ def get_args() -> argparse.Namespace:
"--precision", "--precision",
type=str, type=str,
default="int4", default="int4",
choices=["int4", "bf16"], choices=["int4", "fp4", "bf16"],
help="Which precisions to use", help="Which precisions to use",
) )
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses") parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
......
...@@ -2,15 +2,13 @@ ...@@ -2,15 +2,13 @@
<div> <div>
<!-- Logo Row --> <!-- Logo Row -->
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;"> <div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
<a href="https://github.com/mit-han-lab/nunchaku"> <a href="https://github.com/mit-han-lab/nunchaku" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg" <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
alt="nunchaku logo" alt="nunchaku logo" style="height: 150px; width: auto;" />
style="height: 150px; width: auto;"/>
</a> </a>
<a href="https://hanlab.mit.edu/projects/svdquant"> <a href="https://hanlab.mit.edu/projects/svdquant" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg" <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="svdquant logo" alt="svdquant logo" style="height: 40px; width: auto;" />
style="height: 40px; width: auto;"/>
</a> </a>
</div> </div>
<h1 style="margin-top: 0;">INT4 FLUX.1-schnell Sketch-to-Image Demo</h1> <h1 style="margin-top: 0;">INT4 FLUX.1-schnell Sketch-to-Image Demo</h1>
......
...@@ -2,15 +2,13 @@ ...@@ -2,15 +2,13 @@
<div> <div>
<!-- Logo Row --> <!-- Logo Row -->
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;"> <div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
<a href="https://github.com/mit-han-lab/nunchaku"> <a href="https://github.com/mit-han-lab/nunchaku" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg" <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
alt="nunchaku logo" alt="nunchaku logo" style="height: 150px; width: auto;" />
style="height: 150px; width: auto;"/>
</a> </a>
<a href="https://hanlab.mit.edu/projects/svdquant"> <a href="https://hanlab.mit.edu/projects/svdquant" target="_blank">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg" <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="svdquant logo" alt="svdquant logo" style="height: 40px; width: auto;" />
style="height: 40px; width: auto;"/>
</a> </a>
</div> </div>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment