Unverified Commit 07f07563 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

chore: release v0.3.1

parents 7214300d ad92b16a
......@@ -125,7 +125,7 @@ If you're using a Blackwell GPU (e.g., 50-series GPUs), install a wheel with PyT
pip install ninja wheel diffusers transformers accelerate sentencepiece protobuf huggingface_hub
# For gradio demos
pip install peft opencv-python gradio spaces GPUtil
pip install peft opencv-python gradio spaces
```
To enable NVFP4 on Blackwell GPUs (e.g., 50-series GPUs), please install nightly PyTorch>=2.7 with CUDA>=12.8. The installation command can be:
......
......@@ -122,7 +122,7 @@ pip install https://huggingface.co/mit-han-lab/nunchaku/resolve/main/nunchaku-0.
pip install ninja wheel diffusers transformers accelerate sentencepiece protobuf huggingface_hub
# Gradio演示依赖
pip install peft opencv-python gradio spaces GPUtil
pip install peft opencv-python gradio spaces
```
Blackwell用户需安装PyTorch>=2.7, CUDA>=12.8:
......
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/>
INT4 FLUX.1-{model_name}-dev Demo
</h1>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me' target="_blank">Muyang Li*</a>,
<a href='https://yujunlin.com' target="_blank">Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang' target="_blank">Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/' target="_blank">Tianle Cai</a>,
<a href='https://xiuyuli.com' target="_blank">Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX' target="_blank">Junxian Guo</a>,
<a href='https://xieenze.github.io' target="_blank">Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/' target="_blank">Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/' target="_blank">Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan' target="_blank">Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2411.05007" target="_blank">[Paper]</a>
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku' target="_blank">
[Code]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/projects/svdquant' target="_blank">
[Website]
<!-- Logo Row -->
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
<a href="https://github.com/mit-han-lab/nunchaku">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
alt="nunchaku logo"
style="height: 150px; width: auto;"/>
</a>
&nbsp;
<a href='https://hanlab.mit.edu/blog/svdquant' target="_blank">
[Blog]
<a href="https://hanlab.mit.edu/projects/svdquant">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="svdquant logo"
style="height: 40px; width: auto;"/>
</a>
</div>
<h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor' target="_blank">DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku' target="_blank">Nunchaku</a>&nbsp;
</h4>
<h1 style="margin-top: 0;">INT4 FLUX.1-{model_name}-dev Demo</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info}
</div>
......
......@@ -4,7 +4,6 @@ import random
import time
from datetime import datetime
import GPUtil
import torch
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline
......@@ -54,12 +53,16 @@ if args.precision == "bf16":
else:
assert args.precision == "int4"
pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-int4-flux.1-{model_name}")
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-{model_name}/svdq-int4_r32-flux.1-{model_name}.safetensors"
)
pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = pipeline_class.from_pretrained(
......@@ -125,11 +128,12 @@ def run(
with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name} Demo") as demo:
with open("assets/description.html", "r") as f:
DESCRIPTION = f.read()
gpus = GPUtil.getGPUs()
if len(gpus) > 0:
gpu = gpus[0]
memory = gpu.memoryTotal / 1024
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
# Get the GPU properties
if torch.cuda.device_count() > 0:
gpu_properties = torch.cuda.get_device_properties(0)
gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
gpu_name = torch.cuda.get_device_name(0)
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
......
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/>
INT4 FLUX.1-fill-dev Demo
</h1>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me' target="_blank">Muyang Li*</a>,
<a href='https://yujunlin.com' target="_blank">Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang' target="_blank">Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/' target="_blank">Tianle Cai</a>,
<a href='https://xiuyuli.com' target="_blank">Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX' target="_blank">Junxian Guo</a>,
<a href='https://xieenze.github.io' target="_blank">Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/' target="_blank">Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/' target="_blank">Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan' target="_blank">Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2411.05007" target="_blank">[Paper]</a>
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku' target="_blank">
[Code]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/projects/svdquant' target="_blank">
[Website]
<!-- Logo Row -->
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
<a href="https://github.com/mit-han-lab/nunchaku">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
alt="nunchaku logo"
style="height: 150px; width: auto;"/>
</a>
&nbsp;
<a href='https://hanlab.mit.edu/blog/svdquant' target="_blank">
[Blog]
<a href="https://hanlab.mit.edu/projects/svdquant">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="svdquant logo"
style="height: 40px; width: auto;"/>
</a>
</div>
<!-- Title -->
<h1 style="margin-top: 0;">INT4 FLUX.1-fill-dev Demo</h1>
<h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor' target="_blank">DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku' target="_blank">Nunchaku</a>&nbsp;
......
......@@ -4,7 +4,6 @@ import random
import time
from datetime import datetime
import GPUtil
import torch
from diffusers import FluxFillPipeline
from PIL import Image
......@@ -26,12 +25,16 @@ if args.precision == "bf16":
else:
assert args.precision == "int4"
pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-fill-dev")
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-fill-dev/svdq-int4_r32-flux.1-fill-dev.safetensors"
)
pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxFillPipeline.from_pretrained(
......@@ -97,11 +100,12 @@ def run(
with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Flux.1-Fill-dev Sketch-to-Image Demo") as demo:
with open("assets/description.html", "r") as f:
DESCRIPTION = f.read()
gpus = GPUtil.getGPUs()
if len(gpus) > 0:
gpu = gpus[0]
memory = gpu.memoryTotal / 1024
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
# Get the GPU properties
if torch.cuda.device_count() > 0:
gpu_properties = torch.cuda.get_device_properties(0)
gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
gpu_name = torch.cuda.get_device_name(0)
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
......
# Nunchaku INT4 FLUX.1 Redux Demo
![demo](https://huggingface.co/mit-han-lab/nunchaku-artifacts/resolve/main/nunchaku/app/flux.1/redux/assets/demo.jpg)
![demo](https://huggingface.co/mit-han-lab/nunchaku-artifacts/resolve/main/nunchaku/app/flux.1/redux/assets/demo.png)
This interactive Gradio application allows you to interactively generate image variations. The base model is [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev). We use [FLUX.1-Redux-dev](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) to preprocess the image before inputting it into Flux.1-dev. To launch the application, run:
......
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/>
INT4 FLUX.1-redux-dev Demo
</h1>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me'>Muyang Li*</a>,
<a href='https://yujunlin.com'>Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang'>Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/'>Tianle Cai</a>,
<a href='https://xiuyuli.com'>Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX'>Junxian Guo</a>,
<a href='https://xieenze.github.io'>Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/'>Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/'>Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan'>Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2411.05007">[Paper]</a>
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku'>
[Code]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/projects/svdquant'>
[Website]
<!-- Logo Row -->
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
<a href="https://github.com/mit-han-lab/nunchaku">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
alt="nunchaku logo"
style="height: 150px; width: auto;"/>
</a>
&nbsp;
<a href='https://hanlab.mit.edu/blog/svdquant'>
[Blog]
<a href="https://hanlab.mit.edu/projects/svdquant">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="svdquant logo"
style="height: 40px; width: auto;"/>
</a>
</div>
<h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor'>DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku'>Nunchaku</a>&nbsp;
</h4>
<!-- Title -->
<h1 style="margin-top: 0;">INT4 FLUX.1-redux-dev Demo</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info}
</div>
......
@import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css');
.gradio-container{max-width: 1200px !important}
h1{text-align:center}
.gradio-container {
max-width: 1200px !important;
margin: auto; /* Centers the element horizontally */
}
h1 {
text-align: center
}
.wrap.svelte-p4aq0j.svelte-p4aq0j {
display: none;
......@@ -22,8 +28,13 @@ h1{text-align:center}
#accessibility {
text-align: center; /* Center-aligns the text */
margin: auto; /* Centers the element horizontally */
margin: auto; /* Centers the element horizontally */
}
#random_seed {height: 71px;}
#run_button {height: 87px;}
#random_seed {
height: 71px;
}
#run_button {
height: 87px;
}
......@@ -4,10 +4,6 @@ import random
import time
from datetime import datetime
import GPUtil
# import gradio last to avoid conflicts with other imports
import gradio as gr
import torch
from diffusers import FluxPipeline, FluxPriorReduxPipeline
from PIL import Image
......@@ -16,6 +12,9 @@ from vars import DEFAULT_GUIDANCE, DEFAULT_INFERENCE_STEP, EXAMPLES, MAX_SEED
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
# import gradio last to avoid conflicts with other imports
import gradio as gr # noqa: isort: skip
args = get_args()
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
......@@ -29,7 +28,9 @@ if args.precision == "bf16":
else:
assert args.precision == "int4"
pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-dev/svdq-int4_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
text_encoder=None,
......@@ -79,11 +80,12 @@ def run(image, num_inference_steps: int, guidance_scale: float, seed: int) -> tu
with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Flux.1-redux-dev Demo") as demo:
with open("assets/description.html", "r") as f:
DESCRIPTION = f.read()
gpus = GPUtil.getGPUs()
if len(gpus) > 0:
gpu = gpus[0]
memory = gpu.memoryTotal / 1024
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
# Get the GPU properties
if torch.cuda.device_count() > 0:
gpu_properties = torch.cuda.get_device_properties(0)
gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
gpu_name = torch.cuda.get_device_name(0)
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
......
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/>
INT4 FLUX.1-schnell Sketch-to-Image Demo
</h1>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me' target="_blank">Muyang Li*</a>,
<a href='https://yujunlin.com' target="_blank">Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang' target="_blank">Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/' target="_blank">Tianle Cai</a>,
<a href='https://xiuyuli.com' target="_blank">Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX' target="_blank">Junxian Guo</a>,
<a href='https://xieenze.github.io' target="_blank">Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/' target="_blank">Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/' target="_blank">Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan' target="_blank">Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2411.05007" target="_blank">[Paper]</a>
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku' target="_blank">
[Code]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/projects/svdquant' target="_blank">
[Website]
<!-- Logo Row -->
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
<a href="https://github.com/mit-han-lab/nunchaku">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
alt="nunchaku logo"
style="height: 150px; width: auto;"/>
</a>
&nbsp;
<a href='https://hanlab.mit.edu/blog/svdquant' target="_blank">
[Blog]
<a href="https://hanlab.mit.edu/projects/svdquant">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="svdquant logo"
style="height: 40px; width: auto;"/>
</a>
</div>
<h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor' target="_blank">DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku' target="_blank">Nunchaku</a>&nbsp;
Image Control: <a href="https://github.com/GaParmar/img2img-turbo" target="_blank">img2img-turbo</a>
</h4>
<h1 style="margin-top: 0;">INT4 FLUX.1-schnell Sketch-to-Image Demo</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info}
</div>
......
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import os
import random
import tempfile
import time
from datetime import datetime
import GPUtil
import numpy as np
import torch
from flux_pix2pix_pipeline import FluxPix2pixTurboPipeline
......@@ -33,12 +31,16 @@ if args.precision == "bf16":
else:
assert args.precision == "int4"
pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-schnell/svdq-int4_r32-flux.1-schnell.safetensors"
)
pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPix2pixTurboPipeline.from_pretrained(
......@@ -55,14 +57,6 @@ else:
safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)
def save_image(img):
if isinstance(img, dict):
img = img["composite"]
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(temp_file.name)
return temp_file.name
def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed: int) -> tuple[Image, str]:
print(f"Prompt: {prompt}")
......@@ -116,11 +110,12 @@ def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed:
with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Demo") as demo:
with open("assets/description.html", "r") as f:
DESCRIPTION = f.read()
gpus = GPUtil.getGPUs()
if len(gpus) > 0:
gpu = gpus[0]
memory = gpu.memoryTotal / 1024
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
# Get the GPU properties
if torch.cuda.device_count() > 0:
gpu_properties = torch.cuda.get_device_properties(0)
gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
gpu_name = torch.cuda.get_device_name(0)
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
......@@ -170,7 +165,6 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Dem
with gr.Row():
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
run_button = gr.Button("Run", scale=1, elem_id="run_button")
download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch")
with gr.Row():
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
prompt_template = gr.Textbox(
......@@ -207,7 +201,6 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Dem
)
latency_result = gr.Text(label="Inference Latency", show_label=True)
download_result = gr.DownloadButton("Download Result", elem_id="download_result")
gr.Markdown("### Instructions")
gr.Markdown("**1**. Enter a text prompt (e.g. a cat)")
gr.Markdown("**2**. Start sketching")
......@@ -235,8 +228,6 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Dem
api_name=False,
)
download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch)
download_result.click(fn=save_image, inputs=result, outputs=download_result)
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
......
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/>
FLUX.1-{model} Demo
</h1>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me' target="_blank">Muyang Li*</a>,
<a href='https://yujunlin.com' target="_blank">Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang' target="_blank">Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/' target="_blank">Tianle Cai</a>,
<a href='https://xiuyuli.com' target="_blank">Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX' target="_blank">Junxian Guo</a>,
<a href='https://xieenze.github.io' target="_blank">Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/' target="_blank">Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/' target="_blank">Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan' target="_blank">Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2411.05007" target="_blank">[Paper]</a>
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku' target="_blank">
[Code]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/projects/svdquant' target="_blank">
[Website]
<!-- Logo Row -->
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
<a href="https://github.com/mit-han-lab/nunchaku">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
alt="nunchaku logo"
style="height: 150px; width: auto;"/>
</a>
&nbsp;
<a href='https://hanlab.mit.edu/blog/svdquant' target="_blank">
[Blog]
<a href="https://hanlab.mit.edu/projects/svdquant">
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="svdquant logo"
style="height: 40px; width: auto;"/>
</a>
</div>
<h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor' target="_blank">DeepCompressor</a>
&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku' target="_blank">Nunchaku</a>
</h4>
<!-- Title -->
<h1 style="margin-top: 0;">FLUX.1-{model} Demo</h1>
<!-- Device Info -->
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info}
</div>
<!-- Notice -->
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{notice}
</div>
<!-- Count Info -->
{count_info}
</div>
</div>
......@@ -5,12 +5,11 @@ import random
import time
from datetime import datetime
import GPUtil
import spaces
import torch
from peft.tuners import lora
from utils import get_pipeline
from vars import DEFAULT_HEIGHT, DEFAULT_WIDTH, EXAMPLES, MAX_SEED, PROMPT_TEMPLATES, SVDQ_LORA_PATHS
from vars import DEFAULT_HEIGHT, DEFAULT_WIDTH, EXAMPLES, LORA_PATHS, MAX_SEED, PROMPT_TEMPLATES
from nunchaku.models.safety_checker import SafetyChecker
......@@ -98,7 +97,9 @@ def generate(
else:
assert precision == "int4"
if lora_name != "None":
pipeline.transformer.update_lora_params(SVDQ_LORA_PATHS[lora_name])
lora_path = LORA_PATHS[lora_name]
lora_path = os.path.join(lora_path["name_or_path"], lora_path["weight_name"])
pipeline.transformer.update_lora_params(lora_path)
pipeline.transformer.set_lora_strength(lora_weight)
else:
pipeline.transformer.set_lora_strength(0)
......@@ -157,11 +158,13 @@ def generate(
with open("./assets/description.html", "r") as f:
DESCRIPTION = f.read()
gpus = GPUtil.getGPUs()
if len(gpus) > 0:
gpu = gpus[0]
memory = gpu.memoryTotal / 1024
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
# Get the GPU properties
if torch.cuda.device_count() > 0:
gpu_properties = torch.cuda.get_device_properties(0)
gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
gpu_name = torch.cuda.get_device_name(0)
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
......
......@@ -28,17 +28,21 @@ def get_pipeline(
if precision in ["int4", "fp4"]:
assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices"
if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-schnell/svdq-int4_r32-flux.1-schnell.safetensors"
)
else:
assert precision == "fp4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-fp4-flux.1-schnell", precision="fp4"
"mit-han-lab/nunchaku-flux.1-schnell/svdq-fp4_r32-flux.1-schnell.safetensors", precision="fp4"
)
pipeline_init_kwargs["transformer"] = transformer
if use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
else:
assert precision == "bf16"
......@@ -47,7 +51,9 @@ def get_pipeline(
)
elif model_name == "dev":
if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-dev/svdq-int4_r32-flux.1-dev.safetensors"
)
if lora_name not in ["All", "None"]:
transformer.update_lora_params(SVDQ_LORA_PATHS[lora_name])
transformer.set_lora_strength(lora_weight)
......@@ -55,7 +61,9 @@ def get_pipeline(
if use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
......
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from nunchaku.utils import get_precision
precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors",
offload=True,
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
apply_cache_on_pipe(
pipeline,
use_double_fb_cache=True,
residual_diff_threshold_multi=0.09,
residual_diff_threshold_single=0.12,
)
image = pipeline(["A cat holding a sign that says hello world"], num_inference_steps=50).images[0]
image.save(f"flux.1-dev-cache-{precision}.png")
__version__ = "0.3.0"
__version__ = "0.3.1"
......@@ -42,7 +42,6 @@ public:
if (net) {
pybind11::object cb = residual_callback;
net->set_residual_callback([cb](const Tensor &x) -> Tensor {
pybind11::gil_scoped_acquire gil;
torch::Tensor torch_x = to_torch(x);
pybind11::object result = cb(torch_x);
torch::Tensor torch_y = result.cast<torch::Tensor>();
......@@ -143,9 +142,17 @@ public:
temb = temb.contiguous();
rotary_emb_single = rotary_emb_single.contiguous();
if (net->isOffloadEnabled()) {
net->single_transformer_blocks.at(idx)->loadLazyParams();
}
Tensor result = net->single_transformer_blocks.at(idx)->forward(
from_torch(hidden_states), from_torch(temb), from_torch(rotary_emb_single));
if (net->isOffloadEnabled()) {
net->single_transformer_blocks.at(idx)->releaseLazyParams();
}
hidden_states = to_torch(result);
Tensor::synchronizeDevice();
......
......@@ -5,12 +5,19 @@ import torch
from safetensors.torch import save_file
from .diffusers_converter import to_diffusers
from .utils import is_nunchaku_format
from .utils import is_nunchaku_format, load_state_dict_in_safetensors
def compose_lora(
loras: list[tuple[str | dict[str, torch.Tensor], float]], output_path: str | None = None
) -> dict[str, torch.Tensor]:
if len(loras) == 1:
if is_nunchaku_format(loras[0][0]) and (loras[0][1] - 1) < 1e-5:
if isinstance(loras[0][0], str):
return load_state_dict_in_safetensors(loras[0][0], device="cpu")
else:
return loras[0][0]
composed = {}
for lora, strength in loras:
assert not is_nunchaku_format(lora)
......
import argparse
import logging
import os
import warnings
import torch
from diffusers.loaders import FluxLoraLoaderMixin
......@@ -9,6 +9,52 @@ from safetensors.torch import save_file
from .utils import load_state_dict_in_safetensors
# Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
# Configure logging
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def handle_kohya_lora(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
# first check if the state_dict is in the kohya format
# like: https://civitai.com/models/1118358?modelVersionId=1256866
if any([not k.startswith("lora_transformer_") for k in state_dict.keys()]):
return state_dict
else:
new_state_dict = {}
for k, v in state_dict.items():
new_k = k.replace("lora_transformer_", "transformer.")
new_k = new_k.replace("norm_out_", "norm_out.")
new_k = new_k.replace("time_text_embed_", "time_text_embed.")
new_k = new_k.replace("guidance_embedder_", "guidance_embedder.")
new_k = new_k.replace("text_embedder_", "text_embedder.")
new_k = new_k.replace("timestep_embedder_", "timestep_embedder.")
new_k = new_k.replace("single_transformer_blocks_", "single_transformer_blocks.")
new_k = new_k.replace("_attn_", ".attn.")
new_k = new_k.replace("_norm_linear.", ".norm.linear.")
new_k = new_k.replace("_proj_mlp.", ".proj_mlp.")
new_k = new_k.replace("_proj_out.", ".proj_out.")
new_k = new_k.replace("transformer_blocks_", "transformer_blocks.")
new_k = new_k.replace("to_out_0.", "to_out.0.")
new_k = new_k.replace("_ff_context_net_0_proj.", ".ff_context.net.0.proj.")
new_k = new_k.replace("_ff_context_net_2.", ".ff_context.net.2.")
new_k = new_k.replace("_ff_net_0_proj.", ".ff.net.0.proj.")
new_k = new_k.replace("_ff_net_2.", ".ff.net.2.")
new_k = new_k.replace("_norm1_context_linear.", ".norm1_context.linear.")
new_k = new_k.replace("_norm1_linear.", ".norm1.linear.")
new_k = new_k.replace(".lora_down.", ".lora_A.")
new_k = new_k.replace(".lora_up.", ".lora_B.")
new_state_dict[new_k] = v
return new_state_dict
def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | None = None) -> dict[str, torch.Tensor]:
if isinstance(input_lora, str):
......@@ -16,6 +62,8 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N
else:
tensors = {k: v for k, v in input_lora.items()}
tensors = handle_kohya_lora(tensors)
### convert the FP8 tensors to BF16
for k, v in tensors.items():
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
......@@ -25,7 +73,14 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N
new_tensors = convert_unet_state_dict_to_peft(new_tensors)
if alphas is not None and len(alphas) > 0:
warnings.warn("Alpha values are not used in the conversion to diffusers format.")
for k, v in alphas.items():
key_A = k.replace(".alpha", ".lora_A.weight")
key_B = k.replace(".alpha", ".lora_B.weight")
assert key_A in new_tensors, f"Key {key_A} not found in new tensors."
assert key_B in new_tensors, f"Key {key_B} not found in new tensors."
rank = new_tensors[key_A].shape[0]
assert new_tensors[key_B].shape[1] == rank, f"Rank mismatch for {key_B}."
new_tensors[key_A] = new_tensors[key_A] * v / rank
if output_path is not None:
output_dir = os.path.dirname(os.path.abspath(output_path))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment