Commit e9ad0535 authored by muyangli's avatar muyangli
Browse files

[major] support SANA

parent 9eb2cee0
...@@ -4,9 +4,10 @@ Nunchaku is an inference engine designed for 4-bit diffusion models, as demonstr ...@@ -4,9 +4,10 @@ Nunchaku is an inference engine designed for 4-bit diffusion models, as demonstr
### [Paper](http://arxiv.org/abs/2411.05007) | [Project](https://hanlab.mit.edu/projects/svdquant) | [Blog](https://hanlab.mit.edu/blog/svdquant) | [Demo](https://svdquant.mit.edu) ### [Paper](http://arxiv.org/abs/2411.05007) | [Project](https://hanlab.mit.edu/projects/svdquant) | [Blog](https://hanlab.mit.edu/blog/svdquant) | [Demo](https://svdquant.mit.edu)
- **[Jan 22, 2024]** 🎉 [**SVDQuant**](http://arxiv.org/abs/2411.05007) has been accepted to **ICLR 2025**! - **[2025-01-23]** 🚀 **4-bit [SANA](https://nvlabs.github.io/Sana/) support is here!** Experience a 2-3× speedup compared to the 16-bit model. Check out the [usage example](./examples/sana_1600m_pag.py) and the [deployment guide](app/sana/t2i) for more details. Explore our live demo at [svdquant.mit.edu](https://svdquant.mit.edu)!
- **[Dec 8, 2024]** Support [ComfyUI](https://github.com/comfyanonymous/ComfyUI). Please check [comfyui/README.md](comfyui/README.md) for the usage. - **[2025-01-22]** 🎉 [**SVDQuant**](http://arxiv.org/abs/2411.05007) has been accepted to **ICLR 2025**!
- **[Nov 7, 2024]** 🔥 Our latest **W4A4** Diffusion model quantization work [**SVDQuant**](https://hanlab.mit.edu/projects/svdquant) is publicly released! Check [**DeepCompressor**](https://github.com/mit-han-lab/deepcompressor) for the quantization library. - **[2024-12-08]** Support [ComfyUI](https://github.com/comfyanonymous/ComfyUI). Please check [comfyui/README.md](comfyui/README.md) for the usage.
- **[2024-11-07]** 🔥 Our latest **W4A4** Diffusion model quantization work [**SVDQuant**](https://hanlab.mit.edu/projects/svdquant) is publicly released! Check [**DeepCompressor**](https://github.com/mit-han-lab/deepcompressor) for the quantization library.
![teaser](./assets/teaser.jpg) ![teaser](./assets/teaser.jpg)
SVDQuant is a post-training quantization technique for 4-bit weights and activations that well maintains visual fidelity. On 12B FLUX.1-dev, it achieves 3.6× memory reduction compared to the BF16 model. By eliminating CPU offloading, it offers 8.7× speedup over the 16-bit model when on a 16GB laptop 4090 GPU, 3× faster than the NF4 W4A16 baseline. On PixArt-∑, it demonstrates significantly superior visual quality over other W4A4 or even W4A8 baselines. "E2E" means the end-to-end latency including the text encoder and VAE decoder. SVDQuant is a post-training quantization technique for 4-bit weights and activations that well maintains visual fidelity. On 12B FLUX.1-dev, it achieves 3.6× memory reduction compared to the BF16 model. By eliminating CPU offloading, it offers 8.7× speedup over the 16-bit model when on a 16GB laptop 4090 GPU, 3× faster than the NF4 W4A16 baseline. On PixArt-∑, it demonstrates significantly superior visual quality over other W4A4 or even W4A8 baselines. "E2E" means the end-to-end latency including the text encoder and VAE decoder.
...@@ -38,7 +39,9 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat ...@@ -38,7 +39,9 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat
**Note**: **Note**:
* For Windows user, please refer to [this issue](https://github.com/mit-han-lab/nunchaku/issues/6) for the instruction. * Ensure your CUDA version is **≥ 12.2 on Linux** and **≥ 12.6 on Windows**.
* For Windows user, please refer to [this issue](https://github.com/mit-han-lab/nunchaku/issues/6) for the instruction. Please upgrade your MSVC compiler to the latest version.
* We currently support only NVIDIA GPUs with architectures sm_86 (Ampere: RTX 3090, A6000), sm_89 (Ada: RTX 4090), and sm_80 (A100). See [this issue](https://github.com/mit-han-lab/nunchaku/issues/1) for more details. * We currently support only NVIDIA GPUs with architectures sm_86 (Ampere: RTX 3090, A6000), sm_89 (Ada: RTX 4090), and sm_80 (A100). See [this issue](https://github.com/mit-han-lab/nunchaku/issues/1) for more details.
...@@ -47,7 +50,7 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat ...@@ -47,7 +50,7 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat
```shell ```shell
conda create -n nunchaku python=3.11 conda create -n nunchaku python=3.11
conda activate nunchaku conda activate nunchaku
pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121 pip install torch torchvision torchaudio
pip install diffusers ninja wheel transformers accelerate sentencepiece protobuf pip install diffusers ninja wheel transformers accelerate sentencepiece protobuf
pip install huggingface_hub peft opencv-python einops gradio spaces GPUtil pip install huggingface_hub peft opencv-python einops gradio spaces GPUtil
``` ```
...@@ -70,7 +73,7 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat ...@@ -70,7 +73,7 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat
## Usage Example ## Usage Example
In [example.py](example.py), we provide a minimal script for running INT4 FLUX.1-schnell model with Nunchaku. In [examples](examples), we provide minimal scripts for running INT4 [FLUX.1](https://github.com/black-forest-labs/flux) and [Sana](https://github.com/NVlabs/Sana) models with Nunchaku. For example, the [script](examples/flux.1-dev.py) for [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) is as follows:
```python ```python
import torch import torch
...@@ -78,15 +81,15 @@ from diffusers import FluxPipeline ...@@ -78,15 +81,15 @@ from diffusers import FluxPipeline
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell") transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0] image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5).images[0]
image.save("example.png") image.save("flux.1-dev.png")
``` ```
Specifically, `nunchaku` shares the same APIs as [diffusers](https://github.com/huggingface/diffusers) and can be used in a similar way. The FLUX.1-dev model can be loaded in the same way by replace all `schnell` with `dev`. Specifically, `nunchaku` shares the same APIs as [diffusers](https://github.com/huggingface/diffusers) and can be used in a similar way.
## ComfyUI ## ComfyUI
...@@ -94,10 +97,12 @@ Please refer to [comfyui/README.md](comfyui/README.md) for the usage in [ComfyUI ...@@ -94,10 +97,12 @@ Please refer to [comfyui/README.md](comfyui/README.md) for the usage in [ComfyUI
## Gradio Demos ## Gradio Demos
### Text-to-Image ### FLUX.1 Models
#### Text-to-Image
```shell ```shell
cd app/t2i cd app/flux.1/t2i
python run_gradio.py python run_gradio.py
``` ```
...@@ -106,10 +111,10 @@ python run_gradio.py ...@@ -106,10 +111,10 @@ python run_gradio.py
* To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`. * To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
* By default, only the INT4 DiT is loaded. Use `-p int4 bf16` to add a BF16 DiT for side-by-side comparison, or `-p bf16` to load only the BF16 model. * By default, only the INT4 DiT is loaded. Use `-p int4 bf16` to add a BF16 DiT for side-by-side comparison, or `-p bf16` to load only the BF16 model.
### Sketch-to-Image #### Sketch-to-Image
```shell ```shell
cd app/i2i cd app/flux.1/i2i
python run_gradio.py python run_gradio.py
``` ```
...@@ -117,9 +122,18 @@ python run_gradio.py ...@@ -117,9 +122,18 @@ python run_gradio.py
* To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`. * To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
* By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model. * By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
### Sana
#### Text-to-Image
```shell
cd app/sana/t2i
python run_gradio.py
```
## Benchmark ## Benchmark
Please refer to [app/t2i/README.md](app/t2i/README.md) for instructions on reproducing our paper's quality results and benchmarking inference latency. Please refer to [app/flux/t2i/README.md](app/flux/t2i/README.md) for instructions on reproducing our paper's quality results and benchmarking inference latency on FLUX.1 models.
## Roadmap ## Roadmap
...@@ -127,9 +141,10 @@ Please refer to [app/t2i/README.md](app/t2i/README.md) for instructions on repro ...@@ -127,9 +141,10 @@ Please refer to [app/t2i/README.md](app/t2i/README.md) for instructions on repro
- [x] Comfy UI node - [x] Comfy UI node
- [ ] Customized LoRA conversion instructions - [ ] Customized LoRA conversion instructions
- [ ] Customized model quantization instructions - [ ] Customized model quantization instructions
- [ ] FLUX.1 tools support
- [ ] Modularization - [ ] Modularization
- [ ] ControlNet and IP-Adapter integration - [ ] IP-Adapter integration
- [ ] Mochi and CogVideoX support - [ ] Video Model support
- [ ] Metal backend - [ ] Metal backend
## Citation ## Citation
...@@ -154,7 +169,7 @@ If you find `nunchaku` useful or relevant to your research, please cite our pape ...@@ -154,7 +169,7 @@ If you find `nunchaku` useful or relevant to your research, please cite our pape
* [AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration](https://arxiv.org/abs/2306.00978), MLSys 2024 * [AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration](https://arxiv.org/abs/2306.00978), MLSys 2024
* [DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models](https://arxiv.org/abs/2402.19481), CVPR 2024 * [DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models](https://arxiv.org/abs/2402.19481), CVPR 2024
* [QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving](https://arxiv.org/abs/2405.04532), ArXiv 2024 * [QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving](https://arxiv.org/abs/2405.04532), ArXiv 2024
* [SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://arxiv.org/abs/2410.10629), ArXiv 2024 * [SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://arxiv.org/abs/2410.10629), ICLR 2025
## Acknowledgments ## Acknowledgments
......
...@@ -48,8 +48,6 @@ ...@@ -48,8 +48,6 @@
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{notice} {notice}
</div> </div>
{count_info}
</div> </div>
</div> </div>
\ No newline at end of file
<div>
<br>
</div>
\ No newline at end of file
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py # Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import logging
import os import os
import random import random
import tempfile import tempfile
import time import time
from datetime import datetime
import GPUtil import GPUtil
import numpy as np import numpy as np
...@@ -63,10 +65,11 @@ def save_image(img): ...@@ -63,10 +65,11 @@ def save_image(img):
def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed: int) -> tuple[Image, str]: def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed: int) -> tuple[Image, str]:
print(f"Prompt: {prompt}")
image_numpy = np.array(image["composite"].convert("RGB")) image_numpy = np.array(image["composite"].convert("RGB"))
if prompt.strip() == "" and np.sum(image_numpy != 255) <= 100: if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628):
return image["composite"], "Please input the prompt or draw something." return blank_image, "Please input the prompt or draw something."
is_unsafe_prompt = False is_unsafe_prompt = False
if not safety_checker(prompt): if not safety_checker(prompt):
...@@ -98,9 +101,12 @@ def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed: ...@@ -98,9 +101,12 @@ def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed:
else: else:
count = 0 count = 0
count += 1 count += 1
print(f"Use count: {count}") current_time = datetime.now()
print(f"{current_time}: {count}")
with open("use_count.txt", "w") as f: with open("use_count.txt", "w") as f:
f.write(str(count)) f.write(str(count))
with open("use_record.txt", "a") as f:
f.write(f"{current_time}: {count}\n")
return result_image, latency_str return result_image, latency_str
...@@ -115,7 +121,27 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image De ...@@ -115,7 +121,27 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image De
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."' notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
gr.HTML(DESCRIPTION.format(device_info=device_info, notice=notice))
def get_header_str():
if args.count_use:
if os.path.exists("use_count.txt"):
with open("use_count.txt", "r") as f:
count = int(f.read())
else:
count = 0
count_info = (
f"<div style='display: flex; justify-content: center; align-items: center; text-align: center;'>"
f"<span style='font-size: 18px; font-weight: bold;'>Total inference runs: </span>"
f"<span style='font-size: 18px; color:red; font-weight: bold;'>&nbsp;{count}</span></div>"
)
else:
count_info = ""
header_str = DESCRIPTION.format(device_info=device_info, notice=notice, count_info=count_info)
return header_str
header = gr.HTML(get_header_str())
demo.load(fn=get_header_str, outputs=header)
with gr.Row(elem_id="main_row"): with gr.Row(elem_id="main_row"):
with gr.Column(elem_id="column_input"): with gr.Column(elem_id="column_input"):
......
...@@ -40,7 +40,7 @@ python latency.py ...@@ -40,7 +40,7 @@ python latency.py
- For FLUX.1-schnell, the defaults are 4 steps and a guidance scale of 0. - For FLUX.1-schnell, the defaults are 4 steps and a guidance scale of 0.
- For FLUX.1-dev, the defaults are 50 steps and a guidance scale of 3.5. - For FLUX.1-dev, the defaults are 50 steps and a guidance scale of 3.5.
* By default, the script measures the end-to-end latency for generating a single image. To measure the latency of a single DiT forward step instead, use the `--mode step` flag. * By default, the script measures the end-to-end latency for generating a single image. To measure the latency of a single DiT forward step instead, use the `--mode step` flag.
* Specify the number of warmup and test runs using `--warmup_times` and `--test_times`. The defaults are 2 warmup runs and 10 test runs. * Specify the number of warmup and test runs using `--warmup-times` and `--test-times`. The defaults are 2 warmup runs and 10 test runs.
## Quality Results ## Quality Results
......
...@@ -48,5 +48,6 @@ ...@@ -48,5 +48,6 @@
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{notice} {notice}
</div> </div>
{count_info}
</div> </div>
</div> </div>
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment