Commit e9ad0535 authored by muyangli's avatar muyangli
Browse files

[major] support SANA

parent 9eb2cee0
......@@ -4,9 +4,10 @@ Nunchaku is an inference engine designed for 4-bit diffusion models, as demonstr
### [Paper](http://arxiv.org/abs/2411.05007) | [Project](https://hanlab.mit.edu/projects/svdquant) | [Blog](https://hanlab.mit.edu/blog/svdquant) | [Demo](https://svdquant.mit.edu)
- **[Jan 22, 2024]** 🎉 [**SVDQuant**](http://arxiv.org/abs/2411.05007) has been accepted to **ICLR 2025**!
- **[Dec 8, 2024]** Support [ComfyUI](https://github.com/comfyanonymous/ComfyUI). Please check [comfyui/README.md](comfyui/README.md) for the usage.
- **[Nov 7, 2024]** 🔥 Our latest **W4A4** Diffusion model quantization work [**SVDQuant**](https://hanlab.mit.edu/projects/svdquant) is publicly released! Check [**DeepCompressor**](https://github.com/mit-han-lab/deepcompressor) for the quantization library.
- **[2025-01-23]** 🚀 **4-bit [SANA](https://nvlabs.github.io/Sana/) support is here!** Experience a 2-3× speedup compared to the 16-bit model. Check out the [usage example](./examples/sana_1600m_pag.py) and the [deployment guide](app/sana/t2i) for more details. Explore our live demo at [svdquant.mit.edu](https://svdquant.mit.edu)!
- **[2025-01-22]** 🎉 [**SVDQuant**](http://arxiv.org/abs/2411.05007) has been accepted to **ICLR 2025**!
- **[2024-12-08]** Support [ComfyUI](https://github.com/comfyanonymous/ComfyUI). Please check [comfyui/README.md](comfyui/README.md) for the usage.
- **[2024-11-07]** 🔥 Our latest **W4A4** Diffusion model quantization work [**SVDQuant**](https://hanlab.mit.edu/projects/svdquant) is publicly released! Check [**DeepCompressor**](https://github.com/mit-han-lab/deepcompressor) for the quantization library.
![teaser](./assets/teaser.jpg)
SVDQuant is a post-training quantization technique for 4-bit weights and activations that well maintains visual fidelity. On 12B FLUX.1-dev, it achieves 3.6× memory reduction compared to the BF16 model. By eliminating CPU offloading, it offers 8.7× speedup over the 16-bit model when on a 16GB laptop 4090 GPU, 3× faster than the NF4 W4A16 baseline. On PixArt-∑, it demonstrates significantly superior visual quality over other W4A4 or even W4A8 baselines. "E2E" means the end-to-end latency including the text encoder and VAE decoder.
......@@ -38,16 +39,18 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat
**Note**:
* For Windows user, please refer to [this issue](https://github.com/mit-han-lab/nunchaku/issues/6) for the instruction.
* Ensure your CUDA version is **≥ 12.2 on Linux** and **≥ 12.6 on Windows**.
* We currently support only NVIDIA GPUs with architectures sm_86 (Ampere: RTX 3090, A6000), sm_89 (Ada: RTX 4090), and sm_80 (A100). See [this issue](https://github.com/mit-han-lab/nunchaku/issues/1) for more details.
* For Windows user, please refer to [this issue](https://github.com/mit-han-lab/nunchaku/issues/6) for the instruction. Please upgrade your MSVC compiler to the latest version.
* We currently support only NVIDIA GPUs with architectures sm_86 (Ampere: RTX 3090, A6000), sm_89 (Ada: RTX 4090), and sm_80 (A100). See [this issue](https://github.com/mit-han-lab/nunchaku/issues/1) for more details.
1. Install dependencies:
```shell
conda create -n nunchaku python=3.11
conda activate nunchaku
pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
pip install torch torchvision torchaudio
pip install diffusers ninja wheel transformers accelerate sentencepiece protobuf
pip install huggingface_hub peft opencv-python einops gradio spaces GPUtil
```
......@@ -70,7 +73,7 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat
## Usage Example
In [example.py](example.py), we provide a minimal script for running INT4 FLUX.1-schnell model with Nunchaku.
In [examples](examples), we provide minimal scripts for running INT4 [FLUX.1](https://github.com/black-forest-labs/flux) and [Sana](https://github.com/NVlabs/Sana) models with Nunchaku. For example, the [script](examples/flux.1-dev.py) for [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) is as follows:
```python
import torch
......@@ -78,15 +81,15 @@ from diffusers import FluxPipeline
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0]
image.save("example.png")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5).images[0]
image.save("flux.1-dev.png")
```
Specifically, `nunchaku` shares the same APIs as [diffusers](https://github.com/huggingface/diffusers) and can be used in a similar way. The FLUX.1-dev model can be loaded in the same way by replace all `schnell` with `dev`.
Specifically, `nunchaku` shares the same APIs as [diffusers](https://github.com/huggingface/diffusers) and can be used in a similar way.
## ComfyUI
......@@ -94,10 +97,12 @@ Please refer to [comfyui/README.md](comfyui/README.md) for the usage in [ComfyUI
## Gradio Demos
### Text-to-Image
### FLUX.1 Models
#### Text-to-Image
```shell
cd app/t2i
cd app/flux.1/t2i
python run_gradio.py
```
......@@ -106,10 +111,10 @@ python run_gradio.py
* To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
* By default, only the INT4 DiT is loaded. Use `-p int4 bf16` to add a BF16 DiT for side-by-side comparison, or `-p bf16` to load only the BF16 model.
### Sketch-to-Image
#### Sketch-to-Image
```shell
cd app/i2i
cd app/flux.1/i2i
python run_gradio.py
```
......@@ -117,9 +122,18 @@ python run_gradio.py
* To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
* By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
### Sana
#### Text-to-Image
```shell
cd app/sana/t2i
python run_gradio.py
```
## Benchmark
Please refer to [app/t2i/README.md](app/t2i/README.md) for instructions on reproducing our paper's quality results and benchmarking inference latency.
Please refer to [app/flux/t2i/README.md](app/flux/t2i/README.md) for instructions on reproducing our paper's quality results and benchmarking inference latency on FLUX.1 models.
## Roadmap
......@@ -127,9 +141,10 @@ Please refer to [app/t2i/README.md](app/t2i/README.md) for instructions on repro
- [x] Comfy UI node
- [ ] Customized LoRA conversion instructions
- [ ] Customized model quantization instructions
- [ ] FLUX.1 tools support
- [ ] Modularization
- [ ] ControlNet and IP-Adapter integration
- [ ] Mochi and CogVideoX support
- [ ] IP-Adapter integration
- [ ] Video Model support
- [ ] Metal backend
## Citation
......@@ -154,7 +169,7 @@ If you find `nunchaku` useful or relevant to your research, please cite our pape
* [AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration](https://arxiv.org/abs/2306.00978), MLSys 2024
* [DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models](https://arxiv.org/abs/2402.19481), CVPR 2024
* [QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving](https://arxiv.org/abs/2405.04532), ArXiv 2024
* [SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://arxiv.org/abs/2410.10629), ArXiv 2024
* [SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://arxiv.org/abs/2410.10629), ICLR 2025
## Acknowledgments
......@@ -162,4 +177,4 @@ We thank MIT-IBM Watson AI Lab, MIT and Amazon Science Hub, MIT AI Hardware Prog
We use [img2img-turbo](https://github.com/GaParmar/img2img-turbo) to train the sketch-to-image LoRA. Our text-to-image and sketch-to-image UI is built upon [playground-v.25](https://huggingface.co/spaces/playgroundai/playground-v2.5/blob/main/app.py) and [img2img-turbo](https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py), respectively. Our safety checker is borrowed from [hart](https://github.com/mit-han-lab/hart).
Nunchaku is also inspired by many open-source libraries, including (but not limited to) [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [vLLM](https://github.com/vllm-project/vllm), [QServe](https://github.com/mit-han-lab/qserve), [AWQ](https://github.com/mit-han-lab/llm-awq), [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), and [Atom](https://github.com/efeslab/Atom).
Nunchaku is also inspired by many open-source libraries, including (but not limited to) [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [vLLM](https://github.com/vllm-project/vllm), [QServe](https://github.com/mit-han-lab/qserve), [AWQ](https://github.com/mit-han-lab/llm-awq), [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), and [Atom](https://github.com/efeslab/Atom).
\ No newline at end of file
......@@ -48,8 +48,6 @@
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{notice}
</div>
{count_info}
</div>
</div>
<div>
<br>
</div>
\ No newline at end of file
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import logging
import os
import random
import tempfile
import time
from datetime import datetime
import GPUtil
import numpy as np
......@@ -63,10 +65,11 @@ def save_image(img):
def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed: int) -> tuple[Image, str]:
print(f"Prompt: {prompt}")
image_numpy = np.array(image["composite"].convert("RGB"))
if prompt.strip() == "" and np.sum(image_numpy != 255) <= 100:
return image["composite"], "Please input the prompt or draw something."
if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628):
return blank_image, "Please input the prompt or draw something."
is_unsafe_prompt = False
if not safety_checker(prompt):
......@@ -98,9 +101,12 @@ def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed:
else:
count = 0
count += 1
print(f"Use count: {count}")
current_time = datetime.now()
print(f"{current_time}: {count}")
with open("use_count.txt", "w") as f:
f.write(str(count))
with open("use_record.txt", "a") as f:
f.write(f"{current_time}: {count}\n")
return result_image, latency_str
......@@ -115,7 +121,27 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image De
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
gr.HTML(DESCRIPTION.format(device_info=device_info, notice=notice))
def get_header_str():
if args.count_use:
if os.path.exists("use_count.txt"):
with open("use_count.txt", "r") as f:
count = int(f.read())
else:
count = 0
count_info = (
f"<div style='display: flex; justify-content: center; align-items: center; text-align: center;'>"
f"<span style='font-size: 18px; font-weight: bold;'>Total inference runs: </span>"
f"<span style='font-size: 18px; color:red; font-weight: bold;'>&nbsp;{count}</span></div>"
)
else:
count_info = ""
header_str = DESCRIPTION.format(device_info=device_info, notice=notice, count_info=count_info)
return header_str
header = gr.HTML(get_header_str())
demo.load(fn=get_header_str, outputs=header)
with gr.Row(elem_id="main_row"):
with gr.Column(elem_id="column_input"):
......
......@@ -40,7 +40,7 @@ python latency.py
- For FLUX.1-schnell, the defaults are 4 steps and a guidance scale of 0.
- For FLUX.1-dev, the defaults are 50 steps and a guidance scale of 3.5.
* By default, the script measures the end-to-end latency for generating a single image. To measure the latency of a single DiT forward step instead, use the `--mode step` flag.
* Specify the number of warmup and test runs using `--warmup_times` and `--test_times`. The defaults are 2 warmup runs and 10 test runs.
* Specify the number of warmup and test runs using `--warmup-times` and `--test-times`. The defaults are 2 warmup runs and 10 test runs.
## Quality Results
......
......@@ -48,5 +48,6 @@
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{notice}
</div>
{count_info}
</div>
</div>
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment