Commit fcbab540 authored by limm's avatar limm
Browse files

fix hip compile

parent 7c282e2e
Pipeline #3055 failed with stages
<div align="center" id="nunchaku_logo">
<img src="https://huggingface.co/datasets/nunchaku-tech/cdn/resolve/main/logo/v2/nunchaku-compact-transparent.png" alt="logo" width="220"></img>
</div>
<h3 align="center">
<a href="http://arxiv.org/abs/2411.05007"><b>Paper</b></a> | <a href="https://nunchaku.tech/docs/nunchaku/"><b>Docs</b></a> | <a href="https://hanlab.mit.edu/projects/svdquant"><b>Website</b></a> | <a href="https://hanlab.mit.edu/blog/svdquant"><b>Blog</b></a> | <a href="https://svdquant.mit.edu"><b>Demo</b></a> | <a href="https://huggingface.co/nunchaku-tech"><b>Hugging Face</b></a> | <a href="https://modelscope.cn/organization/nunchaku-tech"><b>ModelScope</b></a> | <a href="https://github.com/nunchaku-tech/ComfyUI-nunchaku"><b>ComfyUI</b></a>
</h3>
source /opt/dtk/env.sh
<div align="center">
<a href=https://discord.gg/Wk6PnwX9Sm target="_blank"><img src=https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2FWk6PnwX9Sm%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total height=22px></a>
<a href=https://huggingface.co/datasets/nunchaku-tech/cdn/resolve/main/nunchaku/assets/wechat.jpg target="_blank"><img src=https://img.shields.io/badge/WeChat-07C160?logo=wechat&logoColor=white height=22px></a>
<a href=https://deepwiki.com/nunchaku-tech/nunchaku target="_blank"><img src=https://deepwiki.com/badge.svg height=22px></a>
</div>
source /usr/local/bin/fastpt -T
**Nunchaku** is a high-performance inference engine optimized for 4-bit neural networks, as introduced in our paper [SVDQuant](http://arxiv.org/abs/2411.05007). For the underlying quantization library, check out [DeepCompressor](https://github.com/nunchaku-tech/deepcompressor).
export CPLUS_INCLUDE_PATH=/opt/dtk/roctracer/include:$CPLUS_INCLUDE_PATH
export AMDGPU_TARGETS="gfx906;gfx926;gfx928;gfx936"
Join our user groups on [**Discord**](https://discord.gg/Wk6PnwX9Sm) and [**WeChat**](https://huggingface.co/datasets/nunchaku-tech/cdn/resolve/main/nunchaku/assets/wechat.jpg) to engage in discussions with the community! More details can be found [here](https://github.com/nunchaku-tech/nunchaku/issues/149). If you have any questions, run into issues, or are interested in contributing, don’t hesitate to reach out!
## News
- **[2025-09-25]** 🔥 Release **4-bit [4/8-step lightning Qwen-Image-Edit-2509](https://huggingface.co/lightx2v/Qwen-Image-Lightning)**! Download on [Hugging Face](https://huggingface.co/nunchaku-tech/nunchaku-qwen-image-edit-2509) or [ModelScope](https://modelscope.cn/models/nunchaku-tech/nunchaku-qwen-image-edit-2509), and try it with our [example script](examples/v1/qwen-image-edit-2509-lightning.py).
- **[2025-09-24]** 🔥 Released **4-bit [Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509)**! Models are available on [Hugging Face](https://huggingface.co/nunchaku-tech/nunchaku-qwen-image-edit-2509). Try them out with our [example script](examples/v1/qwen-image-edit-2509.py). Lightning models will follow up!
- **[2025-09-09]** 🔥 Released [**4-bit Qwen-Image-Edit**](https://huggingface.co/Qwen/Qwen-Image-Edit) together with the [4/8-step Lightning](https://huggingface.co/lightx2v/Qwen-Image-Lightning) variants! Models are available on [Hugging Face](https://huggingface.co/nunchaku-tech/nunchaku-qwen-image-edit). Try them out with our [example script](examples/v1/qwen-image-edit.py).
- **[2025-09-04]** 🚀 Official release of **Nunchaku v1.0.0**! Qwen-Image now supports **asynchronous offloading**, reducing VRAM usage to as little as **3 GiB** with no performance loss. Check out the [tutorial](https://nunchaku.tech/docs/nunchaku/usage/qwenimage.html) to get started.
- **[2025-08-27]** 🔥 Release **4-bit [4/8-step lightning Qwen-Image](https://huggingface.co/lightx2v/Qwen-Image-Lightning)**! Download on [Hugging Face](https://huggingface.co/nunchaku-tech/nunchaku-qwen-image) or [ModelScope](https://modelscope.cn/models/nunchaku-tech/nunchaku-qwen-image), and try it with our [example script](examples/v1/qwen-image-lightning.py).
- **[2025-08-15]** 🔥 Our **4-bit Qwen-Image** models are now live on [Hugging Face](https://huggingface.co/nunchaku-tech/nunchaku-qwen-image)! Get started with our [example script](examples/v1/qwen-image.py). *ComfyUI, LoRA, and CPU offloading support are coming soon!*
- **[2025-08-15]** 🚀 The **Python backend** is now available! Explore our Pythonic FLUX models [here](nunchaku/models/transformers/transformer_flux_v2.py) and see the modular **4-bit linear layer** [here](nunchaku/models/linear.py).
<details>
<summary>More</summary>
- **[2025-07-31]** 🚀 **[FLUX.1-Krea-dev](https://www.krea.ai/blog/flux-krea-open-source-release) is now supported!** Check out our new [example script](./examples/flux.1-krea-dev.py) to get started.
- **[2025-07-13]** 🚀 The official [**Nunchaku documentation**](https://nunchaku.tech/docs/nunchaku/) is now live! Explore comprehensive guides and resources to help you get started.
- **[2025-06-29]** 🔥 Support **FLUX.1-Kontext**! Try out our [example script](./examples/flux.1-kontext-dev.py) to see it in action! Our demo is available at this [link](https://svdquant.mit.edu/kontext/)!
- **[2025-06-01]** 🚀 **Release v0.3.0!** This update adds support for multiple-batch inference, [**ControlNet-Union-Pro 2.0**](https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0), initial integration of [**PuLID**](https://github.com/ToTheBeginning/PuLID), and introduces [**Double FB Cache**](examples/flux.1-dev-double_cache.py). You can now load Nunchaku FLUX models as a single file, and our upgraded [**4-bit T5 encoder**](https://huggingface.co/nunchaku-tech/nunchaku-t5) now matches **FP8 T5** in quality!
- **[2025-04-16]** 🎥 Released tutorial videos in both [**English**](https://youtu.be/YHAVe-oM7U8?si=cM9zaby_aEHiFXk0) and [**Chinese**](https://www.bilibili.com/video/BV1BTocYjEk5/?share_source=copy_web&vd_source=8926212fef622f25cc95380515ac74ee) to assist installation and usage.
- **[2025-04-09]** 📢 Published the [April roadmap](https://github.com/nunchaku-tech/nunchaku/issues/266) and an [FAQ](https://github.com/nunchaku-tech/nunchaku/discussions/262) to help the community get started and stay up to date with Nunchaku’s development.
- **[2025-04-05]** 🚀 **Nunchaku v0.2.0 released!** This release brings [**multi-LoRA**](examples/flux.1-dev-multiple-lora.py) and [**ControlNet**](examples/flux.1-dev-controlnet-union-pro.py) support with even faster performance powered by [**FP16 attention**](#fp16-attention) and [**First-Block Cache**](#first-block-cache). We've also added compatibility for [**20-series GPUs**](examples/flux.1-dev-turing.py) — Nunchaku is now more accessible than ever!
- **[2025-03-07]** 🚀 **Nunchaku v0.1.4 Released!** We've supported [4-bit text encoder and per-layer CPU offloading](#Low-Memory-Inference), reducing FLUX's minimum memory requirement to just **4 GiB** while maintaining a **2–3× speedup**. This update also fixes various issues related to resolution, LoRA, pin memory, and runtime stability. Check out the release notes for full details!
- **[2025-02-20]** 🚀 **Support NVFP4 precision on NVIDIA RTX 5090!** NVFP4 delivers superior image quality compared to INT4, offering **~3× speedup** on the RTX 5090 over BF16. Learn more in our [blog](https://hanlab.mit.edu/blog/svdquant-nvfp4), checkout [`examples`](./examples) for usage and try [our demo](https://svdquant.mit.edu/flux1-schnell/) online!
- **[2025-02-18]** 🔥 [**Customized LoRA conversion**](#Customized-LoRA) and [**model quantization**](#Customized-Model-Quantization) instructions are now available! **[ComfyUI](./comfyui)** workflows now support **customized LoRA**, along with **FLUX.1-Tools**!
- **[2025-02-11]** 🎉 **[SVDQuant](http://arxiv.org/abs/2411.05007) has been selected as a ICLR 2025 Spotlight! FLUX.1-tools Gradio demos are now available!** Check [here](#gradio-demos) for the usage details! Our new [depth-to-image demo](https://svdquant.mit.edu/flux1-depth-dev/) is also online—try it out!
- **[2025-02-04]** **🚀 4-bit [FLUX.1-tools](https://blackforestlabs.ai/flux-1-tools/) is here!** Enjoy a **2-3× speedup** over the original models. Check out the [examples](./examples) for usage. **ComfyUI integration is coming soon!**
- **[2025-01-23]** 🚀 **4-bit [SANA](https://nvlabs.github.io/Sana/) support is here!** Experience a 2-3× speedup compared to the 16-bit model. Check out the [usage example](examples/sana1.6b_pag.py) and the [deployment guide](app/sana/t2i) for more details. Explore our live demo at [svdquant.mit.edu](https://svdquant.mit.edu)!
- **[2025-01-22]** 🎉 [**SVDQuant**](http://arxiv.org/abs/2411.05007) has been accepted to **ICLR 2025**!
- **[2024-12-08]** Support [ComfyUI](https://github.com/comfyanonymous/ComfyUI). Please check [ComfyUI-nunchaku](https://github.com/nunchaku-tech/ComfyUI-nunchaku) for the usage.
- **[2024-11-07]** 🔥 Our latest **W4A4** Diffusion model quantization work [**SVDQuant**](https://hanlab.mit.edu/projects/svdquant) is publicly released! Check [**DeepCompressor**](https://github.com/nunchaku-tech/deepcompressor) for the quantization library.
</details>
## Overview
![teaser](https://huggingface.co/datasets/nunchaku-tech/cdn/resolve/main/nunchaku/assets/teaser.jpg)
**Nunchaku** is a high-performance inference engine for low-bit neural networks. It implements **SVDQuant**, a post-training quantization technique for 4-bit weights and activations that well maintains visual fidelity. On 12B FLUX.1-dev, it achieves 3.6× memory reduction compared to the BF16 model. By eliminating CPU offloading, it offers 8.7× speedup over the 16-bit model when on a 16GB laptop 4090 GPU, 3× faster than the NF4 W4A16 baseline. On PixArt-∑, it demonstrates significantly superior visual quality over other W4A4 or even W4A8 baselines. "E2E" means the end-to-end latency including the text encoder and VAE decoder.
**SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models**<br>
[Muyang Li](https://lmxyy.me)\*, [Yujun Lin](https://yujunlin.com)\*, [Zhekai Zhang](https://hanlab.mit.edu/team/zhekai-zhang)\*, [Tianle Cai](https://www.tianle.website/#/), [Xiuyu Li](https://xiuyuli.com), [Junxian Guo](https://github.com/JerryGJX), [Enze Xie](https://xieenze.github.io), [Chenlin Meng](https://cs.stanford.edu/~chenlin/), [Jun-Yan Zhu](https://www.cs.cmu.edu/~junyanz/), and [Song Han](https://hanlab.mit.edu/songhan) <br>
*MIT, NVIDIA, CMU, Princeton, UC Berkeley, SJTU, and Pika Labs* <br>
https://github.com/user-attachments/assets/fdd4ab68-6489-4c65-8768-259bd866e8f8
## Method
#### Quantization Method -- SVDQuant
![intuition](https://huggingface.co/datasets/nunchaku-tech/cdn/resolve/main/nunchaku/assets/intuition.gif)Overview of SVDQuant. Stage1: Originally, both the activation $\boldsymbol{X}$ and weights $\boldsymbol{W}$ contain outliers, making 4-bit quantization challenging. Stage 2: We migrate the outliers from activations to weights, resulting in the updated activation $\hat{\boldsymbol{X}}$ and weights $\hat{\boldsymbol{W}}$. While $\hat{\boldsymbol{X}}$ becomes easier to quantize, $\hat{\boldsymbol{W}}$ now becomes more difficult. Stage 3: SVDQuant further decomposes $\hat{\boldsymbol{W}}$ into a low-rank component $\boldsymbol{L}_1\boldsymbol{L}_2$ and a residual $\hat{\boldsymbol{W}}-\boldsymbol{L}_1\boldsymbol{L}_2$ with SVD. Thus, the quantization difficulty is alleviated by the low-rank branch, which runs at 16-bit precision.
#### Nunchaku Engine Design
![engine](https://huggingface.co/datasets/nunchaku-tech/cdn/resolve/main/nunchaku/assets/engine.jpg) (a) Naïvely running low-rank branch with rank 32 will introduce 57% latency overhead due to extra read of 16-bit inputs in *Down Projection* and extra write of 16-bit outputs in *Up Projection*. Nunchaku optimizes this overhead with kernel fusion. (b) *Down Projection* and *Quantize* kernels use the same input, while *Up Projection* and *4-Bit Compute* kernels share the same output. To reduce data movement overhead, we fuse the first two and the latter two kernels together.
## Performance
![efficiency](https://huggingface.co/datasets/nunchaku-tech/cdn/resolve/main/nunchaku/assets/efficiency.jpg)SVDQuant reduces the 12B FLUX.1 model size by 3.6× and cuts the 16-bit model's memory usage by 3.5×. With Nunchaku, our INT4 model runs 3.0× faster than the NF4 W4A16 baseline on both desktop and laptop NVIDIA RTX 4090 GPUs. Notably, on the laptop 4090, it achieves a total 10.1× speedup by eliminating CPU offloading. Our NVFP4 model is also 3.1× faster than both BF16 and NF4 on the RTX 5090 GPU.
## Getting Started
- [Installation Guide](https://nunchaku.tech/docs/nunchaku/installation/installation.html)
- [Usage Tutorial](https://nunchaku.tech/docs/nunchaku/usage/basic_usage.html)
- [ComfyUI Plugin: ComfyUI-nunchaku](https://github.com/nunchaku-tech/ComfyUI-nunchaku)
- [Custom Model Quantization: DeepCompressor](https://github.com/nunchaku-tech/deepcompressor)
- [Gradio Demo Apps](https://github.com/nunchaku-tech/nunchaku/tree/main/app)
- [Reproduce SVDQuant Paper Results](app/flux.1/t2i)
- [API Reference](https://nunchaku.tech/docs/nunchaku/python_api/nunchaku.html)
- [Contribution Guide](https://nunchaku.tech/docs/nunchaku/developer/contribution_guide.html)
- [Frequently Asked Questions](https://nunchaku.tech/docs/nunchaku/faq/faq.html)
## Contact Us
For enterprises interested in adopting SVDQuant or Nunchaku, including technical consulting, sponsorship opportunities, or partnership inquiries, please contact us at muyangli@nunchaku.tech.
## Related Projects
- [Efficient Spatially Sparse Inference for Conditional GANs and Diffusion Models](https://arxiv.org/abs/2211.02048), NeurIPS 2022 & T-PAMI 2023
- [SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models](https://arxiv.org/abs/2211.10438), ICML 2023
- [Q-Diffusion: Quantizing Diffusion Models](https://arxiv.org/abs/2302.04304), ICCV 2023
- [AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration](https://arxiv.org/abs/2306.00978), MLSys 2024
- [DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models](https://arxiv.org/abs/2402.19481), CVPR 2024
- [QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving](https://arxiv.org/abs/2405.04532), MLSys 2025
- [SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://arxiv.org/abs/2410.10629), ICLR 2025
- [Sparse VideoGen: Accelerating Video Diffusion Transformers with Spatial-Temporal Sparsity](https://arxiv.org/abs/2502.01776), ICML 2025
- [Radial Attention: $O(n \log n)$ Sparse Attention with Energy Decay for Long Video Generation](https://github.com/mit-han-lab/radial-attention), NeurIPS 2025
- [Sparse VideoGen2: Accelerate Video Generation with Sparse Attention via Semantic-Aware Permutation](https://arxiv.org/abs/2505.18875), NeurIPS 2025
## Citation
If you find `nunchaku` useful or relevant to your research, please cite our paper:
```bibtex
@inproceedings{
li2024svdquant,
title={SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models},
author={Li*, Muyang and Lin*, Yujun and Zhang*, Zhekai and Cai, Tianle and Li, Xiuyu and Guo, Junxian and Xie, Enze and Meng, Chenlin and Zhu, Jun-Yan and Han, Song},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025}
}
```
## Acknowledgments
We thank MIT-IBM Watson AI Lab, MIT and Amazon Science Hub, MIT AI Hardware Program, National Science Foundation, Packard Foundation, Dell, LG, Hyundai, and Samsung for supporting this research. We thank NVIDIA for donating the DGX server. We thank [First Intelligence](https://www.first-intelligence.com/) and [Yotta Labs](https://www.yottalabs.ai/) for generously sponsoring our computing resources.
We use [img2img-turbo](https://github.com/GaParmar/img2img-turbo) to train the sketch-to-image LoRA. Our text-to-image and image-to-image UI is built upon [playground-v.25](https://huggingface.co/spaces/playgroundai/playground-v2.5/blob/main/app.py) and [img2img-turbo](https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py), respectively. Our safety checker is borrowed from [hart](https://github.com/mit-han-lab/hart).
Nunchaku is also inspired by many open-source libraries, including (but not limited to) [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [vLLM](https://github.com/vllm-project/vllm), [QServe](https://github.com/mit-han-lab/qserve), [AWQ](https://github.com/mit-han-lab/llm-awq), [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), and [Atom](https://github.com/efeslab/Atom).
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=nunchaku-tech/nunchaku&type=Date)](https://www.star-history.com/#nunchaku-tech/nunchaku&Date)
CXX=hipcc CC=hipcc python setup.py bdist_wheel
<div align="center" id="nunchaku_logo">
<img src="https://huggingface.co/datasets/nunchaku-tech/cdn/resolve/main/logo/v2/nunchaku-compact-transparent.png" alt="logo" width="220"></img>
</div>
<h3 align="center">
<a href="http://arxiv.org/abs/2411.05007"><b>Paper</b></a> | <a href="https://nunchaku.tech/docs/nunchaku/"><b>Docs</b></a> | <a href="https://hanlab.mit.edu/projects/svdquant"><b>Website</b></a> | <a href="https://hanlab.mit.edu/blog/svdquant"><b>Blog</b></a> | <a href="https://svdquant.mit.edu"><b>Demo</b></a> | <a href="https://huggingface.co/nunchaku-tech"><b>Hugging Face</b></a> | <a href="https://modelscope.cn/organization/nunchaku-tech"><b>ModelScope</b></a> | <a href="https://github.com/nunchaku-tech/ComfyUI-nunchaku"><b>ComfyUI</b></a>
</h3>
<div align="center">
<a href=https://discord.gg/Wk6PnwX9Sm target="_blank"><img src=https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2FWk6PnwX9Sm%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total height=22px></a>
<a href=https://huggingface.co/datasets/nunchaku-tech/cdn/resolve/main/nunchaku/assets/wechat.jpg target="_blank"><img src=https://img.shields.io/badge/WeChat-07C160?logo=wechat&logoColor=white height=22px></a>
<a href=https://deepwiki.com/nunchaku-tech/nunchaku target="_blank"><img src=https://deepwiki.com/badge.svg height=22px></a>
</div>
**Nunchaku** is a high-performance inference engine optimized for 4-bit neural networks, as introduced in our paper [SVDQuant](http://arxiv.org/abs/2411.05007). For the underlying quantization library, check out [DeepCompressor](https://github.com/nunchaku-tech/deepcompressor).
Join our user groups on [**Discord**](https://discord.gg/Wk6PnwX9Sm) and [**WeChat**](https://huggingface.co/datasets/nunchaku-tech/cdn/resolve/main/nunchaku/assets/wechat.jpg) to engage in discussions with the community! More details can be found [here](https://github.com/nunchaku-tech/nunchaku/issues/149). If you have any questions, run into issues, or are interested in contributing, don’t hesitate to reach out!
## News
- **[2025-09-25]** 🔥 Release **4-bit [4/8-step lightning Qwen-Image-Edit-2509](https://huggingface.co/lightx2v/Qwen-Image-Lightning)**! Download on [Hugging Face](https://huggingface.co/nunchaku-tech/nunchaku-qwen-image-edit-2509) or [ModelScope](https://modelscope.cn/models/nunchaku-tech/nunchaku-qwen-image-edit-2509), and try it with our [example script](examples/v1/qwen-image-edit-2509-lightning.py).
- **[2025-09-24]** 🔥 Released **4-bit [Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509)**! Models are available on [Hugging Face](https://huggingface.co/nunchaku-tech/nunchaku-qwen-image-edit-2509). Try them out with our [example script](examples/v1/qwen-image-edit-2509.py). Lightning models will follow up!
- **[2025-09-09]** 🔥 Released [**4-bit Qwen-Image-Edit**](https://huggingface.co/Qwen/Qwen-Image-Edit) together with the [4/8-step Lightning](https://huggingface.co/lightx2v/Qwen-Image-Lightning) variants! Models are available on [Hugging Face](https://huggingface.co/nunchaku-tech/nunchaku-qwen-image-edit). Try them out with our [example script](examples/v1/qwen-image-edit.py).
- **[2025-09-04]** 🚀 Official release of **Nunchaku v1.0.0**! Qwen-Image now supports **asynchronous offloading**, reducing VRAM usage to as little as **3 GiB** with no performance loss. Check out the [tutorial](https://nunchaku.tech/docs/nunchaku/usage/qwenimage.html) to get started.
- **[2025-08-27]** 🔥 Release **4-bit [4/8-step lightning Qwen-Image](https://huggingface.co/lightx2v/Qwen-Image-Lightning)**! Download on [Hugging Face](https://huggingface.co/nunchaku-tech/nunchaku-qwen-image) or [ModelScope](https://modelscope.cn/models/nunchaku-tech/nunchaku-qwen-image), and try it with our [example script](examples/v1/qwen-image-lightning.py).
- **[2025-08-15]** 🔥 Our **4-bit Qwen-Image** models are now live on [Hugging Face](https://huggingface.co/nunchaku-tech/nunchaku-qwen-image)! Get started with our [example script](examples/v1/qwen-image.py). *ComfyUI, LoRA, and CPU offloading support are coming soon!*
- **[2025-08-15]** 🚀 The **Python backend** is now available! Explore our Pythonic FLUX models [here](nunchaku/models/transformers/transformer_flux_v2.py) and see the modular **4-bit linear layer** [here](nunchaku/models/linear.py).
<details>
<summary>More</summary>
- **[2025-07-31]** 🚀 **[FLUX.1-Krea-dev](https://www.krea.ai/blog/flux-krea-open-source-release) is now supported!** Check out our new [example script](./examples/flux.1-krea-dev.py) to get started.
- **[2025-07-13]** 🚀 The official [**Nunchaku documentation**](https://nunchaku.tech/docs/nunchaku/) is now live! Explore comprehensive guides and resources to help you get started.
- **[2025-06-29]** 🔥 Support **FLUX.1-Kontext**! Try out our [example script](./examples/flux.1-kontext-dev.py) to see it in action! Our demo is available at this [link](https://svdquant.mit.edu/kontext/)!
- **[2025-06-01]** 🚀 **Release v0.3.0!** This update adds support for multiple-batch inference, [**ControlNet-Union-Pro 2.0**](https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0), initial integration of [**PuLID**](https://github.com/ToTheBeginning/PuLID), and introduces [**Double FB Cache**](examples/flux.1-dev-double_cache.py). You can now load Nunchaku FLUX models as a single file, and our upgraded [**4-bit T5 encoder**](https://huggingface.co/nunchaku-tech/nunchaku-t5) now matches **FP8 T5** in quality!
- **[2025-04-16]** 🎥 Released tutorial videos in both [**English**](https://youtu.be/YHAVe-oM7U8?si=cM9zaby_aEHiFXk0) and [**Chinese**](https://www.bilibili.com/video/BV1BTocYjEk5/?share_source=copy_web&vd_source=8926212fef622f25cc95380515ac74ee) to assist installation and usage.
- **[2025-04-09]** 📢 Published the [April roadmap](https://github.com/nunchaku-tech/nunchaku/issues/266) and an [FAQ](https://github.com/nunchaku-tech/nunchaku/discussions/262) to help the community get started and stay up to date with Nunchaku’s development.
- **[2025-04-05]** 🚀 **Nunchaku v0.2.0 released!** This release brings [**multi-LoRA**](examples/flux.1-dev-multiple-lora.py) and [**ControlNet**](examples/flux.1-dev-controlnet-union-pro.py) support with even faster performance powered by [**FP16 attention**](#fp16-attention) and [**First-Block Cache**](#first-block-cache). We've also added compatibility for [**20-series GPUs**](examples/flux.1-dev-turing.py) — Nunchaku is now more accessible than ever!
- **[2025-03-07]** 🚀 **Nunchaku v0.1.4 Released!** We've supported [4-bit text encoder and per-layer CPU offloading](#Low-Memory-Inference), reducing FLUX's minimum memory requirement to just **4 GiB** while maintaining a **2–3× speedup**. This update also fixes various issues related to resolution, LoRA, pin memory, and runtime stability. Check out the release notes for full details!
- **[2025-02-20]** 🚀 **Support NVFP4 precision on NVIDIA RTX 5090!** NVFP4 delivers superior image quality compared to INT4, offering **~3× speedup** on the RTX 5090 over BF16. Learn more in our [blog](https://hanlab.mit.edu/blog/svdquant-nvfp4), checkout [`examples`](./examples) for usage and try [our demo](https://svdquant.mit.edu/flux1-schnell/) online!
- **[2025-02-18]** 🔥 [**Customized LoRA conversion**](#Customized-LoRA) and [**model quantization**](#Customized-Model-Quantization) instructions are now available! **[ComfyUI](./comfyui)** workflows now support **customized LoRA**, along with **FLUX.1-Tools**!
- **[2025-02-11]** 🎉 **[SVDQuant](http://arxiv.org/abs/2411.05007) has been selected as a ICLR 2025 Spotlight! FLUX.1-tools Gradio demos are now available!** Check [here](#gradio-demos) for the usage details! Our new [depth-to-image demo](https://svdquant.mit.edu/flux1-depth-dev/) is also online—try it out!
- **[2025-02-04]** **🚀 4-bit [FLUX.1-tools](https://blackforestlabs.ai/flux-1-tools/) is here!** Enjoy a **2-3× speedup** over the original models. Check out the [examples](./examples) for usage. **ComfyUI integration is coming soon!**
- **[2025-01-23]** 🚀 **4-bit [SANA](https://nvlabs.github.io/Sana/) support is here!** Experience a 2-3× speedup compared to the 16-bit model. Check out the [usage example](examples/sana1.6b_pag.py) and the [deployment guide](app/sana/t2i) for more details. Explore our live demo at [svdquant.mit.edu](https://svdquant.mit.edu)!
- **[2025-01-22]** 🎉 [**SVDQuant**](http://arxiv.org/abs/2411.05007) has been accepted to **ICLR 2025**!
- **[2024-12-08]** Support [ComfyUI](https://github.com/comfyanonymous/ComfyUI). Please check [ComfyUI-nunchaku](https://github.com/nunchaku-tech/ComfyUI-nunchaku) for the usage.
- **[2024-11-07]** 🔥 Our latest **W4A4** Diffusion model quantization work [**SVDQuant**](https://hanlab.mit.edu/projects/svdquant) is publicly released! Check [**DeepCompressor**](https://github.com/nunchaku-tech/deepcompressor) for the quantization library.
</details>
## Overview
![teaser](https://huggingface.co/datasets/nunchaku-tech/cdn/resolve/main/nunchaku/assets/teaser.jpg)
**Nunchaku** is a high-performance inference engine for low-bit neural networks. It implements **SVDQuant**, a post-training quantization technique for 4-bit weights and activations that well maintains visual fidelity. On 12B FLUX.1-dev, it achieves 3.6× memory reduction compared to the BF16 model. By eliminating CPU offloading, it offers 8.7× speedup over the 16-bit model when on a 16GB laptop 4090 GPU, 3× faster than the NF4 W4A16 baseline. On PixArt-∑, it demonstrates significantly superior visual quality over other W4A4 or even W4A8 baselines. "E2E" means the end-to-end latency including the text encoder and VAE decoder.
**SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models**<br>
[Muyang Li](https://lmxyy.me)\*, [Yujun Lin](https://yujunlin.com)\*, [Zhekai Zhang](https://hanlab.mit.edu/team/zhekai-zhang)\*, [Tianle Cai](https://www.tianle.website/#/), [Xiuyu Li](https://xiuyuli.com), [Junxian Guo](https://github.com/JerryGJX), [Enze Xie](https://xieenze.github.io), [Chenlin Meng](https://cs.stanford.edu/~chenlin/), [Jun-Yan Zhu](https://www.cs.cmu.edu/~junyanz/), and [Song Han](https://hanlab.mit.edu/songhan) <br>
*MIT, NVIDIA, CMU, Princeton, UC Berkeley, SJTU, and Pika Labs* <br>
https://github.com/user-attachments/assets/fdd4ab68-6489-4c65-8768-259bd866e8f8
## Method
#### Quantization Method -- SVDQuant
![intuition](https://huggingface.co/datasets/nunchaku-tech/cdn/resolve/main/nunchaku/assets/intuition.gif)Overview of SVDQuant. Stage1: Originally, both the activation $\boldsymbol{X}$ and weights $\boldsymbol{W}$ contain outliers, making 4-bit quantization challenging. Stage 2: We migrate the outliers from activations to weights, resulting in the updated activation $\hat{\boldsymbol{X}}$ and weights $\hat{\boldsymbol{W}}$. While $\hat{\boldsymbol{X}}$ becomes easier to quantize, $\hat{\boldsymbol{W}}$ now becomes more difficult. Stage 3: SVDQuant further decomposes $\hat{\boldsymbol{W}}$ into a low-rank component $\boldsymbol{L}_1\boldsymbol{L}_2$ and a residual $\hat{\boldsymbol{W}}-\boldsymbol{L}_1\boldsymbol{L}_2$ with SVD. Thus, the quantization difficulty is alleviated by the low-rank branch, which runs at 16-bit precision.
#### Nunchaku Engine Design
![engine](https://huggingface.co/datasets/nunchaku-tech/cdn/resolve/main/nunchaku/assets/engine.jpg) (a) Naïvely running low-rank branch with rank 32 will introduce 57% latency overhead due to extra read of 16-bit inputs in *Down Projection* and extra write of 16-bit outputs in *Up Projection*. Nunchaku optimizes this overhead with kernel fusion. (b) *Down Projection* and *Quantize* kernels use the same input, while *Up Projection* and *4-Bit Compute* kernels share the same output. To reduce data movement overhead, we fuse the first two and the latter two kernels together.
## Performance
![efficiency](https://huggingface.co/datasets/nunchaku-tech/cdn/resolve/main/nunchaku/assets/efficiency.jpg)SVDQuant reduces the 12B FLUX.1 model size by 3.6× and cuts the 16-bit model's memory usage by 3.5×. With Nunchaku, our INT4 model runs 3.0× faster than the NF4 W4A16 baseline on both desktop and laptop NVIDIA RTX 4090 GPUs. Notably, on the laptop 4090, it achieves a total 10.1× speedup by eliminating CPU offloading. Our NVFP4 model is also 3.1× faster than both BF16 and NF4 on the RTX 5090 GPU.
## Getting Started
- [Installation Guide](https://nunchaku.tech/docs/nunchaku/installation/installation.html)
- [Usage Tutorial](https://nunchaku.tech/docs/nunchaku/usage/basic_usage.html)
- [ComfyUI Plugin: ComfyUI-nunchaku](https://github.com/nunchaku-tech/ComfyUI-nunchaku)
- [Custom Model Quantization: DeepCompressor](https://github.com/nunchaku-tech/deepcompressor)
- [Gradio Demo Apps](https://github.com/nunchaku-tech/nunchaku/tree/main/app)
- [Reproduce SVDQuant Paper Results](app/flux.1/t2i)
- [API Reference](https://nunchaku.tech/docs/nunchaku/python_api/nunchaku.html)
- [Contribution Guide](https://nunchaku.tech/docs/nunchaku/developer/contribution_guide.html)
- [Frequently Asked Questions](https://nunchaku.tech/docs/nunchaku/faq/faq.html)
## Contact Us
For enterprises interested in adopting SVDQuant or Nunchaku, including technical consulting, sponsorship opportunities, or partnership inquiries, please contact us at muyangli@nunchaku.tech.
## Related Projects
- [Efficient Spatially Sparse Inference for Conditional GANs and Diffusion Models](https://arxiv.org/abs/2211.02048), NeurIPS 2022 & T-PAMI 2023
- [SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models](https://arxiv.org/abs/2211.10438), ICML 2023
- [Q-Diffusion: Quantizing Diffusion Models](https://arxiv.org/abs/2302.04304), ICCV 2023
- [AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration](https://arxiv.org/abs/2306.00978), MLSys 2024
- [DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models](https://arxiv.org/abs/2402.19481), CVPR 2024
- [QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving](https://arxiv.org/abs/2405.04532), MLSys 2025
- [SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://arxiv.org/abs/2410.10629), ICLR 2025
- [Sparse VideoGen: Accelerating Video Diffusion Transformers with Spatial-Temporal Sparsity](https://arxiv.org/abs/2502.01776), ICML 2025
- [Radial Attention: $O(n \log n)$ Sparse Attention with Energy Decay for Long Video Generation](https://github.com/mit-han-lab/radial-attention), NeurIPS 2025
- [Sparse VideoGen2: Accelerate Video Generation with Sparse Attention via Semantic-Aware Permutation](https://arxiv.org/abs/2505.18875), NeurIPS 2025
## Citation
If you find `nunchaku` useful or relevant to your research, please cite our paper:
```bibtex
@inproceedings{
li2024svdquant,
title={SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models},
author={Li*, Muyang and Lin*, Yujun and Zhang*, Zhekai and Cai, Tianle and Li, Xiuyu and Guo, Junxian and Xie, Enze and Meng, Chenlin and Zhu, Jun-Yan and Han, Song},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025}
}
```
## Acknowledgments
We thank MIT-IBM Watson AI Lab, MIT and Amazon Science Hub, MIT AI Hardware Program, National Science Foundation, Packard Foundation, Dell, LG, Hyundai, and Samsung for supporting this research. We thank NVIDIA for donating the DGX server. We thank [First Intelligence](https://www.first-intelligence.com/) and [Yotta Labs](https://www.yottalabs.ai/) for generously sponsoring our computing resources.
We use [img2img-turbo](https://github.com/GaParmar/img2img-turbo) to train the sketch-to-image LoRA. Our text-to-image and image-to-image UI is built upon [playground-v.25](https://huggingface.co/spaces/playgroundai/playground-v2.5/blob/main/app.py) and [img2img-turbo](https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py), respectively. Our safety checker is borrowed from [hart](https://github.com/mit-han-lab/hart).
Nunchaku is also inspired by many open-source libraries, including (but not limited to) [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [vLLM](https://github.com/vllm-project/vllm), [QServe](https://github.com/mit-han-lab/qserve), [AWQ](https://github.com/mit-han-lab/llm-awq), [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), and [Atom](https://github.com/efeslab/Atom).
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=nunchaku-tech/nunchaku&type=Date)](https://www.star-history.com/#nunchaku-tech/nunchaku&Date)
......@@ -26,6 +26,29 @@ class CustomBuildExtension(BuildExtension):
def get_sm_targets() -> list[str]:
is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None
if is_rocm:
# ========== ROCm / AMD 路径 ==========
# 手动指定或自动检测 AMD gfx 架构
# 注意:ROCm 不使用 "SM",而是 "gfxXXXX"
gfx_arch = os.getenv("AMDGPU_TARGETS", None)
if gfx_arch is None:
# 尝试从 PyTorch 获取(部分版本支持)
try:
# 示例:'gfx942' for MI300X
props = torch.cuda.get_device_properties(0)
# 在 ROCm 中,name 可能包含 gfx 信息,或需查表
# 这里保守起见,要求用户显式设置
raise NotImplementedError("Auto-detection of AMD arch not reliable. Set NUNCHAKU_AMD_ARCH=gfx942 etc.")
except:
raise RuntimeError(
"Running on ROCm, but NUNCHAKU_AMD_ARCH not set. "
"Please specify your AMD GPU architecture, e.g.: "
"export NUNCHAKU_AMD_ARCH=gfx942 # for MI300X"
)
return [gfx_arch] # 返回如 ["gfx942"]
else:
nvcc_path = os.path.join(CUDA_HOME, "bin/nvcc") if CUDA_HOME else "nvcc"
try:
nvcc_output = subprocess.check_output([nvcc_path, "--version"]).decode()
......@@ -48,12 +71,12 @@ def get_sm_targets() -> list[str]:
sm = f"{capability[0]}{capability[1]}"
if sm == "120" and support_sm120:
sm = "120a"
assert sm in ["75", "80", "86", "89", "120a"], f"Unsupported SM {sm}"
assert sm in ["75", "80", "86", "89", "92", "120a"], f"Unsupported SM {sm}"
if sm not in ret:
ret.append(sm)
else:
assert install_mode == "ALL"
ret = ["75", "80", "86", "89"]
ret = ["75", "80", "86", "89", "92"]
if support_sm120:
ret.append("120a")
return ret
......@@ -96,21 +119,24 @@ if __name__ == "__main__":
else:
return []
sm_targets = get_sm_targets()
print(f"Detected SM targets: {sm_targets}", file=sys.stderr)
#sm_targets = get_sm_targets()
#print(f"Detected SM targets: {sm_targets}", file=sys.stderr)
assert len(sm_targets) > 0, "No SM targets found"
#assert len(sm_targets) > 0, "No SM targets found"
GCC_FLAGS = ["-DENABLE_BF16=1", "-DBUILD_NUNCHAKU=1", "-fvisibility=hidden", "-g", "-std=c++20", "-UNDEBUG", "-Og"]
MSVC_FLAGS = ["/DENABLE_BF16=1", "/DBUILD_NUNCHAKU=1", "/std:c++20", "/UNDEBUG", "/Zc:__cplusplus", "/FS"]
GCC_FLAGS = ["-DENABLE_BF16=1", "-DBUILD_NUNCHAKU=1", "-fvisibility=hidden", "-g", "-std=c++2a", "-UNDEBUG", "-Og"]
MSVC_FLAGS = ["/DENABLE_BF16=1", "/DBUILD_NUNCHAKU=1", "/std:c++2a", "/UNDEBUG", "/Zc:__cplusplus", "/FS"]
NVCC_FLAGS = [
"-DENABLE_BF16=1",
"-DBUILD_NUNCHAKU=1",
"-g",
"-std=c++20",
"-std=c++2a",
"-UNDEBUG",
"-Xcudafe",
"--diag_suppress=20208", # spdlog: 'long double' is treated as 'double' in device code
"-mllvm",
"-nv-ptx-asm-transform=true",
"-finline-asm-ptx",
#"-Xcudafe",
#"--diag_suppress=20208", # spdlog: 'long double' is treated as 'double' in device code
*cond("-G"),
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
......@@ -120,17 +146,17 @@ if __name__ == "__main__":
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
f"--threads={len(sm_targets)}",
"--expt-relaxed-constexpr",
#f"--threads={len(sm_targets)}",
f"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--ptxas-options=--allow-expensive-optimizations=true",
#"--ptxas-options=--allow-expensive-optimizations=true",
]
if os.getenv("NUNCHAKU_BUILD_WHEELS", "0") == "0":
NVCC_FLAGS.append("--generate-line-info")
for target in sm_targets:
NVCC_FLAGS += ["-gencode", f"arch=compute_{target},code=sm_{target}"]
#for target in sm_targets:
# NVCC_FLAGS += ["-gencode", f"arch=compute_{target},code=sm_{target}"]
NVCC_MSVC_FLAGS = ["-Xcompiler", "/Zc:__cplusplus", "-Xcompiler", "/FS", "-Xcompiler", "/bigobj"]
......
......@@ -4,7 +4,8 @@
#include "kernels/zgemm/zgemm.h"
#include "flash_api.h"
#include "activation.h"
#include <nvtx3/nvToolsExt.h>
// #include <nvtx3/nvToolsExt.h>
#include <roctx.h>
#include <pybind11/functional.h>
......
......@@ -5,7 +5,10 @@
#include "kernels/awq/gemv_awq.h"
#include "kernels/dwconv.h"
#include <nvtx3/nvToolsExt.h>
#include <hip/hip_fp16.h> // 定义 __half, __bfloat16 等
#include <hip/hip_bfloat16.h> // 显式包含 bfloat16(部分 ROCm 版本需要)
// #include <nvtx3/nvToolsExt.h>
#include <roctx.h>
using namespace nunchaku;
......@@ -140,7 +143,7 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
} else if (key == "wtscale") {
assert(src.numel() == 1);
if (src.dtype() == Tensor::BF16) {
*dst.data_ptr<float>() = float(*src.data_ptr<__nv_bfloat16>());
*dst.data_ptr<float>() = float(*src.data_ptr<hip_bfloat16>());
} else if (src.dtype() == Tensor::FP16) {
*dst.data_ptr<float>() = float(*src.data_ptr<half>());
} else if (src.dtype() == Tensor::FP32) {
......
......@@ -5,7 +5,8 @@
#include "flash_api.h"
#include "kernels/misc_kernels.h"
#include <nvtx3/nvToolsExt.h>
// #include <nvtx3/nvToolsExt.h>
#include <roctx.h>
using spdlog::fmt_lib::format;
using namespace nunchaku;
......
......@@ -52,7 +52,7 @@ inline cublasStatus_t checkCUBLAS(cublasStatus_t retValue,
const std::source_location location = std::source_location::current()) {
if (retValue != CUBLAS_STATUS_SUCCESS) {
throw std::runtime_error(spdlog::fmt_lib::format(
"CUBLAS error: {} (at {}:{})", cublasGetStatusString(retValue), location.file_name(), location.line()));
"CUBLAS error: {} (at {}:{})", hipblasStatusToString(retValue), location.file_name(), location.line()));
}
return retValue;
}
......
......@@ -18,8 +18,8 @@ __global__ void silu_and_mul_kernel(scalar_t *__restrict__ out, // [...,
const int64_t token_idx_d = token_idx * int64_t(d);
const int64_t token_idx_2d = token_idx_d * 2;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx_2d + idx]);
const scalar_t y = __ldg(&input[token_idx_2d + d + idx]);
const scalar_t x = input[token_idx_2d + idx];
const scalar_t y = input[token_idx_2d + d + idx];
out[token_idx_d + idx] = silu(x) * y;
}
}
......@@ -40,8 +40,8 @@ __global__ void dequant_silu_and_mul_quant_kernel(int8_t *__restrict__ out,
const float zero = 0.0f;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const float x = (float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
const float y = (float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
const float x = (float)input[token_idx * 2 * d + idx] * scale_gate;
const float y = (float)input[token_idx * 2 * d + d + idx] * scale_up;
float t = silu(x) * y;
tmp[token_idx * d + idx] = t;
t = t > zero ? t : -t;
......@@ -63,8 +63,8 @@ __global__ void dequant_silu_and_mul_quant_kernel(int8_t *__restrict__ out,
}
} else {
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const float x = (float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
const float y = (float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
const float x = (float)input[token_idx * 2 * d + idx] * scale_gate;
const float y = (float)input[token_idx * 2 * d + d + idx] * scale_up;
out[token_idx * d + idx] = float_to_int8_rn(silu(x) * y / scale_out);
}
}
......@@ -80,7 +80,7 @@ __global__ void activation_kernel(scalar_t *__restrict__ out, // [..., d
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * d + idx]);
const scalar_t x = input[token_idx * d + idx];
out[token_idx * d + idx] = ACT_FN(x);
}
}
......
......@@ -52,9 +52,9 @@ __device__ __forceinline__ static void warp_reduce(float_t *psum, float (*out_sm
#pragma unroll
for (int i = 0; i < Num; ++i) {
// T0 + T1 + T8 + T9 + T16 + T17 + T24 + T25 (kInterleave = 4)
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 16);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 8);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 1);
fpsum[i] += __shfl_xor(fpsum[i], 16);
fpsum[i] += __shfl_xor(fpsum[i], 8);
fpsum[i] += __shfl_xor(fpsum[i], 1);
}
__syncthreads();
int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;
......
......@@ -25,7 +25,7 @@ template<typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(0xffffffff, val, mask, 32);
val += __shfl_xor(val, mask, 32);
return val;
}
......@@ -35,7 +35,7 @@ __inline__ __device__ T warpReduceSumV2(T *val) {
for (int i = 0; i < NUM; i++) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
val[i] += __shfl_xor(val[i], mask, 32);
}
return (T)(0.0f);
}
......@@ -112,7 +112,7 @@ template<typename T>
__inline__ __device__ T warpReduceMax(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val = max(val, __shfl_xor_sync(0xffffffff, val, mask, 32));
val = max(val, __shfl_xor(val, mask, 32));
return val;
}
/* Calculate the maximum of all elements in a block */
......
......@@ -20,17 +20,18 @@ __device__ __forceinline__ static void trap_unsupported_arch() {
printf("This kernel is not supported on your GPU\n");
}
__syncthreads();
__nanosleep(1000000);
__trap();
// __nanosleep(1000000);
// __trap();
__builtin_trap();
}
#if defined(ENABLE_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__device__ __forceinline__ static __nv_bfloat162
__hfma2(const __nv_bfloat162 a, const __nv_bfloat162 b, const __nv_bfloat162 c) {
trap_unsupported_arch();
return __nv_bfloat162(0.0f, 0.0f);
}
#endif
// #if defined(ENABLE_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// __device__ __forceinline__ static __nv_bfloat162
// __hfma2(const __nv_bfloat162 a, const __nv_bfloat162 b, const __nv_bfloat162 c) {
// trap_unsupported_arch();
// return __nv_bfloat162(0.0f, 0.0f);
// }
// #endif
template<typename T>
struct num_elems;
......@@ -381,7 +382,8 @@ __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_b
template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val) {
return __float2bfloat162_rn(val);
// return __float2bfloat162_rn(val);
return __float22bfloat162_rn(make_float2(val, val));
}
template<>
......
......@@ -278,8 +278,8 @@ public:
}
#pragma unroll
for (int mask = 1; mask <= 2; mask *= 2) {
maxv.x = fmaxf(maxv.x, __shfl_xor_sync(~0, maxv.x, mask));
maxv.y = fmaxf(maxv.y, __shfl_xor_sync(~0, maxv.y, mask));
maxv.x = fmaxf(maxv.x, __shfl_xor(maxv.x, mask));
maxv.y = fmaxf(maxv.y, __shfl_xor(maxv.y, mask));
}
rowmax[m].x = fmaxf(rowmax[m].x, maxv.x * scale);
rowmax[m].y = fmaxf(rowmax[m].y, maxv.y * scale);
......@@ -320,8 +320,8 @@ public:
}
#pragma unroll
for (int mask = 1; mask <= 2; mask *= 2) {
sumv.x += __shfl_xor_sync(~0, sumv.x, mask);
sumv.y += __shfl_xor_sync(~0, sumv.y, mask);
sumv.x += __shfl_xor(sumv.x, mask);
sumv.y += __shfl_xor(sumv.y, mask);
}
rowsum[m] = sumv;
}
......@@ -576,7 +576,8 @@ public:
for (int i = 0; i < SHMEM_TILES; i++) {
store<true>(&Q_shmem[warpId][i][laneId], Q[Q.size() - 1 - i]);
}
__syncwarp();
// __syncwarp();
__builtin_amdgcn_wave_barrier();
int dummy = 0;
......
......@@ -78,7 +78,8 @@ public:
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];
// __shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];
__shared__ uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128] __attribute__((aligned(128)));
constexpr int PACK_SIZE = unpack_fpsum::PACK_SIZE;
using pack_t = unpack_fpsum::pack_t;
......@@ -87,7 +88,8 @@ public:
constexpr int LANES_PER_HEAD = HEAD_DIM / PACK_SIZE;
pack_t reduce_tmp;
__shared__ alignas(128) pack_t pool[NUM_WARPS];
// __shared__ alignas(128) pack_t pool[NUM_WARPS];
__shared__ pack_t pool[NUM_WARPS] __attribute__((aligned(128)));
// load rmsnorm scales
pack_t rms;
......@@ -134,7 +136,7 @@ public:
}
#pragma unroll
for (int mask = LANES_PER_HEAD / 2; mask > 0; mask /= 2) {
sqrsum += __shfl_xor_sync(~0, sqrsum, mask);
sqrsum += __shfl_xor(sqrsum, mask);
}
sqrsum /= HEAD_DIM;
float coef = cuda_frsqrt(sqrsum + epsilon);
......@@ -329,7 +331,8 @@ public:
__shared__ half_t shmem_rmsnorm[NUM_WARPS][HEAD_DIM];
load_rmsnorm(ptr_rmsnorm_weight, &shmem_rmsnorm[warpId][0]);
__syncwarp();
// __syncwarp();
__builtin_amdgcn_wave_barrier();
rotemb_warp rotemb = load_rotemb(ptr_rotemb);
......@@ -356,8 +359,8 @@ public:
}
#pragma unroll
for (int mask = 1; mask <= 2; mask *= 2) {
sqrsum[0] += __shfl_xor_sync(~0, sqrsum[0], mask);
sqrsum[1] += __shfl_xor_sync(~0, sqrsum[1], mask);
sqrsum[0] += __shfl_xor(sqrsum[0], mask);
sqrsum[1] += __shfl_xor(sqrsum[1], mask);
}
rmsnorm_coef[head][m][0] = cuda_frsqrt(sqrsum[0] / HEAD_DIM + epsilon);
rmsnorm_coef[head][m][1] = cuda_frsqrt(sqrsum[1] / HEAD_DIM + epsilon);
......@@ -594,7 +597,8 @@ public:
for (int i = laneId; i < sizeof(shmem_vk) / sizeof(float) / NUM_WARPS; i += WARP_SIZE) {
*((&shmem_vk[warpId][0][0]) + i) = 0;
}
__syncwarp();
// __syncwarp();
__builtin_amdgcn_wave_barrier();
for (int tile_v = 0; tile_v < LITELA_V_TILES; tile_v++) {
for (int tile_k = 0; tile_k < LITELA_K_TILES; tile_k++) {
......
......@@ -529,7 +529,8 @@ public:
*reinterpret_cast<half2_t *>(&mat[row][col + 0]) = fsum.data[0];
*reinterpret_cast<half2_t *>(&mat[row][col + 8]) = fsum.data[2];
}
__syncwarp();
// __syncwarp();
__builtin_amdgcn_wave_barrier();
#pragma unroll
for (int row = 0; row < 8; row++) {
......@@ -548,7 +549,8 @@ public:
pred);
// }
}
__syncwarp();
// __syncwarp();
__builtin_amdgcn_wave_barrier();
#pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) {
......@@ -558,7 +560,8 @@ public:
*reinterpret_cast<half2_t *>(&mat[row][col + 0]) = fsum.data[1];
*reinterpret_cast<half2_t *>(&mat[row][col + 8]) = fsum.data[3];
}
__syncwarp();
// __syncwarp();
__builtin_amdgcn_wave_barrier();
#pragma unroll
for (int row = 0; row < 8; row++) {
......@@ -578,7 +581,8 @@ public:
pred);
// }
}
__syncwarp();
// __syncwarp();
__builtin_amdgcn_wave_barrier();
}
// if (enableReduce) {
// store<true>(reinterpret_cast<pack_t *>(&reduce_result[laneId * PACK_SIZE]), reduce_tmp);
......@@ -631,7 +635,8 @@ public:
}
store<true>(reinterpret_cast<packed_input *>(&mat[row][laneId * PACK_SIZE]), pack);
}
__syncwarp();
// __syncwarp();
__builtin_amdgcn_wave_barrier();
for (int n = 0; n < WARP_N_TILES; n++) {
const int row = laneId % 16;
......@@ -640,7 +645,8 @@ public:
ldmatrix(&mat[row][col], tmp);
*reinterpret_cast<uint4 *>(&out[m * WARP_N_TILES + n]) = tmp;
}
__syncwarp();
// __syncwarp();
__builtin_amdgcn_wave_barrier();
}
}
};
......@@ -674,7 +680,8 @@ public:
operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) {
const int warpId = threadIdx.x / WARP_SIZE;
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];
// __shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];
__shared__ uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128] __attribute__((aligned(128)));
const int m_offset = binfo.bm * BLOCK_M + warpId * WARP_M;
const int n_offset = binfo.bn * BLOCK_N;
......@@ -856,7 +863,7 @@ public:
using typename Base::EpilogueDefault; \
using typename Base::EpilogueNop; \
template<bool USE_BIAS, bool USE_SCALE> \
using EpilogueBias = typename Base::EpilogueBias<USE_BIAS, USE_SCALE>; \
using EpilogueBias = typename Base::template EpilogueBias<USE_BIAS, USE_SCALE>; \
using Base::mma_f16xf16_f32; \
using Base::packed_fp32_to_fp16; \
using Base::packed_fp16_to_fp32; \
......
......@@ -106,15 +106,18 @@ __device__ __forceinline__ static void store(T *addr, T val) {
}
if constexpr (sizeof(T) == 4) {
__stcg(reinterpret_cast<unsigned int *>(addr), *reinterpret_cast<unsigned int *>(&val));
// __stcg(reinterpret_cast<unsigned int *>(addr), *reinterpret_cast<unsigned int *>(&val));
*reinterpret_cast<unsigned int *>(addr) = *reinterpret_cast<unsigned int *>(&val);
return;
}
if constexpr (sizeof(T) == 8) {
__stcg(reinterpret_cast<uint2 *>(addr), *reinterpret_cast<uint2 *>(&val));
// __stcg(reinterpret_cast<uint2 *>(addr), *reinterpret_cast<uint2 *>(&val));
*reinterpret_cast<uint2 *>(addr) = *reinterpret_cast<uint2 *>(&val);
return;
}
if constexpr (sizeof(T) == 16) {
__stcg(reinterpret_cast<uint4 *>(addr), *reinterpret_cast<uint4 *>(&val));
// __stcg(reinterpret_cast<uint4 *>(addr), *reinterpret_cast<uint4 *>(&val));
*reinterpret_cast<uint4 *>(addr) = *reinterpret_cast<uint4 *>(&val);
return;
}
*addr = val;
......@@ -189,9 +192,9 @@ __device__ __forceinline__ static void unused_var(T &val, bool alwaysfalse) {
}
__device__ __forceinline__ static void ldmatrix(const void *ptr, uint4 &out) {
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(out.x), "=r"(out.y), "=r"(out.z), "=r"(out.w)
: "l"(__cvta_generic_to_shared(ptr)));
// asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
// : "=r"(out.x), "=r"(out.y), "=r"(out.z), "=r"(out.w)
// : "l"(__cvta_generic_to_shared(ptr))); // limengmeng
}
template<typename T>
......
......@@ -36,8 +36,9 @@ public:
printf("FP4 is not available on this device\n");
}
__syncthreads();
__nanosleep(1000000);
__trap();
// __nanosleep(1000000);
// __trap();
__builtin_trap();
}
static_assert(WARP_N % 32 == 0);
......@@ -121,8 +122,8 @@ public:
for (int mask = 2; mask > 0; mask /= 2) {
#pragma unroll
for (int i = 0; i < NUM_GROUPS; i++) {
maxvalue[0][i] = __hmax(maxvalue[0][i], __shfl_xor_sync(~0, maxvalue[0][i], mask));
maxvalue[1][i] = __hmax(maxvalue[1][i], __shfl_xor_sync(~0, maxvalue[1][i], mask));
maxvalue[0][i] = __hmax(maxvalue[0][i], __shfl_xor(maxvalue[0][i], mask));
maxvalue[1][i] = __hmax(maxvalue[1][i], __shfl_xor(maxvalue[1][i], mask));
}
}
// lane 0,1,2,3 / 4,5,6,7 / ... should have identical maxvalue now
......@@ -170,7 +171,7 @@ public:
for (int i = 0; i < INSN_K / INSN_M * 2; i++) {
#pragma unroll
for (int j = 0; j < 2; j++) {
qpacks[j][i] |= __shfl_xor_sync(~0, qpacks[j][i], mask);
qpacks[j][i] |= __shfl_xor(qpacks[j][i], mask);
}
}
}
......@@ -464,8 +465,8 @@ public:
}
#pragma unroll
for (int mask = 2; mask > 0; mask /= 2) {
maxvalue[0] = __hmax(maxvalue[0], __shfl_xor_sync(~0, maxvalue[0], mask));
maxvalue[1] = __hmax(maxvalue[1], __shfl_xor_sync(~0, maxvalue[1], mask));
maxvalue[0] = __hmax(maxvalue[0], __shfl_xor(maxvalue[0], mask));
maxvalue[1] = __hmax(maxvalue[1], __shfl_xor(maxvalue[1], mask));
}
maxvalue[0] = __shfl_sync(~0, maxvalue[0], laneId / 4 * 4);
maxvalue[1] = __shfl_sync(~0, maxvalue[1], laneId / 4 * 4);
......@@ -505,7 +506,7 @@ public:
for (int i = 0; i < INSN_K / INSN_M * 2; i++) {
#pragma unroll
for (int j = 0; j < 2; j++) {
qpacks[j][i] |= __shfl_xor_sync(~0, qpacks[j][i], mask);
qpacks[j][i] |= __shfl_xor(qpacks[j][i], mask);
}
}
}
......@@ -576,7 +577,7 @@ public:
for (int mask = NUM_PACKS_PER_ROW / 2; mask > 0; mask /= 2) {
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
maxvalue[i] = __hmax(maxvalue[i], __shfl_xor_sync(~0, maxvalue[i], mask));
maxvalue[i] = __hmax(maxvalue[i], __shfl_xor(maxvalue[i], mask));
}
}
......@@ -611,14 +612,16 @@ public:
}
mat[i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW][laneId % NUM_PACKS_PER_ROW] = qpack;
}
__syncwarp();
// __syncwarp();
__builtin_amdgcn_wave_barrier();
// convert to imma format
int row = laneId % 16;
int col = laneId / 16 * 4;
ldmatrix(&mat[row][col], output);
__syncwarp();
// __syncwarp();
__builtin_amdgcn_wave_barrier();
}
// each thread block (1 warp) quantize WARP_M * WARP_K tile (32 * 64)
......@@ -634,8 +637,11 @@ public:
const int row = blockIdx.x * WARP_M;
const int col = blockIdx.y * WARP_K;
__shared__ alignas(128) half_t oscale_shmem[WARP_M];
__shared__ alignas(128) uint8_t tmp_shmem[INSN_M * INSN_K / 2];
// __shared__ alignas(128) half_t oscale_shmem[WARP_M];
// __shared__ alignas(128) uint8_t tmp_shmem[INSN_M * INSN_K / 2];
__shared__ half_t oscale_shmem[WARP_M] __attribute__((aligned(128)));
__shared__ uint8_t tmp_shmem[INSN_M * INSN_K / 2] __attribute__((aligned(128)));
for (int tileId = 0; tileId < WARP_M_TILES; tileId++) {
packed_act_t tmpout;
......@@ -670,8 +676,10 @@ public:
const int col = blockIdx.x * WARP_N;
const int row = blockIdx.y * WARP_K;
__shared__ alignas(128) half_t oscale_shmem[WARP_N];
__shared__ alignas(128) uint8_t tmp_shmem[INSN_M * INSN_K / 2];
// __shared__ alignas(128) half_t oscale_shmem[WARP_N];
// __shared__ alignas(128) uint8_t tmp_shmem[INSN_M * INSN_K / 2];
__shared__ half_t oscale_shmem[WARP_M] __attribute__((aligned(128)));
__shared__ uint8_t tmp_shmem[INSN_M * INSN_K / 2] __attribute__((aligned(128)));
for (int tileId = 0; tileId < WARP_N_TILES; tileId++) {
packed_wgt_t tmpout;
......@@ -1012,10 +1020,12 @@ public:
}
}
if constexpr (!USE_FP4) {
__syncwarp();
// __syncwarp();
__builtin_amdgcn_wave_barrier();
pack_ascales(&oscale_shmem[warpId][0],
&oscales[(group * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES]);
__syncwarp();
// __syncwarp();
__builtin_amdgcn_wave_barrier();
}
}
}
......
......@@ -175,12 +175,12 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
dispatchBool(bias.valid(), [&]<bool USE_BIAS>() {
dispatchBool(wcscales.valid(), [&]<bool USE_SCALE>() {
using EpilogueBias = typename GEMM::EpilogueBias<USE_BIAS, USE_SCALE>;
using EpilogueBias = typename GEMM::template EpilogueBias<USE_BIAS, USE_SCALE>;
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code
// on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using Epilogue =
typename GEMM::EpilogueCombination<EpilogueBias, NextEpilogue, typename GEMM::EpilogueNop>;
typename GEMM::template EpilogueCombination<EpilogueBias, NextEpilogue, typename GEMM::EpilogueNop>;
return launch.template operator()<Epilogue>(
{typename EpilogueBias::Arguments{
.bias = USE_BIAS ? bias.data_ptr<packed_wscale_t>() : nullptr,
......@@ -203,7 +203,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
if (rank_up == 0) {
assert(rank_down == 0);
return launch_bias.template operator()<typename GEMM::EpilogueCombination<MidEpilogue, NextEpilogue>>(
return launch_bias.template operator()<typename GEMM::template EpilogueCombination<MidEpilogue, NextEpilogue>>(
{midArgs, nextArgs});
}
......@@ -225,7 +225,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
}
if (rank_down == 0) {
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp,
using Epilogue = typename GEMM::template EpilogueCombination<typename LoraUp::EpilogueLoraUp,
MidEpilogue,
NextEpilogue,
typename GEMM::EpilogueNop>;
......@@ -254,7 +254,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
// dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {
using LoraDown = LoraUp; // GEMM::Lora<RANK_DOWN>;
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp,
using Epilogue = typename GEMM::template EpilogueCombination<typename LoraUp::EpilogueLoraUp,
MidEpilogue,
typename LoraDown::EpilogueLoraDown,
NextEpilogue,
......@@ -286,7 +286,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
static constexpr float SHIFT_GELU = 0.171875f;
constexpr bool USE_UNSIGNED = !USE_FP4;
using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED, USE_FP4>;
using EpilogueQuantize = typename GEMM::template EpilogueQuantize<false, USE_UNSIGNED, USE_FP4>;
auto argsQuantize =
typename EpilogueQuantize::Arguments{.qout = qout.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<typename EpilogueQuantize::oscales_t>(),
......@@ -296,7 +296,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
// TODO: check if gelu is needed
if (out.valid()) {
launch_lora.template
operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>,
operator()<typename GEMM::template EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>,
typename Epilogues::EpilogueGelu>({typename GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
......@@ -376,7 +376,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
if (out_q.valid()) {
launch_lora.template
operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename Epilogues::EpiloguePackQKV>,
operator()<typename GEMM::template EpilogueCombination<EpilogueRope, typename Epilogues::EpiloguePackQKV>,
typename GEMM::EpilogueNop>(
{argsRope,
typename Epilogues::EpiloguePackQKV::Arguments{
......@@ -394,7 +394,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
{});
} else {
launch_lora
.template operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename GEMM::EpilogueDefault>,
.template operator()<typename GEMM::template EpilogueCombination<EpilogueRope, typename GEMM::EpilogueDefault>,
typename GEMM::EpilogueNop>({argsRope,
typename GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<half_t>(),
......@@ -491,7 +491,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
// dispatchVal(rank, LoraRanks(), [&]<int RANK>() {
dispatchBool(fuse_glu, [&]<bool FUSE_GLU>() {
// using Lora = typename GEMM::Lora<RANK>;
using kernel = typename GEMM::quantize_w4a4_fuse_lora_kernel<FUSE_GLU, USE_FP4>;
using kernel = typename GEMM::template quantize_w4a4_fuse_lora_kernel<FUSE_GLU, USE_FP4>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
......
......@@ -21,7 +21,7 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
// checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
......@@ -59,7 +59,7 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
// checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
......
......@@ -118,14 +118,16 @@ public:
}
mat[row][col] = qpack;
}
__syncwarp();
// __syncwarp();
__builtin_amdgcn_wave_barrier();
// convert to imma format
int row = laneId % 16;
int col = laneId / 16 * 4;
ldmatrix(&mat[row][col], output);
__syncwarp();
// __syncwarp();
__builtin_amdgcn_wave_barrier();
}
/**
......@@ -202,7 +204,7 @@ public:
#pragma unroll
for (int mask = 32 / 2; mask > 0; mask /= 2) {
maxvalue2 = __hmax2(maxvalue2, __shfl_xor_sync(~0, maxvalue2, mask));
maxvalue2 = __hmax2(maxvalue2, __shfl_xor(maxvalue2, mask));
}
return __hmax(maxvalue2.x, maxvalue2.y);
......@@ -328,13 +330,15 @@ public:
*reinterpret_cast<half_t *>(&mat[row + 8][col + 4]) = fsum.data[1];
*reinterpret_cast<half_t *>(&mat[row + 8][col + 4]) = fsum.data[3];
}
__syncwarp();
// __syncwarp();
__builtin_amdgcn_wave_barrier();
for (int row = 0; row < INSN_M; row++) {
pack_t pack = *reinterpret_cast<pack_t *>(&mat[row][laneId * PACK_SIZE]);
store(reinterpret_cast<pack_t *>(&output[(i * INSN_M + row) * stride + laneId * PACK_SIZE]), pack);
}
__syncwarp();
// __syncwarp();
__builtin_amdgcn_wave_barrier();
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment