Commit 54e6d065 authored by muyangli's avatar muyangli
Browse files

[major] support NVFP4; upgrade to 0.1

parent c7f41661
......@@ -4,11 +4,12 @@ Nunchaku is an inference engine designed for 4-bit diffusion models, as demonstr
### [Paper](http://arxiv.org/abs/2411.05007) | [Project](https://hanlab.mit.edu/projects/svdquant) | [Blog](https://hanlab.mit.edu/blog/svdquant) | [Demo](https://svdquant.mit.edu)
- **[2025-02-20]** 🚀 **Support NVFP4 precision on NVIDIA RTX 5090!** NVFP4 delivers superior image quality compared to INT4, offering **~3× speedup** on the RTX 5090 over BF16. Checkout the [`examples`](./examples) for usage and try [our online demo](https://svdquant.mit.edu/flux1-schnell/)!
- **[2025-02-18]** 🔥 [**Customized LoRA conversion**](#Customized-LoRA) and [**model quantization**](#Customized-Model-Quantization) instructions are now available! **[ComfyUI](./comfyui)** workflows now support **customized LoRA**, along with **FLUX.1-Tools**!
- **[2025-02-14]** 🔥 **[LoRA conversion script](nunchaku/convert_lora.py)** is now available! [ComfyUI FLUX.1-tools workflows](./comfyui) is released!
- **[2025-02-11]** 🎉 **[SVDQuant](http://arxiv.org/abs/2411.05007) has been selected as a ICLR 2025 Spotlight! FLUX.1-tools Gradio demos are now available!** Check [here](#gradio-demos) for the usage details! Our new [depth-to-image demo](https://svdquant.mit.edu/flux1-depth-dev/) is also online—try it out!
- **[2025-02-04]** **🚀 4-bit [FLUX.1-tools](https://blackforestlabs.ai/flux-1-tools/) is here!** Enjoy a **2-3× speedup** over the original models. Check out the [examples](./examples) for usage. **ComfyUI integration is coming soon!**
- **[2025-01-23]** 🚀 **4-bit [SANA](https://nvlabs.github.io/Sana/) support is here!** Experience a 2-3× speedup compared to the 16-bit model. Check out the [usage example](./examples/sana_1600m_pag.py) and the [deployment guide](app/sana/t2i) for more details. Explore our live demo at [svdquant.mit.edu](https://svdquant.mit.edu)!
- **[2025-01-23]** 🚀 **4-bit [SANA](https://nvlabs.github.io/Sana/) support is here!** Experience a 2-3× speedup compared to the 16-bit model. Check out the [usage example](./examples/int4-sana_1600m_pag.py) and the [deployment guide](app/sana/t2i) for more details. Explore our live demo at [svdquant.mit.edu](https://svdquant.mit.edu)!
- **[2025-01-22]** 🎉 [**SVDQuant**](http://arxiv.org/abs/2411.05007) has been accepted to **ICLR 2025**!
- **[2024-12-08]** Support [ComfyUI](https://github.com/comfyanonymous/ComfyUI). Please check [comfyui/README.md](comfyui/README.md) for the usage.
- **[2024-11-07]** 🔥 Our latest **W4A4** Diffusion model quantization work [**SVDQuant**](https://hanlab.mit.edu/projects/svdquant) is publicly released! Check [**DeepCompressor**](https://github.com/mit-han-lab/deepcompressor) for the quantization library.
......@@ -41,6 +42,22 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat
## Installation
### Wheels (linux only)
Please make sure you are using Python 3.11.
```shell
pip install nunchaku
```
**The above wheel work for PyTorch>=2.6.** For PyTorch 2.4 and 2.5, please install the corresponding wheel in our release. For example:
```shell
pip install
```
### Build from Source
**Note**:
* Ensure your CUDA version is **≥ 12.2 on Linux** and **≥ 12.6 on Windows**.
......@@ -75,11 +92,9 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat
pip install -e . --no-build-isolation
```
[Optional] You can verify your installation by running `python -m nunchaku.test`. This will execute our 4-bit FLUX.1-schnell model, which may take some time to download.
## Usage Example
In [examples](examples), we provide minimal scripts for running INT4 [FLUX.1](https://github.com/black-forest-labs/flux) and [SANA](https://github.com/NVlabs/Sana) models with Nunchaku. For example, the [script](examples/flux.1-dev.py) for [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) is as follows:
In [examples](examples), we provide minimal scripts for running INT4 [FLUX.1](https://github.com/black-forest-labs/flux) and [SANA](https://github.com/NVlabs/Sana) models with Nunchaku. For example, the [script](examples/int4-flux.1-dev.py) for [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) is as follows:
```python
import torch
......@@ -134,7 +149,7 @@ transformer.update_lora_params(path_to_your_converted_lora)
transformer.set_lora_strength(lora_strength)
```
`path_to_your_lora` can also be a remote HuggingFace path. In [examples/flux.1-dev-lora.py](examples/flux.1-dev-lora.py), we provide a minimal example script for running [Ghibsky](https://huggingface.co/aleksa-codes/flux-ghibsky-illustration) LoRA with SVDQuant's INT4 FLUX.1-dev:
`path_to_your_lora` can also be a remote HuggingFace path. In [examples/int4-flux.1-dev-lora.py](examples/int4-flux.1-dev-lora.py), we provide a minimal example script for running [Ghibsky](https://huggingface.co/aleksa-codes/flux-ghibsky-illustration) LoRA with SVDQuant's INT4 FLUX.1-dev:
```python
import torch
......@@ -220,7 +235,7 @@ If you find `nunchaku` useful or relevant to your research, please cite our pape
* [Q-Diffusion: Quantizing Diffusion Models](https://arxiv.org/abs/2302.04304), ICCV 2023
* [AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration](https://arxiv.org/abs/2306.00978), MLSys 2024
* [DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models](https://arxiv.org/abs/2402.19481), CVPR 2024
* [QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving](https://arxiv.org/abs/2405.04532), ArXiv 2024
* [QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving](https://arxiv.org/abs/2405.04532), MLSys 2025
* [SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://arxiv.org/abs/2410.10629), ICLR 2025
## Acknowledgments
......
......@@ -14,7 +14,7 @@ def get_args():
"-m", "--model", type=str, default="schnell", choices=["schnell", "dev"], help="Which FLUX.1 model to use"
)
parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precision to use"
"-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precision to use"
)
parser.add_argument(
"-d", "--datasets", type=str, nargs="*", default=["MJHQ", "DCI"], help="The benchmark datasets to evaluate on."
......
......@@ -13,7 +13,7 @@ def get_args() -> argparse.Namespace:
"-m", "--model", type=str, default="schnell", choices=["schnell", "dev"], help="Which FLUX.1 model to use"
)
parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precision to use"
"-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precision to use"
)
parser.add_argument(
"--prompt", type=str, default="A cat holding a sign that says hello world", help="Prompt for the image"
......
......@@ -14,7 +14,7 @@ def get_args() -> argparse.Namespace:
"-m", "--model", type=str, default="schnell", choices=["schnell", "dev"], help="Which FLUX.1 model to use"
)
parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precision to use"
"-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precision to use"
)
parser.add_argument("-t", "--num-inference-steps", type=int, default=4, help="Number of inference steps")
......@@ -72,17 +72,20 @@ def main():
pipeline.transformer.register_forward_pre_hook(get_input_hook, with_kwargs=True)
pipeline(prompt=dummy_prompt, num_inference_steps=1, guidance_scale=args.guidance_scale, output_type="latent")
with torch.no_grad():
pipeline(
prompt=dummy_prompt, num_inference_steps=1, guidance_scale=args.guidance_scale, output_type="latent"
)
for _ in trange(args.warmup_times, desc="Warmup", position=0, leave=False):
pipeline.transformer(*inputs["args"], **inputs["kwargs"])
torch.cuda.synchronize()
for _ in trange(args.test_times, desc="Warmup", position=0, leave=False):
start_time = time.time()
pipeline.transformer(*inputs["args"], **inputs["kwargs"])
torch.cuda.synchronize()
end_time = time.time()
latency_list.append(end_time - start_time)
for _ in trange(args.warmup_times, desc="Warmup", position=0, leave=False):
pipeline.transformer(*inputs["args"], **inputs["kwargs"])
torch.cuda.synchronize()
for _ in trange(args.test_times, desc="Warmup", position=0, leave=False):
start_time = time.time()
pipeline.transformer(*inputs["args"], **inputs["kwargs"])
torch.cuda.synchronize()
end_time = time.time()
latency_list.append(end_time - start_time)
latency_list = sorted(latency_list)
ignored_count = int(args.ignore_ratio * len(latency_list) / 2)
......
......@@ -29,7 +29,7 @@ def get_args() -> argparse.Namespace:
type=str,
default=["int4"],
nargs="*",
choices=["int4", "bf16"],
choices=["int4", "fp4", "bf16"],
help="Which precisions to use",
)
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
......
......@@ -25,9 +25,15 @@ def get_pipeline(
pipeline_init_kwargs: dict = {},
) -> FluxPipeline:
if model_name == "schnell":
if precision == "int4":
if precision in ["int4", "fp4"]:
assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices"
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
else:
assert precision == "fp4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"/home/muyang/nunchaku_models/flux.1-schnell-nvfp4-svdq-gptq", precision="fp4"
)
pipeline_init_kwargs["transformer"] = transformer
if use_qencoder:
from nunchaku.models.text_encoder import NunchakuT5EncoderModel
......
import torch
from diffusers import FluxPipeline
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-fp4-flux.1-dev", precision="fp4")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5).images[0]
image.save("flux.1-dev.png")
import torch
from diffusers import FluxPipeline
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-fp4-flux.1-schnell", precision="fp4")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save("flux.1-schnell.png")
......@@ -10,4 +10,4 @@ pipeline = FluxPipeline.from_pretrained(
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save("flux.1-schnell-int4.png")
image.save("flux.1-schnell.png")
__version__ = "0.0.2beta6"
__version__ = "0.1.0"
......@@ -9,9 +9,9 @@
class QuantizedFluxModel : public ModuleWrapper<FluxModel> { // : public torch::CustomClassHolder {
public:
void init(bool bf16, int8_t deviceId) {
void init(bool use_fp4, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedFluxModel");
net = std::make_unique<FluxModel>(bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
net = std::make_unique<FluxModel>(use_fp4, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
torch::Tensor forward(
......
......@@ -8,7 +8,7 @@
class QuantizedGEMM : public ModuleWrapper<GEMM_W4A4> {
public:
void init(int64_t in_features, int64_t out_features, bool bias, bool bf16, int8_t deviceId) {
void init(int64_t in_features, int64_t out_features, bool bias, bool use_fp4, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedGEMM");
size_t val = 0;
......@@ -16,7 +16,7 @@ public:
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val);
net = std::make_unique<GEMM_W4A4>((int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
net = std::make_unique<GEMM_W4A4>((int)in_features, (int)out_features, bias, use_fp4, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
torch::Tensor forward(torch::Tensor x) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment