Commit 37c494a7 authored by Zhekai Zhang's avatar Zhekai Zhang
Browse files

Initial release

parents
# Prerequisites
*.d
# Compiled Object files
*.slo
*.lo
*.o
*.obj
# Precompiled Headers
*.gch
*.pch
# Compiled Dynamic libraries
*.so
*.dylib
*.dll
# Fortran module files
*.mod
*.smod
# Compiled Static libraries
*.lai
*.la
*.a
*.lib
# Executables
*.exe
*.out
*.app
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# VS Code
.vscode/
!.vscode/settings.json
.gradio/
.DS_Store
*.log
*.pt
*.nsys-rep
*.ncu-rep
\ No newline at end of file
[submodule "third_party/cutlass"]
path = third_party/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "third_party/json"]
path = third_party/json
url = https://github.com/nlohmann/json.git
[submodule "third_party/mio"]
path = third_party/mio
url = https://github.com/vimpunk/mio.git
[submodule "third_party/spdlog"]
path = third_party/spdlog
url = https://github.com/gabime/spdlog
[submodule "third_party/Block-Sparse-Attention"]
path = third_party/Block-Sparse-Attention
url = https://github.com/sxtyzhangzk/Block-Sparse-Attention.git
branch = nunchaku
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [2024] [MIT HAN Lab]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
\ No newline at end of file
# Nunchaku
### Paper | [Project](https://hanlab.mit.edu/projects/svdquant) | [Blog](https://hanlab.mit.edu/blog/svdquant) | [Demo](https://svdquant.mit.edu)
- **[Nov 7, 2024]** 🔥 Our latest **W4A4** Diffusion model quantization work [**SVDQuant**](https://hanlab.mit.edu/projects/svdquant) is publicly released! Check [**DeepCompressor**](https://github.com/mit-han-lab/deepcompressor) for the quantization library.
![teaser](./assets/teaser.jpg)
SVDQuant is a post-training quantization technique for 4-bit weights and activations that well maintains visual fidelity. On 12B FLUX.1-dev, it achieves 3.6× memory reduction compared to the BF16 model. By eliminating CPU offloading, it offers 8.7× speedup over the 16-bit model when on a 16GB laptop 4090 GPU, 3× faster than the NF4 W4A16 baseline. On PixArt-∑, it demonstrates significantly superior visual quality over other W4A4 or even W4A8 baselines. "E2E" means the end-to-end latency including the text encoder and VAE decoder.
**SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models**
Muyang Li, Yujun Lin, Zhekai Zhang, Tianle Cai, Xiuyu Li, Junxian Guo, Enze Xie, Chenlin Meng, Jun-Yan Zhu, and Song Han <br>
*MIT, NVIDIA, CMU, Princeton, UC Berkeley, SJTU, and Pika Labs* <br>
## Overview
#### Quantization Method -- SVDQuant
![intuition](./assets/intuition.gif)Overview of SVDQuant. Stage1: Originally, both the activation $\boldsymbol{X}$ and weights $\boldsymbol{W}$ contain outliers, making 4-bit quantization challenging. Stage 2: We migrate the outliers from activations to weights, resulting in the updated activation $\hat{\boldsymbol{X}}$ and weights $\hat{\boldsymbol{W}}$. While $\hat{\boldsymbol{X}}$ becomes easier to quantize, $\hat{\boldsymbol{W}}$ now becomes more difficult. Stage 3: SVDQuant further decomposes $\hat{\boldsymbol{W}}$ into a low-rank component $\boldsymbol{L}_1\boldsymbol{L}_2$ and a residual $\hat{\boldsymbol{W}}-\boldsymbol{L}_1\boldsymbol{L}_2$ with SVD. Thus, the quantization difficulty is alleviated by the low-rank branch, which runs at 16-bit precision.
#### Nunchaku Engine Design
![engine](./assets/engine.jpg) (a) Naïvely running low-rank branch with rank 32 will introduce 57% latency overhead due to extra read of 16-bit inputs in *Down Projection* and extra write of 16-bit outputs in *Up Projection*. Nunchaku optimizes this overhead with kernel fusion. (b) *Down Projection* and *Quantize* kernels use the same input, while *Up Projection* and *4-Bit Compute* kernels share the same output. To reduce data movement overhead, we fuse the first two and the latter two kernels together.
## Performance
![efficiency](./assets/efficiency.jpg)SVDQuant reduces the model size of the 12B FLUX.1 by 3.6×. Additionally, Nunchaku, further cuts memory usage of the 16-bit model by 3.5× and delivers 3.0× speedups over the NF4 W4A16 baseline on both the desktop and laptop NVIDIA RTX 4090 GPUs. Remarkably, on laptop 4090, it achieves in total 10.1× speedup by eliminating CPU offloading.
<p align="center">
<img src="./assets/speed_demo.gif" width="80%"/>
</p>
## Installation
1. Install dependencies:
```shell
conda create -n nunchaku python=3.11
conda activate nunchaku
pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
pip install diffusers ninja wheel transformers accelerate sentencepiece protobuf
pip install huggingface_hub peft opencv-python einops gradio spaces GPUtil
```
2. Install `nunchaku` package:
Make sure you have `gcc/g++>=11`. If you don't, you can install it via Conda:
```shell
conda install -c conda-forge gxx=11 gcc=11
```
Then build the package from source:
```shell
git clone https://github.com/mit-han-lab/nunchaku.git
cd nunchaku
git submodule init
git submodule update
pip install -e .
```
## Usage Example
In [example.py](example.py), we provide a minimal script for running INT4 FLUX.1-schnell model with Nunchaku.
```python
import torch
from nunchaku.pipelines import flux as nunchaku_flux
pipeline = nunchaku_flux.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
qmodel_path="mit-han-lab/svdquant-models/svdq-int4-flux.1-schnell.safetensors", # download from Huggingface
).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0]
image.save("example.png")
```
Specifically, `nunchaku` shares the same APIs as [diffusers](https://github.com/huggingface/diffusers) and can be used in a similar way. The FLUX.1-dev model can be loaded in the same way by replace all `schnell` with `dev`.
## Gradio Demos
### Text-to-Image
```shell
cd app/t2i
python run_gradio.py
```
* The demo also defaults to the FLUX.1-schnell model. To switch to the FLUX.1-dev model, use `-m dev`.
* By default, the Gemma-2B model is loaded as a safety checker. To disable this feature and save GPU memory, use `--no-safety-checker`.
* To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
* By default, only the INT4 DiT is loaded. Use `-p int4 bf16` to add a BF16 DiT for side-by-side comparison, or `-p bf16` to load only the BF16 model.
### Sketch-to-Image
```shell
cd app/t2i
python run_gradio.py
```
* Similarly, the demo loads the Gemma-2B model as a safety checker by default. To disable this feature, use `--no-safety-checker`.
* To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
* By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
## Benchmark
Please refer to [app/t2i/README.md](app/t2i/README.md) for instructions on reproducing our paper's quality results and benchmarking inference latency.
## Citation
If you find `nunchaku` useful or relevant to your research, please cite our paper:
```bibtex
@article{
li2024svdquant,
title={SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models},
author={Li*, Muyang and Lin*, Yujun and Zhang*, Zhekai and Cai, Tianle and Li, Xiuyu and Guo, Junxian and Xie, Enze and Meng, Chenlin and Zhu, Jun-Yan and Han, Song},
journal={arXiv preprint arXiv:TBD},
year={2024}
}
```
## Related Projects
* [Efficient Spatially Sparse Inference for Conditional GANs and Diffusion Models](https://arxiv.org/abs/2211.02048), NeurIPS 2022 & T-PAMI 2023
* [SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models](https://arxiv.org/abs/2211.10438), ICML 2023
* [AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration](https://arxiv.org/abs/2306.00978), MLSys 2024
* [DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models](https://arxiv.org/abs/2402.19481), CVPR 2024
* [QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving](https://arxiv.org/abs/2405.04532), ArXiv 2024
* [SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://arxiv.org/abs/2410.10629), ArXiv 2024
## Acknowledgments
We thank MIT-IBM Watson AI Lab, MIT and Amazon Science Hub, MIT AI Hardware Program, National Science Foundation, Packard Foundation, Dell, LG, Hyundai, and Samsung for supporting this research. We thank NVIDIA for donating the DGX server.
We use [img2img-turbo](https://github.com/GaParmar/img2img-turbo) to train the sketch-to-image LoRA. Our text-to-image and sketch-to-image UI is built upon [playground-v.25](https://huggingface.co/spaces/playgroundai/playground-v2.5/blob/main/app.py) and [img2img-turbo](https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py), respectively. Our safety checker is borrowed from [hart](https://github.com/mit-han-lab/hart).
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>
<img src="https://i.ibb.co/9gh96K6/logo.png"
alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/>
INT4 FLUX.1-schnell Sketch-to-Image Demo
</h1>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me'>Muyang Li*</a>,
<a href='https://yujunlin.com'>Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang'>Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/'>Tianle Cai</a>,
<a href='https://xiuyuli.com'>Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX'>Junxian Guo</a>,
<a href='https://xieenze.github.io'>Enze Xie</a>,
<a href='https://www.cs.cmu.edu/~srinivas/'>Chenlin Meng</a>,
<a href='https://cs.stanford.edu/~chenlin/'>Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan'>Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
[Paper]
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku'>
[Code]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/projects/svdquant'>
[Website]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/blog/svdquant'>
[Blog]
</a>
</div>
<h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor'>DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku'>Nunchaku</a>&nbsp;
Image Control: <a href="https://github.com/GaParmar/img2img-turbo">img2img-turbo</a>
</h4>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info}
</div>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{notice}
</div>
</div>
</div>
<div>
<br>
</div>
\ No newline at end of file
@import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css');
.gradio-container{max-width: 1200px !important}
h1{text-align:center}
.wrap.svelte-p4aq0j.svelte-p4aq0j {
display: none;
}
#column_input, #column_output {
width: 500px;
display: flex;
align-items: center;
}
#input_header, #output_header {
display: flex;
justify-content: center;
align-items: center;
width: 400px;
}
#random_seed {height: 71px;}
#run_button {height: 87px;}
\ No newline at end of file
import argparse
import os
import torch
from safetensors.torch import save_file
from tqdm import tqdm
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("input_path", type=str, help="Path to checkpoint file")
parser.add_argument(
"-o", "--output-path", type=str, help="Path to save the output checkpoint file", default="output.safetensors"
)
args = parser.parse_args()
return args
def swap_scale_shift(weight: torch.Tensor) -> torch.Tensor:
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def main():
args = get_args()
original_state_dict = torch.load(args.input_path, map_location="cpu")
new_state_dict = {
"transformer.x_embedder.weight": original_state_dict["img_in"]["weight"],
"transformer.x_embedder.bias": original_state_dict["img_in"]["bias"],
"transformer.norm_out.linear.weight": swap_scale_shift(
original_state_dict["final_layer"]["adaLN_modulation.1.weight"]
),
"transformer.norm_out.linear.bias": swap_scale_shift(
original_state_dict["final_layer"]["adaLN_modulation.1.bias"]
),
"transformer.proj_out.weight": original_state_dict["final_layer"]["linear.weight"],
"transformer.proj_out.bias": original_state_dict["final_layer"]["linear.bias"],
}
original_state_dict.pop("img_in")
original_state_dict.pop("final_layer")
original_lora_state_dict = original_state_dict["lora"]
for k, v in tqdm(original_lora_state_dict.items()):
if "double_blocks" in k:
new_k = k.replace("double_blocks", "transformer.transformer_blocks").replace(".default", "")
if "qkv" in new_k:
for i, p in enumerate(["q", "k", "v"]):
if "lora_A" in new_k:
# Copy the tensor
new_k2 = new_k.replace("img_attn.qkv", f"attn.to_{p}")
new_k2 = new_k2.replace("txt_attn.qkv", f"attn.add_{p}_proj")
new_state_dict[new_k2] = v.clone()
else:
assert "lora_B" in new_k
assert v.shape[0] % 3 == 0
chunk_size = v.shape[0] // 3
new_k2 = new_k.replace("img_attn.qkv", f"attn.to_{p}")
new_k2 = new_k2.replace("txt_attn.qkv", f"attn.add_{p}_proj")
new_state_dict[new_k2] = v[i * chunk_size : (i + 1) * chunk_size]
else:
new_k = new_k.replace("img_mod.lin", "norm1.linear")
new_k = new_k.replace("txt_mod.lin", "norm1_context.linear")
new_k = new_k.replace("img_mlp.0", "ff.net.0.proj")
new_k = new_k.replace("img_mlp.2", "ff.net.2")
new_state_dict[new_k] = v
else:
assert "single_blocks" in k
new_k = k.replace("single_blocks", "transformer.single_transformer_blocks").replace(".default", "")
if "linear1" in k:
start = 0
for i, p in enumerate(["q", "k", "v", "i"]):
if "lora_A" in new_k:
if p == "i":
new_k2 = new_k.replace("linear1", "proj_mlp")
else:
new_k2 = new_k.replace("linear1", f"attn.to_{p}")
new_state_dict[new_k2] = v.clone()
else:
if p == "i":
new_k2 = new_k.replace("linear1", "proj_mlp")
else:
new_k2 = new_k.replace("linear1", f"attn.to_{p}")
chunk_size = 12288 if p == "i" else 3072
new_state_dict[new_k2] = v[start : start + chunk_size]
start += chunk_size
elif "linear2" in k:
new_k = new_k.replace("linear2", "proj_out")
new_k = new_k.replace("modulation_lin", ".norm.linear")
new_state_dict[new_k] = v
else:
assert "modulation.lin" in k
new_k = new_k.replace("modulation.lin", "norm.linear")
new_state_dict[new_k] = v
os.makedirs(os.path.dirname(os.path.abspath(args.output_path)), exist_ok=True)
save_file(new_state_dict, args.output_path)
if __name__ == "__main__":
main()
import os
from typing import Any, Callable, Optional, Union
import torch
import torchvision.transforms.functional as F
from PIL import Image
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, FluxPipelineOutput, FluxTransformer2DModel
from einops import rearrange
from huggingface_hub import hf_hub_download
from peft.tuners import lora
from nunchaku.models.flux import inject_pipeline, load_quantized_model
from nunchaku.pipelines.flux import quantize_t5
class FluxPix2pixTurboPipeline(FluxPipeline):
def update_alpha(self, alpha: float) -> None:
self._alpha = alpha
transformer = self.transformer
for n, p in transformer.named_parameters():
if n in self._tuned_state_dict:
new_data = self._tuned_state_dict[n] * alpha + self._original_state_dict[n] * (1 - alpha)
new_data = new_data.to(self._execution_device).to(p.dtype)
p.data.copy_(new_data)
if self.precision == "bf16":
for m in transformer.modules():
if isinstance(m, lora.LoraLayer):
m.scaling["default_0"] = alpha
else:
assert self.precision == "int4"
transformer.nunchaku_set_lora_scale(alpha)
def load_control_module(
self,
pretrained_model_name_or_path: str,
weight_name: str | None = None,
svdq_lora_path: str | None = None,
alpha: float = 1,
):
state_dict, alphas = self.lora_state_dict(
pretrained_model_name_or_path, weight_name=weight_name, return_alphas=True
)
transformer = self.transformer
original_state_dict = {}
tuned_state_dict = {}
assert isinstance(transformer, FluxTransformer2DModel)
for n, p in transformer.named_parameters():
if f"transformer.{n}" in state_dict:
original_state_dict[n] = p.data.cpu()
tuned_state_dict[n] = state_dict[f"transformer.{n}"].cpu()
self._original_state_dict = original_state_dict
self._tuned_state_dict = tuned_state_dict
if self.precision == "bf16":
self.load_lora_into_transformer(state_dict, {}, transformer=transformer)
else:
assert svdq_lora_path is not None
self.transformer.nunchaku_update_params(svdq_lora_path)
self.update_alpha(alpha)
@torch.no_grad()
def __call__(
self,
image: str or Image,
image_type: str = "sketch",
alpha: float = 1.0,
prompt: str | None = None,
prompt_2: str | None = None,
height: int | None = 1024,
width: int | None = 1024,
timesteps: list[int] = None,
generator: torch.Generator | None = None,
prompt_embeds: torch.FloatTensor | None = None,
pooled_prompt_embeds: torch.FloatTensor | None = None,
output_type: str | None = "pil",
return_dict: bool = True,
joint_attention_kwargs: dict[str, Any] | None = None,
callback_on_step_end: Callable[[int, int, dict], None] | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
max_sequence_length: int = 512,
):
if alpha != self._alpha:
self.update_alpha(alpha)
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
guidance_scale = 0
num_images_per_prompt = 1
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
batch_size = 1
device = self._execution_device
lora_scale = self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if isinstance(image, str):
image = Image.open(image).convert("RGB").resize((width, height), Image.LANCZOS)
else:
image = image.resize((width, height), Image.LANCZOS)
image_t = F.to_tensor(image) < 0.5
image_t = image_t.unsqueeze(0).to(self.dtype).to(device)
image_t = (image_t - 0.5) * 2
# 4. Prepare latent variables
encoded_image = self.vae.encode(image_t, return_dict=False)[0].sample(generator=generator)
encoded_image = (encoded_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
if generator is None:
z = torch.randn_like(encoded_image)
else:
z = torch.randn(
encoded_image.shape, device=generator.device, dtype=encoded_image.dtype, generator=generator
).to(device)
noisy_latent = z * (1 - alpha) + encoded_image * alpha
noisy_latent = rearrange(noisy_latent, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
num_channels_latents = self.transformer.config.in_channels // 4
_, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents=None,
)
# 5. Denoising
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
t = torch.full((batch_size,), 1.0, dtype=self.dtype, device=device)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.tensor([guidance_scale], device=device)
guidance = guidance.expand(noisy_latent.shape[0])
else:
guidance = None
pred = self.transformer(
hidden_states=noisy_latent,
timestep=t,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
encoded_output = noisy_latent - pred
if output_type == "latent":
image = encoded_output
else:
encoded_output = self._unpack_latents(encoded_output, height, width, self.vae_scale_factor)
encoded_output = (encoded_output / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(encoded_output, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
qmodel_device = kwargs.pop("qmodel_device", "cuda:0")
qmodel_device = torch.device(qmodel_device)
qmodel_path = kwargs.pop("qmodel_path", None)
qencoder_path = kwargs.pop("qencoder_path", None)
pipeline = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
pipeline.precision = "bf16"
if qmodel_path is not None:
assert isinstance(qmodel_path, str)
if not os.path.exists(qmodel_path):
hf_repo_id = os.path.dirname(qmodel_path)
filename = os.path.basename(qmodel_path)
qmodel_path = hf_hub_download(repo_id=hf_repo_id, filename=filename)
m = load_quantized_model(qmodel_path, 0 if qmodel_device.index is None else qmodel_device.index)
inject_pipeline(pipeline, m)
pipeline.precision = "int4"
if qencoder_path is not None:
assert isinstance(qencoder_path, str)
if not os.path.exists(qencoder_path):
hf_repo_id = os.path.dirname(qencoder_path)
filename = os.path.basename(qencoder_path)
qencoder_path = hf_hub_download(repo_id=hf_repo_id, filename=filename)
quantize_t5(pipeline, qencoder_path)
return pipeline
import argparse
import torch
from flux_pix2pix_pipeline import FluxPix2pixTurboPipeline
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("input_path", type=str, help="Path to the input image")
parser.add_argument("-o", "--output-path", type=str, help="Path to save the output image", default="output.png")
parser.add_argument("-t", "--type", type=str, help="Input type", default="sketch", choices=["sketch", "canny"])
parser.add_argument(
"-m", "--model", type=str, default="pretrained/converted/sketch.safetensors", help="Path to the model"
)
parser.add_argument("-p", "--prompt", type=str, help="Prompt to use for the model", default="a cat")
parser.add_argument("-a", "--alpha", type=float, default=0.4, help="Alpha value for LoRA")
args = parser.parse_args()
return args
def main():
args = get_args()
pipeline = FluxPix2pixTurboPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
).to("cuda")
pipeline.load_model(args.model, alpha=args.alpha)
img = pipeline(image=args.input_path, image_type=args.type, alpha=args.alpha, prompt=args.prompt).images[0]
img.save(args.output_path)
if __name__ == "__main__":
main()
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import random
import tempfile
import time
import GPUtil
import gradio as gr
import torch
from PIL import Image
from flux_pix2pix_pipeline import FluxPix2pixTurboPipeline
from nunchaku.models.safety_checker import SafetyChecker
from utils import get_args
from vars import DEFAULT_SKETCH_GUIDANCE, DEFAULT_STYLE_NAME, MAX_SEED, STYLES, STYLE_NAMES
blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255))
args = get_args()
if args.precision == "bf16":
pipeline = FluxPix2pixTurboPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda")
pipeline.load_control_module(
"mit-han-lab/svdquant-models", "flux.1-pix2pix-turbo-sketch2image.safetensors", alpha=DEFAULT_SKETCH_GUIDANCE
)
else:
assert args.precision == "int4"
pipeline = FluxPix2pixTurboPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
qmodel_path="mit-han-lab/svdquant-models/svdq-int4-flux.1-schnell.safetensors",
qencoder_path="mit-han-lab/svdquant-models/svdq-w4a16-t5.pt" if args.use_qencoder else None,
)
pipeline = pipeline.to("cuda")
pipeline.load_control_module(
"mit-han-lab/svdquant-models",
"flux.1-pix2pix-turbo-sketch2image.safetensors",
svdq_lora_path="mit-han-lab/svdquant-models/svdq-flux.1-pix2pix-turbo-sketch2image.safetensors",
alpha=DEFAULT_SKETCH_GUIDANCE,
)
safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)
def save_image(img):
if isinstance(img, dict):
img = img["composite"]
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(temp_file.name)
return temp_file.name
def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed: int) -> tuple[Image, str]:
is_unsafe_prompt = False
if not safety_checker(prompt):
is_unsafe_prompt = True
prompt = "A peaceful world."
prompt = prompt_template.format(prompt=prompt)
start_time = time.time()
result_image = pipeline(
image=image["composite"],
image_type="sketch",
alpha=sketch_guidance,
prompt=prompt,
generator=torch.Generator().manual_seed(int(seed)),
).images[0]
latency = time.time() - start_time
if latency < 1:
latency = latency * 1000
latency_str = f"{latency:.2f}ms"
else:
latency_str = f"{latency:.2f}s"
if is_unsafe_prompt:
latency_str += " (Unsafe prompt detected)"
torch.cuda.empty_cache()
return result_image, latency_str
with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image Demo") as demo:
with open("assets/description.html", "r") as f:
DESCRIPTION = f.read()
gpus = GPUtil.getGPUs()
if len(gpus) > 0:
gpu = gpus[0]
memory = gpu.memoryTotal / 1024
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
gr.HTML(DESCRIPTION.format(device_info=device_info, notice=notice))
with gr.Row(elem_id="main_row"):
with gr.Column(elem_id="column_input"):
gr.Markdown("## INPUT", elem_id="input_header")
with gr.Group():
canvas = gr.Sketchpad(
value=blank_image,
height=640,
image_mode="RGB",
sources=["upload", "clipboard"],
type="pil",
label="Sketch",
show_label=False,
show_download_button=True,
interactive=True,
transforms=[],
canvas_size=(1024, 1024),
scale=1,
brush=gr.Brush(default_size=1, colors=["#000000"], color_mode="fixed"),
format="png",
layers=False,
)
with gr.Row():
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
run_button = gr.Button("Run", scale=1, elem_id="run_button")
download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch")
with gr.Row():
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
prompt_template = gr.Textbox(
label="Prompt Style Template", value=STYLES[DEFAULT_STYLE_NAME], scale=2, max_lines=1
)
with gr.Row():
sketch_guidance = gr.Slider(
label="Sketch Guidance",
show_label=True,
minimum=0,
maximum=1,
value=DEFAULT_SKETCH_GUIDANCE,
step=0.01,
scale=5,
)
with gr.Row():
seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
with gr.Column(elem_id="column_output"):
gr.Markdown("## OUTPUT", elem_id="output_header")
with gr.Group():
result = gr.Image(
format="png",
height=640,
image_mode="RGB",
type="pil",
label="Result",
show_label=False,
show_download_button=True,
interactive=False,
elem_id="output_image",
)
latency_result = gr.Text(label="Inference Latency", show_label=True)
download_result = gr.DownloadButton("Download Result", elem_id="download_result")
gr.Markdown("### Instructions")
gr.Markdown("**1**. Enter a text prompt (e.g. a cat)")
gr.Markdown("**2**. Start sketching")
gr.Markdown("**3**. Change the image style using a style template")
gr.Markdown("**4**. Adjust the effect of sketch guidance using the slider (typically between 0.2 and 0.4)")
gr.Markdown("**5**. Try different seeds to generate different results")
run_inputs = [canvas, prompt, prompt_template, sketch_guidance, seed]
run_outputs = [result, latency_result]
randomize_seed.click(
lambda: random.randint(0, MAX_SEED),
inputs=[],
outputs=seed,
api_name=False,
queue=False,
).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)
style.change(
lambda x: STYLES[x],
inputs=[style],
outputs=[prompt_template],
api_name=False,
queue=False,
).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False)
gr.on(
triggers=[prompt.submit, run_button.click, canvas.change],
fn=run,
inputs=run_inputs,
outputs=run_outputs,
api_name=False,
)
download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch)
download_result.click(fn=save_image, inputs=result, outputs=download_result)
if __name__ == "__main__":
demo.queue().launch(debug=True, share=True)
import cv2
import numpy as np
from PIL import Image
import argparse
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
# parser.add_argument(
# "-m", "--model", type=str, default="pretrained/converted/sketch.safetensors", help="Path to the model"
# )
parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precisions to use"
)
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
args = parser.parse_args()
return args
STYLES = {
"None": "{prompt}",
"Cinematic": "cinematic still {prompt}. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
"3D Model": "professional 3d model {prompt}. octane render, highly detailed, volumetric, dramatic lighting",
"Anime": "anime artwork {prompt}. anime style, key visual, vibrant, studio anime, highly detailed",
"Digital Art": "concept art {prompt}. digital artwork, illustrative, painterly, matte painting, highly detailed",
"Photographic": "cinematic photo {prompt}. 35mm photograph, film, bokeh, professional, 4k, highly detailed",
"Pixel art": "pixel-art {prompt}. low-res, blocky, pixel art style, 8-bit graphics",
"Fantasy art": "ethereal fantasy concept art of {prompt}. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
"Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
"Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style",
}
DEFAULT_STYLE_NAME = "3D Model"
STYLE_NAMES = list(STYLES.keys())
MAX_SEED = 1000000000
DEFAULT_SKETCH_GUIDANCE = 0.28
# Nunchaku INT4 FLUX.1 Models
## Text-to-Image Gradio Demo
```shell
python run_gradio.py
```
* The demo also defaults to the FLUX.1-schnell model. To switch to the FLUX.1-dev model, use `-m dev`.
* By default, the Gemma-2B model is loaded as a safety checker. To disable this feature and save GPU memory, use `--no-safety-checker`.
* To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
* By default, only the INT4 DiT is loaded. Use `-p int4 bf16` to add a BF16 DiT for side-by-side comparison, or `-p bf16` to load only the BF16 model.
## Command Line Inference
We provide a script, [generate.py](generate.py), that generates an image from a text prompt directly from the command line, similar to the demo. Simply run:
```shell
python generate.py --prompt "You Text Prompt"
```
* The generated image will be saved as `output.png` by default. You can specify a different path using the `-o` or `--output-path` options.
* The script defaults to using the FLUX.1-schnell model. To switch to the FLUX.1-dev model, use `-m dev`.
* By default, the script uses our INT4 model. To use the BF16 model instead, specify `-p bf16`.
* You can specify `--use-qencoder` to use our W4A16 text encoder.
* You can adjust the number of inference steps and guidance scale with `-t` and `-g`, respectively. For the FLUX.1-schnell model, the defaults are 4 steps and a guidance scale of 0; for the FLUX.1-dev model, the defaults are 50 steps and a guidance scale of 3.5.
* When using the FLUX.1-dev model, you also have the option to load a LoRA adapter with `--lora-name`. Available choices are `None`, [`Anime`](https://huggingface.co/alvdansen/sonny-anime-fixed), [`GHIBSKY Illustration`](https://huggingface.co/aleksa-codes/flux-ghibsky-illustration), [`Realism`](https://huggingface.co/XLabs-AI/flux-RealismLora), [`Children Sketch`](https://huggingface.co/Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch), and [`Yarn Art`](https://huggingface.co/linoyts/yarn_art_Flux_LoRA), with the default set to `None`. You can also specify the LoRA weight with `--lora-weight`, which defaults to 1.
## Latency Benchmark
To measure the latency of our INT4 models, use the following command:
```shell
python latency.py
```
* The script defaults to the INT4 FLUX.1-schnell model. To switch to FLUX.1-dev, use the `-m dev` option. For BF16 precision, add `-p bf16`.
* Adjust the number of inference steps and the guidance scale using `-t` and `-g`, respectively.
- For FLUX.1-schnell, the defaults are 4 steps and a guidance scale of 0.
- For FLUX.1-dev, the defaults are 50 steps and a guidance scale of 3.5.
* By default, the script measures the end-to-end latency for generating a single image. To measure the latency of a single DiT forward step instead, use the `--mode step` flag.
* Specify the number of warmup and test runs using `--warmup_times` and `--test_times`. The defaults are 2 warmup runs and 10 test runs.
## Quality Results
Below are the steps to reproduce the quality metrics reported in our paper. Firstly, you will need to install slightly more packages for the image quality metrics:
```shell
pip install clean-fid torchmetrics image-reward clip datasets
```
Then generate images using both the original BF16 model and our INT4 model on the [MJHQ](https://huggingface.co/datasets/playgroundai/MJHQ-30K) and [DCI](https://github.com/facebookresearch/DCI) datasets:
```shell
python evaluate.py -p int4
python evaluate.py -p bf16
```
* The commands above will generate images from FLUX.1-schnell on both datasets. Use `-m dev` to switch to FLUX.1-dev, or specify a single dataset with `-d MJHQ` or `-d DCI`.
* By default, generated images are saved to `results/$MODEL/$PRECISION`. Customize the output path using the `-o` option if desired.
* You can also adjust the number of inference steps and the guidance scale using `-t` and `-g`, respectively.
- For FLUX.1-schnell, the defaults are 4 steps and a guidance scale of 0.
- For FLUX.1-dev, the defaults are 50 steps and a guidance scale of 3.5.
* To accelerate the generation process, you can distribute the workload across multiple GPUs. For instance, if you have $N$ GPUs, on GPU $i (0 \le i < N)$ , you can add the options `--chunk-start $i --chunk-step $N`. This setup ensures each GPU handles a distinct portion of the workload, enhancing overall efficiency.
Finally you can compute the metrics for the images with
```shell
python get_metrics.py results/schnell/int4 results/schnell/bf16
```
Remember to replace the example paths with the actual paths to your image folders.
**Notes:**
- The script will calculate quality metrics (CLIP IQA, CLIP Score, Image Reward, FID) only for the first folder specified. Ensure the INT4 results folder is listed first.
- **Similarity Metrics**: If a second folder path is not provided, similarity metrics (LPIPS, PSNR, SSIM) will be skipped.
- **Output File**: Metric results are saved in `metrics.json` by default. Use `-o` to specify a custom output file if needed.
h1{text-align:center}
h2{text-align:center}
#random_seed {height: 72px;}
\ No newline at end of file
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>
<img src="https://i.ibb.co/9gh96K6/logo.png"
alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/>
FLUX.1-{model} Demo
</h1>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me'>Muyang Li*</a>,
<a href='https://yujunlin.com'>Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang'>Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/'>Tianle Cai</a>,
<a href='https://xiuyuli.com'>Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX'>Junxian Guo</a>,
<a href='https://xieenze.github.io'>Enze Xie</a>,
<a href='https://www.cs.cmu.edu/~srinivas/'>Chenlin Meng</a>,
<a href='https://cs.stanford.edu/~chenlin/'>Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan'>Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
[Paper]
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku'>
[Code]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/projects/svdquant'>
[Website]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/blog/svdquant'>
[Blog]
</a>
</div>
<h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor'>DeepCompressor</a>
&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku'>Nunchaku</a>
</h4>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info}
</div>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{notice}
</div>
</div>
</div>
\ No newline at end of file
.gradio-container{max-width: 560px !important}
.gradio-container{max-width: 1200px !important}
import os
import random
import datasets
import yaml
from PIL import Image
_CITATION = """\
@InProceedings{Urbanek_2024_CVPR,
author = {Urbanek, Jack and Bordes, Florian and Astolfi, Pietro and Williamson, Mary and Sharma, Vasu and Romero-Soriano, Adriana},
title = {A Picture is Worth More Than 77 Text Tokens: Evaluating CLIP-Style Models on Dense Captions},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2024},
pages = {26700-26709}
}
"""
_DESCRIPTION = """\
The Densely Captioned Images dataset, or DCI, consists of 7805 images from SA-1B,
each with a complete description aiming to capture the full visual detail of what is present in the image.
Much of the description is directly aligned to submasks of the image.
"""
_HOMEPAGE = "https://github.com/facebookresearch/DCI"
_LICENSE = "Attribution-NonCommercial 4.0 International (https://github.com/facebookresearch/DCI/blob/main/LICENSE)"
IMAGE_URL = "https://scontent.xx.fbcdn.net/m1/v/t6/An_zz_Te0EtVC_cHtUwnyNKODapWXuNNPeBgZn_3XY8yDFzwHrNb-zwN9mYCbAeWUKQooCI9mVbwvzZDZzDUlscRjYxLKsw.tar?ccb=10-5&oh=00_AYBnKR-fSIir-E49Q7-qO2tjmY0BGJhCciHS__B5QyiBAg&oe=673FFA8A&_nc_sid=0fdd51"
PROMPT_URLS = {"sDCI": "https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/sDCI.yaml"}
class DCIConfig(datasets.BuilderConfig):
def __init__(self, max_dataset_size: int = -1, return_gt: bool = False, **kwargs):
super(DCIConfig, self).__init__(
name=kwargs.get("name", "default"),
version=kwargs.get("version", "0.0.0"),
data_dir=kwargs.get("data_dir", None),
data_files=kwargs.get("data_files", None),
description=kwargs.get("description", None),
)
self.max_dataset_size = max_dataset_size
self.return_gt = return_gt
class DCI(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
BUILDER_CONFIG_CLASS = DCIConfig
BUILDER_CONFIGS = [DCIConfig(name="sDCI_full", version=VERSION, description="sDCI full prompt set")]
DEFAULT_CONFIG_NAME = "sDCI"
def _info(self):
features = datasets.Features(
{
"filename": datasets.Value("string"),
"image": datasets.Image(),
"prompt": datasets.Value("string"),
"meta_path": datasets.Value("string"),
"image_root": datasets.Value("string"),
"image_path": datasets.Value("string"),
"split": datasets.Value("string"),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.download.DownloadManager):
image_url = IMAGE_URL
meta_url = PROMPT_URLS[self.config.name]
meta_path = dl_manager.download(meta_url)
image_root = dl_manager.download_and_extract(image_url)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"meta_path": meta_path, "image_root": image_root}
)
]
def _generate_examples(self, meta_path: str, image_root: str):
meta = yaml.safe_load(open(meta_path, "r"))
names = list(meta.keys())
if self.config.max_dataset_size > 0:
random.Random(0).shuffle(names)
names = names[: self.config.max_dataset_size]
names = sorted(names)
for i, name in enumerate(names):
prompt = meta[name]
image_path = os.path.join(image_root, f"{name}.jpg")
yield i, {
"filename": name,
"image": Image.open(image_path) if self.config.return_gt else None,
"prompt": prompt,
"meta_path": meta_path,
"image_root": image_root,
"image_path": image_path,
"split": self.config.name,
}
import json
import os
import random
import datasets
from PIL import Image
_CITATION = """\
@misc{li2024playground,
title={Playground v2.5: Three Insights towards Enhancing Aesthetic Quality in Text-to-Image Generation},
author={Daiqing Li and Aleks Kamko and Ehsan Akhgari and Ali Sabet and Linmiao Xu and Suhail Doshi},
year={2024},
eprint={2402.17245},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
"""
_DESCRIPTION = """\
We introduce a new benchmark, MJHQ-30K, for automatic evaluation of a model’s aesthetic quality.
The benchmark computes FID on a high-quality dataset to gauge aesthetic quality.
"""
_HOMEPAGE = "https://huggingface.co/datasets/playgroundai/MJHQ-30K"
_LICENSE = (
"Playground v2.5 Community License "
"(https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md)"
)
IMAGE_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/mjhq30k_imgs.zip"
META_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/meta_data.json"
class MJHQConfig(datasets.BuilderConfig):
def __init__(self, max_dataset_size: int = -1, return_gt: bool = False, **kwargs):
super(MJHQConfig, self).__init__(
name=kwargs.get("name", "default"),
version=kwargs.get("version", "0.0.0"),
data_dir=kwargs.get("data_dir", None),
data_files=kwargs.get("data_files", None),
description=kwargs.get("description", None),
)
self.max_dataset_size = max_dataset_size
self.return_gt = return_gt
class DCI(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
BUILDER_CONFIG_CLASS = MJHQConfig
BUILDER_CONFIGS = [MJHQConfig(name="MJHQ", version=VERSION, description="MJHQ-30K full dataset")]
DEFAULT_CONFIG_NAME = "MJHQ"
def _info(self):
features = datasets.Features(
{
"filename": datasets.Value("string"),
"category": datasets.Value("string"),
"image": datasets.Image(),
"prompt": datasets.Value("string"),
"prompt_path": datasets.Value("string"),
"image_root": datasets.Value("string"),
"image_path": datasets.Value("string"),
"split": datasets.Value("string"),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.download.DownloadManager):
meta_path = dl_manager.download(META_URL)
image_root = dl_manager.download_and_extract(IMAGE_URL)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"meta_path": meta_path, "image_root": image_root}
),
]
def _generate_examples(self, meta_path: str, image_root: str):
with open(meta_path, "r") as f:
meta = json.load(f)
names = list(meta.keys())
if self.config.max_dataset_size > 0:
random.Random(0).shuffle(names)
names = names[: self.config.max_dataset_size]
names = sorted(names)
for i, name in enumerate(names):
category = meta[name]["category"]
prompt = meta[name]["prompt"]
image_path = os.path.join(image_root, category, f"{name}.jpg")
yield i, {
"filename": name,
"category": category,
"image": Image.open(image_path) if self.config.return_gt else None,
"prompt": prompt,
"meta_path": meta_path,
"image_root": image_root,
"image_path": image_path,
"split": self.config.name,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment