Unverified Commit e5faa2df authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

fix: fix all the nunchaku gradio demos (#442)

* bump the version to 0.3.1

* update the t2i demos

* remove the dependencies of GPUtils

* update the html

* update the html

* update the html

* update the html

* fix the demos

* demo runnable again
parent 45afb58b
...@@ -125,7 +125,7 @@ If you're using a Blackwell GPU (e.g., 50-series GPUs), install a wheel with PyT ...@@ -125,7 +125,7 @@ If you're using a Blackwell GPU (e.g., 50-series GPUs), install a wheel with PyT
pip install ninja wheel diffusers transformers accelerate sentencepiece protobuf huggingface_hub pip install ninja wheel diffusers transformers accelerate sentencepiece protobuf huggingface_hub
# For gradio demos # For gradio demos
pip install peft opencv-python gradio spaces GPUtil pip install peft opencv-python gradio spaces
``` ```
To enable NVFP4 on Blackwell GPUs (e.g., 50-series GPUs), please install nightly PyTorch>=2.7 with CUDA>=12.8. The installation command can be: To enable NVFP4 on Blackwell GPUs (e.g., 50-series GPUs), please install nightly PyTorch>=2.7 with CUDA>=12.8. The installation command can be:
......
...@@ -122,7 +122,7 @@ pip install https://huggingface.co/mit-han-lab/nunchaku/resolve/main/nunchaku-0. ...@@ -122,7 +122,7 @@ pip install https://huggingface.co/mit-han-lab/nunchaku/resolve/main/nunchaku-0.
pip install ninja wheel diffusers transformers accelerate sentencepiece protobuf huggingface_hub pip install ninja wheel diffusers transformers accelerate sentencepiece protobuf huggingface_hub
# Gradio演示依赖 # Gradio演示依赖
pip install peft opencv-python gradio spaces GPUtil pip install peft opencv-python gradio spaces
``` ```
Blackwell用户需安装PyTorch>=2.7, CUDA>=12.8: Blackwell用户需安装PyTorch>=2.7, CUDA>=12.8:
......
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div> <div>
<h1> <!-- Logo Row -->
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg" <div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
alt="logo" <a href="https://github.com/mit-han-lab/nunchaku">
style="height: 40px; width: auto; display: block; margin: auto;"/> <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
INT4 FLUX.1-{model_name}-dev Demo alt="nunchaku logo"
</h1> style="height: 150px; width: auto;"/>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me' target="_blank">Muyang Li*</a>,
<a href='https://yujunlin.com' target="_blank">Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang' target="_blank">Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/' target="_blank">Tianle Cai</a>,
<a href='https://xiuyuli.com' target="_blank">Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX' target="_blank">Junxian Guo</a>,
<a href='https://xieenze.github.io' target="_blank">Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/' target="_blank">Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/' target="_blank">Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan' target="_blank">Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2411.05007" target="_blank">[Paper]</a>
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku' target="_blank">
[Code]
</a> </a>
&nbsp; <a href="https://hanlab.mit.edu/projects/svdquant">
<a href='https://hanlab.mit.edu/projects/svdquant' target="_blank"> <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
[Website] alt="svdquant logo"
</a> style="height: 40px; width: auto;"/>
&nbsp;
<a href='https://hanlab.mit.edu/blog/svdquant' target="_blank">
[Blog]
</a> </a>
</div> </div>
<h4>Quantization Library: <h1 style="margin-top: 0;">INT4 FLUX.1-{model_name}-dev Demo</h1>
<a href='https://github.com/mit-han-lab/deepcompressor' target="_blank">DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku' target="_blank">Nunchaku</a>&nbsp;
</h4>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info} {device_info}
</div> </div>
......
...@@ -4,7 +4,6 @@ import random ...@@ -4,7 +4,6 @@ import random
import time import time
from datetime import datetime from datetime import datetime
import GPUtil
import torch import torch
from controlnet_aux import CannyDetector from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline from diffusers import FluxControlPipeline
...@@ -54,12 +53,16 @@ if args.precision == "bf16": ...@@ -54,12 +53,16 @@ if args.precision == "bf16":
else: else:
assert args.precision == "int4" assert args.precision == "int4"
pipeline_init_kwargs = {} pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-int4-flux.1-{model_name}") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-{model_name}/svdq-int4_r32-flux.1-{model_name}.safetensors"
)
pipeline_init_kwargs["transformer"] = transformer pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder: if args.use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5") text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = pipeline_class.from_pretrained( pipeline = pipeline_class.from_pretrained(
...@@ -125,11 +128,12 @@ def run( ...@@ -125,11 +128,12 @@ def run(
with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name} Demo") as demo: with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name} Demo") as demo:
with open("assets/description.html", "r") as f: with open("assets/description.html", "r") as f:
DESCRIPTION = f.read() DESCRIPTION = f.read()
gpus = GPUtil.getGPUs() # Get the GPU properties
if len(gpus) > 0: if torch.cuda.device_count() > 0:
gpu = gpus[0] gpu_properties = torch.cuda.get_device_properties(0)
memory = gpu.memoryTotal / 1024 gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." gpu_name = torch.cuda.get_device_name(0)
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."' notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
......
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div> <div>
<h1> <!-- Logo Row -->
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg" <div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
alt="logo" <a href="https://github.com/mit-han-lab/nunchaku">
style="height: 40px; width: auto; display: block; margin: auto;"/> <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
INT4 FLUX.1-fill-dev Demo alt="nunchaku logo"
</h1> style="height: 150px; width: auto;"/>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me' target="_blank">Muyang Li*</a>,
<a href='https://yujunlin.com' target="_blank">Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang' target="_blank">Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/' target="_blank">Tianle Cai</a>,
<a href='https://xiuyuli.com' target="_blank">Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX' target="_blank">Junxian Guo</a>,
<a href='https://xieenze.github.io' target="_blank">Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/' target="_blank">Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/' target="_blank">Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan' target="_blank">Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2411.05007" target="_blank">[Paper]</a>
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku' target="_blank">
[Code]
</a> </a>
&nbsp; <a href="https://hanlab.mit.edu/projects/svdquant">
<a href='https://hanlab.mit.edu/projects/svdquant' target="_blank"> <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
[Website] alt="svdquant logo"
</a> style="height: 40px; width: auto;"/>
&nbsp;
<a href='https://hanlab.mit.edu/blog/svdquant' target="_blank">
[Blog]
</a> </a>
</div> </div>
<!-- Title -->
<h1 style="margin-top: 0;">INT4 FLUX.1-fill-dev Demo</h1>
<h4>Quantization Library: <h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor' target="_blank">DeepCompressor</a>&nbsp; <a href='https://github.com/mit-han-lab/deepcompressor' target="_blank">DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku' target="_blank">Nunchaku</a>&nbsp; Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku' target="_blank">Nunchaku</a>&nbsp;
......
...@@ -4,7 +4,6 @@ import random ...@@ -4,7 +4,6 @@ import random
import time import time
from datetime import datetime from datetime import datetime
import GPUtil
import torch import torch
from diffusers import FluxFillPipeline from diffusers import FluxFillPipeline
from PIL import Image from PIL import Image
...@@ -26,12 +25,16 @@ if args.precision == "bf16": ...@@ -26,12 +25,16 @@ if args.precision == "bf16":
else: else:
assert args.precision == "int4" assert args.precision == "int4"
pipeline_init_kwargs = {} pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-fill-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-fill-dev/svdq-int4_r32-flux.1-fill-dev.safetensors"
)
pipeline_init_kwargs["transformer"] = transformer pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder: if args.use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5") text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxFillPipeline.from_pretrained( pipeline = FluxFillPipeline.from_pretrained(
...@@ -97,11 +100,12 @@ def run( ...@@ -97,11 +100,12 @@ def run(
with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Flux.1-Fill-dev Sketch-to-Image Demo") as demo: with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Flux.1-Fill-dev Sketch-to-Image Demo") as demo:
with open("assets/description.html", "r") as f: with open("assets/description.html", "r") as f:
DESCRIPTION = f.read() DESCRIPTION = f.read()
gpus = GPUtil.getGPUs() # Get the GPU properties
if len(gpus) > 0: if torch.cuda.device_count() > 0:
gpu = gpus[0] gpu_properties = torch.cuda.get_device_properties(0)
memory = gpu.memoryTotal / 1024 gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." gpu_name = torch.cuda.get_device_name(0)
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."' notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
......
# Nunchaku INT4 FLUX.1 Redux Demo # Nunchaku INT4 FLUX.1 Redux Demo
![demo](https://huggingface.co/mit-han-lab/nunchaku-artifacts/resolve/main/nunchaku/app/flux.1/redux/assets/demo.jpg) ![demo](https://huggingface.co/mit-han-lab/nunchaku-artifacts/resolve/main/nunchaku/app/flux.1/redux/assets/demo.png)
This interactive Gradio application allows you to interactively generate image variations. The base model is [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev). We use [FLUX.1-Redux-dev](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) to preprocess the image before inputting it into Flux.1-dev. To launch the application, run: This interactive Gradio application allows you to interactively generate image variations. The base model is [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev). We use [FLUX.1-Redux-dev](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) to preprocess the image before inputting it into Flux.1-dev. To launch the application, run:
......
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div> <div>
<h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg" <!-- Logo Row -->
alt="logo" <div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
style="height: 40px; width: auto; display: block; margin: auto;"/> <a href="https://github.com/mit-han-lab/nunchaku">
INT4 FLUX.1-redux-dev Demo <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
</h1> alt="nunchaku logo"
<h2> style="height: 150px; width: auto;"/>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me'>Muyang Li*</a>,
<a href='https://yujunlin.com'>Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang'>Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/'>Tianle Cai</a>,
<a href='https://xiuyuli.com'>Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX'>Junxian Guo</a>,
<a href='https://xieenze.github.io'>Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/'>Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/'>Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan'>Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2411.05007">[Paper]</a>
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku'>
[Code]
</a> </a>
&nbsp; <a href="https://hanlab.mit.edu/projects/svdquant">
<a href='https://hanlab.mit.edu/projects/svdquant'> <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
[Website] alt="svdquant logo"
</a> style="height: 40px; width: auto;"/>
&nbsp;
<a href='https://hanlab.mit.edu/blog/svdquant'>
[Blog]
</a> </a>
</div> </div>
<h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor'>DeepCompressor</a>&nbsp; <!-- Title -->
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku'>Nunchaku</a>&nbsp; <h1 style="margin-top: 0;">INT4 FLUX.1-redux-dev Demo</h1>
</h4>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info} {device_info}
</div> </div>
......
@import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css'); @import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css');
.gradio-container{max-width: 1200px !important} .gradio-container {
h1{text-align:center} max-width: 1200px !important;
margin: auto; /* Centers the element horizontally */
}
h1 {
text-align: center
}
.wrap.svelte-p4aq0j.svelte-p4aq0j { .wrap.svelte-p4aq0j.svelte-p4aq0j {
display: none; display: none;
...@@ -25,5 +31,10 @@ h1{text-align:center} ...@@ -25,5 +31,10 @@ h1{text-align:center}
margin: auto; /* Centers the element horizontally */ margin: auto; /* Centers the element horizontally */
} }
#random_seed {height: 71px;} #random_seed {
#run_button {height: 87px;} height: 71px;
}
#run_button {
height: 87px;
}
...@@ -4,10 +4,6 @@ import random ...@@ -4,10 +4,6 @@ import random
import time import time
from datetime import datetime from datetime import datetime
import GPUtil
# import gradio last to avoid conflicts with other imports
import gradio as gr
import torch import torch
from diffusers import FluxPipeline, FluxPriorReduxPipeline from diffusers import FluxPipeline, FluxPriorReduxPipeline
from PIL import Image from PIL import Image
...@@ -16,6 +12,9 @@ from vars import DEFAULT_GUIDANCE, DEFAULT_INFERENCE_STEP, EXAMPLES, MAX_SEED ...@@ -16,6 +12,9 @@ from vars import DEFAULT_GUIDANCE, DEFAULT_INFERENCE_STEP, EXAMPLES, MAX_SEED
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
# import gradio last to avoid conflicts with other imports
import gradio as gr # noqa: isort: skip
args = get_args() args = get_args()
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained( pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
...@@ -29,7 +28,9 @@ if args.precision == "bf16": ...@@ -29,7 +28,9 @@ if args.precision == "bf16":
else: else:
assert args.precision == "int4" assert args.precision == "int4"
pipeline_init_kwargs = {} pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-dev/svdq-int4_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-dev",
text_encoder=None, text_encoder=None,
...@@ -79,11 +80,12 @@ def run(image, num_inference_steps: int, guidance_scale: float, seed: int) -> tu ...@@ -79,11 +80,12 @@ def run(image, num_inference_steps: int, guidance_scale: float, seed: int) -> tu
with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Flux.1-redux-dev Demo") as demo: with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Flux.1-redux-dev Demo") as demo:
with open("assets/description.html", "r") as f: with open("assets/description.html", "r") as f:
DESCRIPTION = f.read() DESCRIPTION = f.read()
gpus = GPUtil.getGPUs() # Get the GPU properties
if len(gpus) > 0: if torch.cuda.device_count() > 0:
gpu = gpus[0] gpu_properties = torch.cuda.get_device_properties(0)
memory = gpu.memoryTotal / 1024 gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." gpu_name = torch.cuda.get_device_name(0)
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
......
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div> <div>
<h1> <!-- Logo Row -->
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg" <div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
alt="logo" <a href="https://github.com/mit-han-lab/nunchaku">
style="height: 40px; width: auto; display: block; margin: auto;"/> <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
INT4 FLUX.1-schnell Sketch-to-Image Demo alt="nunchaku logo"
</h1> style="height: 150px; width: auto;"/>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me' target="_blank">Muyang Li*</a>,
<a href='https://yujunlin.com' target="_blank">Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang' target="_blank">Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/' target="_blank">Tianle Cai</a>,
<a href='https://xiuyuli.com' target="_blank">Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX' target="_blank">Junxian Guo</a>,
<a href='https://xieenze.github.io' target="_blank">Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/' target="_blank">Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/' target="_blank">Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan' target="_blank">Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2411.05007" target="_blank">[Paper]</a>
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku' target="_blank">
[Code]
</a> </a>
&nbsp; <a href="https://hanlab.mit.edu/projects/svdquant">
<a href='https://hanlab.mit.edu/projects/svdquant' target="_blank"> <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
[Website] alt="svdquant logo"
</a> style="height: 40px; width: auto;"/>
&nbsp;
<a href='https://hanlab.mit.edu/blog/svdquant' target="_blank">
[Blog]
</a> </a>
</div> </div>
<h4>Quantization Library: <h1 style="margin-top: 0;">INT4 FLUX.1-schnell Sketch-to-Image Demo</h1>
<a href='https://github.com/mit-han-lab/deepcompressor' target="_blank">DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku' target="_blank">Nunchaku</a>&nbsp;
Image Control: <a href="https://github.com/GaParmar/img2img-turbo" target="_blank">img2img-turbo</a>
</h4>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info} {device_info}
</div> </div>
......
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py # Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import os import os
import random import random
import tempfile
import time import time
from datetime import datetime from datetime import datetime
import GPUtil
import numpy as np import numpy as np
import torch import torch
from flux_pix2pix_pipeline import FluxPix2pixTurboPipeline from flux_pix2pix_pipeline import FluxPix2pixTurboPipeline
...@@ -33,12 +31,16 @@ if args.precision == "bf16": ...@@ -33,12 +31,16 @@ if args.precision == "bf16":
else: else:
assert args.precision == "int4" assert args.precision == "int4"
pipeline_init_kwargs = {} pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-schnell/svdq-int4_r32-flux.1-schnell.safetensors"
)
pipeline_init_kwargs["transformer"] = transformer pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder: if args.use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5") text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPix2pixTurboPipeline.from_pretrained( pipeline = FluxPix2pixTurboPipeline.from_pretrained(
...@@ -55,14 +57,6 @@ else: ...@@ -55,14 +57,6 @@ else:
safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker) safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)
def save_image(img):
if isinstance(img, dict):
img = img["composite"]
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(temp_file.name)
return temp_file.name
def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed: int) -> tuple[Image, str]: def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed: int) -> tuple[Image, str]:
print(f"Prompt: {prompt}") print(f"Prompt: {prompt}")
...@@ -116,11 +110,12 @@ def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed: ...@@ -116,11 +110,12 @@ def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed:
with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Demo") as demo: with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Demo") as demo:
with open("assets/description.html", "r") as f: with open("assets/description.html", "r") as f:
DESCRIPTION = f.read() DESCRIPTION = f.read()
gpus = GPUtil.getGPUs() # Get the GPU properties
if len(gpus) > 0: if torch.cuda.device_count() > 0:
gpu = gpus[0] gpu_properties = torch.cuda.get_device_properties(0)
memory = gpu.memoryTotal / 1024 gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." gpu_name = torch.cuda.get_device_name(0)
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."' notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
...@@ -170,7 +165,6 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Dem ...@@ -170,7 +165,6 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Dem
with gr.Row(): with gr.Row():
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6) prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
run_button = gr.Button("Run", scale=1, elem_id="run_button") run_button = gr.Button("Run", scale=1, elem_id="run_button")
download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch")
with gr.Row(): with gr.Row():
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1) style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
prompt_template = gr.Textbox( prompt_template = gr.Textbox(
...@@ -207,7 +201,6 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Dem ...@@ -207,7 +201,6 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Dem
) )
latency_result = gr.Text(label="Inference Latency", show_label=True) latency_result = gr.Text(label="Inference Latency", show_label=True)
download_result = gr.DownloadButton("Download Result", elem_id="download_result")
gr.Markdown("### Instructions") gr.Markdown("### Instructions")
gr.Markdown("**1**. Enter a text prompt (e.g. a cat)") gr.Markdown("**1**. Enter a text prompt (e.g. a cat)")
gr.Markdown("**2**. Start sketching") gr.Markdown("**2**. Start sketching")
...@@ -235,8 +228,6 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Dem ...@@ -235,8 +228,6 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Dem
api_name=False, api_name=False,
) )
download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch)
download_result.click(fn=save_image, inputs=result, outputs=download_result)
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility") gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
......
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div> <div>
<h1> <!-- Logo Row -->
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg" <div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
alt="logo" <a href="https://github.com/mit-han-lab/nunchaku">
style="height: 40px; width: auto; display: block; margin: auto;"/> <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
FLUX.1-{model} Demo alt="nunchaku logo"
</h1> style="height: 150px; width: auto;"/>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me' target="_blank">Muyang Li*</a>,
<a href='https://yujunlin.com' target="_blank">Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang' target="_blank">Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/' target="_blank">Tianle Cai</a>,
<a href='https://xiuyuli.com' target="_blank">Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX' target="_blank">Junxian Guo</a>,
<a href='https://xieenze.github.io' target="_blank">Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/' target="_blank">Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/' target="_blank">Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan' target="_blank">Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2411.05007" target="_blank">[Paper]</a>
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku' target="_blank">
[Code]
</a> </a>
&nbsp; <a href="https://hanlab.mit.edu/projects/svdquant">
<a href='https://hanlab.mit.edu/projects/svdquant' target="_blank"> <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
[Website] alt="svdquant logo"
</a> style="height: 40px; width: auto;"/>
&nbsp;
<a href='https://hanlab.mit.edu/blog/svdquant' target="_blank">
[Blog]
</a> </a>
</div> </div>
<h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor' target="_blank">DeepCompressor</a> <!-- Title -->
&nbsp; <h1 style="margin-top: 0;">FLUX.1-{model} Demo</h1>
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku' target="_blank">Nunchaku</a>
</h4> <!-- Device Info -->
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info} {device_info}
</div> </div>
<!-- Notice -->
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{notice} {notice}
</div> </div>
<!-- Count Info -->
{count_info} {count_info}
</div> </div>
</div> </div>
...@@ -5,12 +5,11 @@ import random ...@@ -5,12 +5,11 @@ import random
import time import time
from datetime import datetime from datetime import datetime
import GPUtil
import spaces import spaces
import torch import torch
from peft.tuners import lora from peft.tuners import lora
from utils import get_pipeline from utils import get_pipeline
from vars import DEFAULT_HEIGHT, DEFAULT_WIDTH, EXAMPLES, MAX_SEED, PROMPT_TEMPLATES, SVDQ_LORA_PATHS from vars import DEFAULT_HEIGHT, DEFAULT_WIDTH, EXAMPLES, LORA_PATHS, MAX_SEED, PROMPT_TEMPLATES
from nunchaku.models.safety_checker import SafetyChecker from nunchaku.models.safety_checker import SafetyChecker
...@@ -98,7 +97,9 @@ def generate( ...@@ -98,7 +97,9 @@ def generate(
else: else:
assert precision == "int4" assert precision == "int4"
if lora_name != "None": if lora_name != "None":
pipeline.transformer.update_lora_params(SVDQ_LORA_PATHS[lora_name]) lora_path = LORA_PATHS[lora_name]
lora_path = os.path.join(lora_path["name_or_path"], lora_path["weight_name"])
pipeline.transformer.update_lora_params(lora_path)
pipeline.transformer.set_lora_strength(lora_weight) pipeline.transformer.set_lora_strength(lora_weight)
else: else:
pipeline.transformer.set_lora_strength(0) pipeline.transformer.set_lora_strength(0)
...@@ -157,11 +158,13 @@ def generate( ...@@ -157,11 +158,13 @@ def generate(
with open("./assets/description.html", "r") as f: with open("./assets/description.html", "r") as f:
DESCRIPTION = f.read() DESCRIPTION = f.read()
gpus = GPUtil.getGPUs()
if len(gpus) > 0: # Get the GPU properties
gpu = gpus[0] if torch.cuda.device_count() > 0:
memory = gpu.memoryTotal / 1024 gpu_properties = torch.cuda.get_device_properties(0)
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
gpu_name = torch.cuda.get_device_name(0)
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."' notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
......
...@@ -28,17 +28,21 @@ def get_pipeline( ...@@ -28,17 +28,21 @@ def get_pipeline(
if precision in ["int4", "fp4"]: if precision in ["int4", "fp4"]:
assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices" assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices"
if precision == "int4": if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-schnell/svdq-int4_r32-flux.1-schnell.safetensors"
)
else: else:
assert precision == "fp4" assert precision == "fp4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-fp4-flux.1-schnell", precision="fp4" "mit-han-lab/nunchaku-flux.1-schnell/svdq-fp4_r32-flux.1-schnell.safetensors", precision="fp4"
) )
pipeline_init_kwargs["transformer"] = transformer pipeline_init_kwargs["transformer"] = transformer
if use_qencoder: if use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5") text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
else: else:
assert precision == "bf16" assert precision == "bf16"
...@@ -47,7 +51,9 @@ def get_pipeline( ...@@ -47,7 +51,9 @@ def get_pipeline(
) )
elif model_name == "dev": elif model_name == "dev":
if precision == "int4": if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-dev/svdq-int4_r32-flux.1-dev.safetensors"
)
if lora_name not in ["All", "None"]: if lora_name not in ["All", "None"]:
transformer.update_lora_params(SVDQ_LORA_PATHS[lora_name]) transformer.update_lora_params(SVDQ_LORA_PATHS[lora_name])
transformer.set_lora_strength(lora_weight) transformer.set_lora_strength(lora_weight)
...@@ -55,7 +61,9 @@ def get_pipeline( ...@@ -55,7 +61,9 @@ def get_pipeline(
if use_qencoder: if use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5") text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
......
__version__ = "0.3.0" __version__ = "0.3.1dev"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment