Unverified Commit 57e50f8d authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

style: upgrade the linter (#339)

* style: reformated codes

* style: reformated codes
parent b737368d
import torch
from diffusers import FluxPipeline
from peft.tuners import lora
from vars import LORA_PATHS, SVDQ_LORA_PATHS
from nunchaku import NunchakuFluxTransformer2dModel
from vars import LORA_PATHS, SVDQ_LORA_PATHS
def hash_str_to_int(s: str) -> int:
......
......@@ -37,4 +37,4 @@ python latency.py
* Adjust the number of inference steps and the guidance scale using `-t` and `-g`, respectively. The defaults are 20 steps and a guidance scale of 5.
* You can also adjust the [PAG guidance](https://arxiv.org/abs/2403.17377) scale with `--pag-scale`. The default is 2.
* By default, the script measures the end-to-end latency for generating a single image. To measure the latency of a single DiT forward step instead, use the `--mode step` flag.
* Specify the number of warmup and test runs using `--warmup-times` and `--test-times`. The defaults are 2 warmup runs and 10 test runs.
\ No newline at end of file
* Specify the number of warmup and test runs using `--warmup-times` and `--test-times`. The defaults are 2 warmup runs and 10 test runs.
......@@ -6,4 +6,4 @@ h2{text-align:center}
#accessibility {
text-align: center; /* Center-aligns the text */
margin: auto; /* Centers the element horizontally */
}
\ No newline at end of file
}
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg"
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/>
<a href='https://nvlabs.github.io/Sana/' target="_blank">SANA-1.6B</a> Demo
......@@ -50,4 +50,4 @@
</div>
{count_info}
</div>
</div>
\ No newline at end of file
</div>
......@@ -2,7 +2,6 @@ import argparse
import os
import torch
from utils import get_pipeline
......
......@@ -4,7 +4,6 @@ import time
import torch
from torch import nn
from tqdm import trange
from utils import get_pipeline
......
......@@ -8,13 +8,13 @@ from datetime import datetime
import GPUtil
import spaces
import torch
from nunchaku.models.safety_checker import SafetyChecker
from utils import get_pipeline
from vars import EXAMPLES, MAX_SEED
from nunchaku.models.safety_checker import SafetyChecker
# import gradio last to avoid conflicts with other imports
import gradio as gr
import gradio as gr # noqa: isort: skip
def get_args() -> argparse.Namespace:
......@@ -73,7 +73,7 @@ def generate(
prompt = "A peaceful world."
images, latency_strs = [], []
for i, pipeline in enumerate(pipelines):
progress = gr.Progress(track_tqdm=True)
gr.Progress(track_tqdm=True)
start_time = time.time()
image = pipeline(
prompt=prompt,
......@@ -124,11 +124,11 @@ if len(gpus) > 0:
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
with gr.Blocks(
css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"],
title=f"SVDQuant SANA-1600M Demo",
title="SVDQuant SANA-1600M Demo",
) as demo:
def get_header_str():
......
......@@ -46,4 +46,4 @@
</g>
<path d="M418.39,22.56c-.9-2.12-3.08-3.99-2.86-6.3.6-6.24-1.96-9.26-5.87-10.8-5.59-2.76-10.79-2.48-15.59.89-5.16,3.63-6.9,8.92-5.88,15.06-3.44,1.79-6.77,3.46-10.03,5.27-1.04.58-1.67.45-2.57-.24-4.36-3.31-9.77-3.35-14.45-.38-2.92,1.85-5.92,3.61-8.99,5.2-4.67,2.41-8.51,5.37-9.23,11.06-.06.44-.81,1.01-1.34,1.15-2.64.72-5.32,1.29-7.97,1.98-1.09.28-1.8-.03-2.5-.87-3.33-4.01-7.59-5.28-12.62-4.14-3.55.8-7.1,1.63-10.65,2.41-4.53.99-8.9,2.23-11.5,6.61-.14.23-.76.32-1.12.26-3.14-.54-6.26-1.14-9.44-1.73-.4-4.66-2.91-7.77-6.66-10.13-3.81-2.39-7.54-4.92-11.29-7.41-2.5-1.65-5.47-2.9-8.14-1.91-3.92,1.46-5.66-.68-7.62-3.11-.53-.65-1.1-1.28-1.71-1.87-.91-.89-1.15-1.7-.63-3.04,2.56-6.58-1.25-14.13-8-16.06-4.78-1.36-9.57-2.67-14.37-3.94-6.58-1.74-12.14.91-14.99,7.05-.24.51-.79,1.18-1.25,1.23-1.63.18-3.26.33-4.89.46.01.52.01,1.04.01,1.56,4.44-1,8.77-1.17,13.19-.6-1.82,1.27-8.29,2.27-13.22,2.36-.04,1.47-.13,2.95-.23,4.43,4.6-.4,9.19-.79,13.79-1.19.01.08.02.15.03.23-2.2.7-4.39,1.39-6.62,2.09,1.3,2.68,3.69,4.83,6.67,5.69,5.33,1.55,10.69,3.06,16.09,4.37,1.72.42,3.61.13,5.84.18-1.34-2.39-2.39-4.26-3.44-6.13l.3-.23c5.72,6.3,11.43,12.61,17.15,18.91-.06.07-.12.13-.18.2-2.04-1.41-4.09-2.82-6.2-4.27-1.71,5.48.04,10.66,4.66,13.84,4.3,2.96,8.67,5.81,13.05,8.64,5.02,3.25,12.27,1.96,15.19-2.14-2.16-.92-4.3-1.83-6.44-2.74.05-.15.11-.3.16-.45,6.02,1.12,12.04,2.21,18.04,3.4.43.09.91.85,1.05,1.39,1.65,6.24,7.78,10.23,14.06,8.93,4.97-1.03,9.89-2.3,14.84-3.41,4.98-1.12,8.06-4.16,9.57-9.25-2.61.09-5,.18-7.4.27l-.02-.24,27-6.51c.05.15.09.31.14.46l-6.85,3.18c3.69,3.77,9.13,4.98,13.57,2.64,5.32-2.8,10.5-5.87,15.62-9.01,2.83-1.74,5.21-6.46,4.49-8.99-2.38.52-4.76,1.04-7.15,1.57-.01-.08-.03-.16-.04-.24l24.55-13.02.16.19c-1.43,1.36-2.86,2.72-4.35,4.14,4.09,3.31,8.57,4.15,13.26,2.79,5.85-1.7,9.32-5.87,10.62-12.29.39.9.81,1.74,1.2,2.55ZM240.66,6.17c2.19-1.05,6.89,2.57,6.7,5.28-2.92-.11-5.18-1.48-7-3.61-.24-.3-.01-1.52.3-1.67ZM236.31,14.54c-1.54,1.54-1.21,3.32.9,6.16-5.49-1.54-10.72-3-15.95-4.46.03-.17.07-.35.1-.52,2.43-.24,5.06-.28,5.67-3.36.39-1.94-.51-3.39-2.17-4.55,2.51.68,5.01,1.35,7.52,2.03,2.26.62,4.57,1.13,6.77,1.94,1.26.46,2.34,1.39,3.48,1.83-1.1-.18-2.23-.61-3.28-.46-1.08.15-2.29.64-3.04,1.39ZM243.02,19.76c3.02.35,11.2,8.77,12.25,12.7-4.84-3.4-8.69-7.74-12.25-12.7ZM271.35,48.21c-.99,2.02-.01,3.61,1.22,5.22-5.37-3.34-10.84-6.47-15.54-10.72.94.54,1.85,1.43,2.84,1.53,1.04.11,2.39-.23,3.21-.87,1.98-1.55,1.71-3.13-.61-7.24,4.91,3.25,9.83,6.5,14.74,9.76-2.44-.05-4.65-.17-5.86,2.32ZM267.38,32.23c4.46,2.84,9.48,4.89,13.41,9.32-2.49.4-12.99-7.11-13.41-9.32ZM284.99,50.83c3.61-1.39,15.07.42,17.7,2.77-5.94.19-11.65-.91-17.7-2.77ZM322.43,48.01c-2.55,1.22-3.64,2.83-3.16,4.68.58,2.26,2.21,3.21,5.16,3.2-6.25,1.93-12.54,3.69-19.16,4.1,2.4-.49,4.56-1.22,4.65-4.09.1-2.89-1.86-4.04-4.44-4.56,5.59-1.28,11.18-2.56,16.76-3.83.06.16.13.33.19.5ZM315.23,43.15c2.4-2.34,6.44-2.95,8.44-1.33-1.16,2.42-6.21,3.29-8.44,1.33ZM333.09,48.29c5.19-3.09,10.81-4.61,16.85-4.57-5.26,2.89-10.96,4.09-16.85,4.57ZM371.58,39.47l-15.81,9.08c-.12-.12-.24-.24-.36-.36,2.07-1.36,3.17-3.17,2.04-5.48-1.15-2.36-3.34-2.39-5.68-1.99,5.35-3.33,10.55-6.82,16.39-9.16-1.98,1.91-2.68,3.81-1.86,5.56.82,1.73,2.46,2.39,5.28,2.35ZM370.85,27.31c-2,.5-4.03.9-6.07,1.18-.43.06-1.37-.52-1.35-.76.03-.55.45-1.12.83-1.59.23-.28.67-.38,1.02-.57v-.42c1.79,0,3.58-.04,5.36.07.42.02.8.55,1.2.84-.33.43-.58,1.15-.99,1.25ZM378.71,29.44c4.29-4.26,9.38-7.12,15.26-8.59-4.37,4.11-9.65,6.64-15.26,8.59ZM391.92,14.77c-.33.39-1.13.37-1.71.54-.13-.58-.44-1.19-.34-1.73.4-2.33,2.42-4.9,4.89-6.03.17,0,.77.02,1.38.03-.03.62.17,1.4-.12,1.83-1.28,1.85-2.65,3.64-4.1,5.36ZM407.84,23.73c-1.86,1.82-5.89,3.26-8.87,1.19.94-1.27,2.06-2.44,2.73-3.83.31-.64-.06-1.82-.47-2.57-1.06-1.94-3.17-2.19-6.12-.83.01-3.35,2.27-5.98,5.73-6.88,3.25-.84,6.83.81,8.56,3.94,1.53,2.76.85,6.6-1.56,8.98Z"/>
<circle class="cls-1" cx="206.14" cy="15.03" r="8.22"/>
</svg>
\ No newline at end of file
</svg>
......@@ -23,4 +23,4 @@
<polygon class="cls-1" points="538.46 21.84 538.46 39.47 503.18 0 484.49 0 484.49 63.62 502.88 63.62 502.88 24.15 538.26 63.62 556.85 63.62 556.85 21.84 538.46 21.84"/>
<rect class="cls-2" x="538.46" width="18.39" height="14.25"/>
<path class="cls-1" d="M565.55,14.12V0h67.25v14.12h-23.48v49.5h-18.39V14.12h-25.38Z"/>
</svg>
\ No newline at end of file
</svg>
......@@ -40,4 +40,4 @@ For detailed guidance on testing, refer to the [`tests/README.md`](../tests/READ
## Acknowledgments
This contribution guide is adapted from [SGLang](https://docs.sglang.ai/references/contribution_guide.html). We thank them for the inspiration.
\ No newline at end of file
This contribution guide is adapted from [SGLang](https://docs.sglang.ai/references/contribution_guide.html). We thank them for the inspiration.
......@@ -62,17 +62,17 @@ Then verify the Python version and installed PyTorch version:
Install PyTorch appropriate for your setup
- **For most users**:
```bash
"G:\ComfyuI\python\python.exe" -m pip install torch==2.6 torchvision==0.21 torchaudio==2.6
```
- **For RTX 50-series GPUs** (requires PyTorch ≥2.7 with CUDA 12.8):
```bash
"G:\ComfyuI\python\python.exe" -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
```
## Step 3: Install Nunchaku
......@@ -108,55 +108,55 @@ You can also run a test (requires a Hugging Face token for downloading the model
Please use CMD instead of PowerShell for building.
- Step 1: Install Build Tools
```bash
C:\Users\muyang\miniconda3\envs\comfyui\python.exe
"G:\ComfyuI\python\python.exe" -m pip install ninja setuptools wheel build
```
- Step 2: Clone the Repository
```bash
git clone https://github.com/mit-han-lab/nunchaku.git
cd nunchaku
git submodule init
git submodule update
```
- Step 3: Set Up Visual Studio Environment
Locate the `VsDevCmd.bat` script on your system. Example path:
```
C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\VsDevCmd.bat
```
Then run:
```bash
"C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\VsDevCmd.bat" -startdir=none -arch=x64 -host_arch=x64
set DISTUTILS_USE_SDK=1
```
- Step 4: Build Nunchaku
```bash
"G:\ComfyuI\python\python.exe" setup.py develop
```
Verify with:
```bash
"G:\ComfyuI\python\python.exe" -c "import nunchaku"
```
You can also run a test (requires a Hugging Face token for downloading the models):
```bash
"G:\ComfyuI\python\python.exe" -m huggingface-cli login
"G:\ComfyuI\python\python.exe" -m nunchaku.test
```
- (Optional) Step 5: Building wheel for Portable Python
If building directly with portable Python fails, you can first build the wheel in a working Conda environment, then install the `.whl` file using your portable Python:
......@@ -182,42 +182,42 @@ Alternatively, install using [ComfyUI-Manager](https://github.com/Comfy-Org/Comf
## 2. Download Models
- **Standard FLUX.1-dev Models**
Start by downloading the standard [FLUX.1-dev text encoders](https://huggingface.co/comfyanonymous/flux_text_encoders/tree/main) and [VAE](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/ae.safetensors). You can also optionally download the original [BF16 FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors) model. An example command:
```bash
huggingface-cli download comfyanonymous/flux_text_encoders clip_l.safetensors --local-dir models/text_encoders
huggingface-cli download comfyanonymous/flux_text_encoders t5xxl_fp16.safetensors --local-dir models/text_encoders
huggingface-cli download black-forest-labs/FLUX.1-schnell ae.safetensors --local-dir models/vae
huggingface-cli download black-forest-labs/FLUX.1-dev flux1-dev.safetensors --local-dir models/diffusion_models
```
- **SVDQuant 4-bit FLUX.1-dev Models**
Next, download the SVDQuant 4-bit models:
- For **50-series GPUs**, use the [FP4 model](https://huggingface.co/mit-han-lab/svdq-fp4-flux.1-dev).
- For **other GPUs**, use the [INT4 model](https://huggingface.co/mit-han-lab/svdq-int4-flux.1-dev).
Make sure to place the **entire downloaded folder** into `models/diffusion_models`. For example:
```bash
huggingface-cli download mit-han-lab/svdq-int4-flux.1-dev --local-dir models/diffusion_models/svdq-int4-flux.1-dev
```
- **(Optional): Download Sample LoRAs**
You can test with some sample LoRAs like [FLUX.1-Turbo](https://huggingface.co/alimama-creative/FLUX.1-Turbo-Alpha/blob/main/diffusion_pytorch_model.safetensors) and [Ghibsky](https://huggingface.co/aleksa-codes/flux-ghibsky-illustration/blob/main/lora.safetensors). Place these files in the `models/loras` directory:
```bash
huggingface-cli download alimama-creative/FLUX.1-Turbo-Alpha diffusion_pytorch_model.safetensors --local-dir models/loras
huggingface-cli download aleksa-codes/flux-ghibsky-illustration lora.safetensors --local-dir models/loras
```
## 3. Set Up Workflows
To use the official workflows, download them from the [ComfyUI-nunchaku](https://github.com/mit-han-lab/ComfyUI-nunchaku/tree/main/workflows) and place them in your `ComfyUI/user/default/workflows` directory. The command can be
To use the official workflows, download them from the [ComfyUI-nunchaku](https://github.com/mit-han-lab/ComfyUI-nunchaku/tree/main/workflows) and place them in your `ComfyUI/user/default/workflows` directory. The command can be
```bash
# From the root of your ComfyUI folder
......@@ -231,4 +231,4 @@ You can now launch ComfyUI and try running the example workflows.
If you encounter issues, refer to our:
- [FAQs](https://github.com/mit-han-lab/nunchaku/discussions/262)
- [GitHub Issues](https://github.com/mit-han-lab/nunchaku/issues)
\ No newline at end of file
- [GitHub Issues](https://github.com/mit-han-lab/nunchaku/issues)
......@@ -4,7 +4,6 @@ from diffusers.models import FluxMultiControlNetModel
from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.diffusers_adapters.flux import apply_cache_on_pipe
from nunchaku.utils import get_gpu_memory, get_precision
base_model = "black-forest-labs/FLUX.1-dev"
......@@ -29,11 +28,6 @@ if need_offload:
else:
pipeline = pipeline.to("cuda")
# apply_cache_on_pipe(
# pipeline, residual_diff_threshold=0.1
# ) # Uncomment this line to enable first-block cache to speedup generation
prompt = "A anime style girl with messy beach waves."
control_image_depth = load_image(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg"
......
......@@ -7,14 +7,10 @@ from nunchaku.utils import get_precision
precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-dev"
)
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
apply_cache_on_pipe(
......@@ -24,9 +20,6 @@ apply_cache_on_pipe(
residual_diff_threshold_single=0.12,
)
image = pipeline(
["A cat holding a sign that says hello world"],
num_inference_steps=50
).images[0]
image = pipeline(["A cat holding a sign that says hello world"], num_inference_steps=50).images[0]
image.save(f"flux.1-dev-cache-{precision}.png")
from .models import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel, NunchakuT5EncoderModel
__all__ = ["NunchakuFluxTransformer2dModel", "NunchakuSanaTransformer2DModel", "NunchakuT5EncoderModel"]
......@@ -20,7 +20,8 @@ public:
ModuleWrapper::init(deviceId);
CUDADeviceContext ctx(this->deviceId);
net = std::make_unique<FluxModel>(use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
net = std::make_unique<FluxModel>(
use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
bool isBF16() {
......@@ -28,52 +29,50 @@ public:
return net->dtype == Tensor::BF16;
}
pybind11::function residual_callback;
void set_residual_callback(pybind11::function callback) {
void set_residual_callback(pybind11::function callback) {
pybind11::gil_scoped_acquire gil;
if (!callback || callback.is_none()) {
residual_callback = pybind11::function();
if (net){
if (net) {
net->set_residual_callback(nullptr);
}
return;
}
residual_callback = std::move(callback);
}
residual_callback = std::move(callback);
if (net) {
pybind11::object cb = residual_callback;
net->set_residual_callback([cb](const Tensor &x) -> Tensor {
pybind11::object cb = residual_callback;
net->set_residual_callback([cb](const Tensor &x) -> Tensor {
pybind11::gil_scoped_acquire gil;
torch::Tensor torch_x = to_torch(x);
torch::Tensor torch_x = to_torch(x);
pybind11::object result = cb(torch_x);
torch::Tensor torch_y = result.cast<torch::Tensor>();
Tensor y = from_torch(torch_y);
torch::Tensor torch_y = result.cast<torch::Tensor>();
Tensor y = from_torch(torch_y);
return y;
});
} else {
}
}
torch::Tensor forward(
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context,
torch::Tensor rotary_emb_single,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt,
bool skip_first_layer = false)
{
torch::Tensor forward(torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context,
torch::Tensor rotary_emb_single,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt,
bool skip_first_layer = false) {
checkModel();
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward");
hidden_states = hidden_states.contiguous();
hidden_states = hidden_states.contiguous();
encoder_hidden_states = encoder_hidden_states.contiguous();
temb = temb.contiguous();
rotary_emb_img = rotary_emb_img.contiguous();
rotary_emb_context = rotary_emb_context.contiguous();
rotary_emb_single = rotary_emb_single.contiguous();
temb = temb.contiguous();
rotary_emb_img = rotary_emb_img.contiguous();
rotary_emb_context = rotary_emb_context.contiguous();
rotary_emb_single = rotary_emb_single.contiguous();
Tensor result = net->forward(
from_torch(hidden_states),
......@@ -83,9 +82,10 @@ public:
from_torch(rotary_emb_context),
from_torch(rotary_emb_single),
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value() ? from_torch(controlnet_single_block_samples.value().contiguous()) : Tensor{},
skip_first_layer
);
controlnet_single_block_samples.has_value()
? from_torch(controlnet_single_block_samples.value().contiguous())
: Tensor{},
skip_first_layer);
torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice();
......@@ -93,25 +93,24 @@ public:
return output;
}
std::tuple<torch::Tensor, torch::Tensor> forward_layer(
int64_t idx,
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt)
{
std::tuple<torch::Tensor, torch::Tensor>
forward_layer(int64_t idx,
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt) {
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward_layer {}", idx);
hidden_states = hidden_states.contiguous();
hidden_states = hidden_states.contiguous();
encoder_hidden_states = encoder_hidden_states.contiguous();
temb = temb.contiguous();
rotary_emb_img = rotary_emb_img.contiguous();
rotary_emb_context = rotary_emb_context.contiguous();
temb = temb.contiguous();
rotary_emb_img = rotary_emb_img.contiguous();
rotary_emb_context = rotary_emb_context.contiguous();
auto &&[hidden_states_, encoder_hidden_states_] = net->forward_layer(
idx,
......@@ -121,35 +120,31 @@ public:
from_torch(rotary_emb_img),
from_torch(rotary_emb_context),
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value() ? from_torch(controlnet_single_block_samples.value().contiguous()) : Tensor{}
);
controlnet_single_block_samples.has_value()
? from_torch(controlnet_single_block_samples.value().contiguous())
: Tensor{});
hidden_states = to_torch(hidden_states_);
hidden_states = to_torch(hidden_states_);
encoder_hidden_states = to_torch(encoder_hidden_states_);
Tensor::synchronizeDevice();
return { hidden_states, encoder_hidden_states };
return {hidden_states, encoder_hidden_states};
}
torch::Tensor forward_single_layer(
int64_t idx,
torch::Tensor hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_single)
{
torch::Tensor forward_single_layer(int64_t idx,
torch::Tensor hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_single) {
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx);
hidden_states = hidden_states.contiguous();
temb = temb.contiguous();
hidden_states = hidden_states.contiguous();
temb = temb.contiguous();
rotary_emb_single = rotary_emb_single.contiguous();
Tensor result = net->single_transformer_blocks.at(idx)->forward(
from_torch(hidden_states),
from_torch(temb),
from_torch(rotary_emb_single)
);
from_torch(hidden_states), from_torch(temb), from_torch(rotary_emb_single));
hidden_states = to_torch(result);
Tensor::synchronizeDevice();
......@@ -159,19 +154,15 @@ public:
// expose the norm1 forward method of the transformer blocks
// this is used by TeaCache to get the norm1 output
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> norm_one_forward(
int64_t idx,
torch::Tensor hidden_states,
torch::Tensor temb
) {
AdaLayerNormZero::Output result = net->transformer_blocks.at(idx)->norm1.forward(from_torch(hidden_states), from_torch(temb));
return {
to_torch(result.x),
to_torch(result.gate_msa),
to_torch(result.shift_mlp),
to_torch(result.scale_mlp),
to_torch(result.gate_mlp)
};
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
norm_one_forward(int64_t idx, torch::Tensor hidden_states, torch::Tensor temb) {
AdaLayerNormZero::Output result =
net->transformer_blocks.at(idx)->norm1.forward(from_torch(hidden_states), from_torch(temb));
return {to_torch(result.x),
to_torch(result.gate_msa),
to_torch(result.shift_mlp),
to_torch(result.scale_mlp),
to_torch(result.gate_mlp)};
}
// must be called after loading lora
......@@ -214,5 +205,4 @@ public:
throw std::invalid_argument(spdlog::fmt_lib::format("Invalid attention implementation {}", name));
}
}
};
\ No newline at end of file
};
......@@ -16,7 +16,12 @@ public:
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val);
net = std::make_unique<GEMM_W4A4>((int)in_features, (int)out_features, bias, use_fp4, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
net = std::make_unique<GEMM_W4A4>((int)in_features,
(int)out_features,
bias,
use_fp4,
bf16 ? Tensor::BF16 : Tensor::FP16,
Device::cuda((int)deviceId));
}
torch::Tensor forward(torch::Tensor x) {
......@@ -53,11 +58,11 @@ public:
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4)
constexpr int BLOCK_M = 256;
constexpr int WARP_K = 64;
constexpr int NUM_WARPS = 8;
constexpr int BLOCK_M = 256;
constexpr int WARP_K = 64;
constexpr int NUM_WARPS = 8;
constexpr int WARP_M_TILES = 2;
constexpr int WARP_SIZE = 32;
constexpr int WARP_SIZE = 32;
std::stringstream ss;
for (int bm = 0; bm < M / BLOCK_M; bm++) {
......@@ -95,13 +100,10 @@ public:
x = x.contiguous();
auto qout = net->quantize(
from_torch(x),
fuse_glu
);
auto qout = net->quantize(from_torch(x), fuse_glu);
Tensor act = qout.act.copy(Device::cpu());
Tensor ascales = qout.ascales.copy(Device::cpu());
Tensor act = qout.act.copy(Device::cpu());
Tensor ascales = qout.ascales.copy(Device::cpu());
Tensor lora_act = qout.lora_act.copy(Device::cpu());
Tensor::synchronizeDevice();
......@@ -109,5 +111,4 @@ public:
spdlog::debug("act = {}", dumpTensorINT4(act));
spdlog::debug("ascales = {}", dumpTensorBF16(ascales));
}
};
......@@ -10,13 +10,14 @@ class QuantizedGEMM88 : public ModuleWrapper<GEMM_W8A8> {
public:
void init(int64_t in_features, int64_t out_features, bool bias, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedGEMM88");
size_t val = 0;
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val);
net = std::make_unique<GEMM_W8A8>((int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
net = std::make_unique<GEMM_W8A8>(
(int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
torch::Tensor forward(torch::Tensor x) {
......@@ -27,10 +28,10 @@ public:
x = x.contiguous();
Tensor result = net->forward(from_torch(x));
torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice();
return output;
}
};
\ No newline at end of file
};
......@@ -18,7 +18,7 @@ public:
debugContext.reset();
net.reset();
Tensor::synchronizeDevice();
nunchaku::utils::trim_memory();
Tensor::synchronizeDevice();
}
......@@ -28,7 +28,7 @@ public:
CUDADeviceContext ctx(this->deviceId);
spdlog::info("{} weights from {}", partial ? "Loading partial" : "Loading", path);
std::shared_ptr<SafeTensors> provider = std::make_shared<SafeTensors>(path);
net->loadParams(*provider, partial);
Tensor::synchronizeDevice();
......@@ -41,7 +41,7 @@ public:
CUDADeviceContext ctx(this->deviceId);
spdlog::info("{} weights from pytorch", partial ? "Loading partial" : "Loading");
std::shared_ptr<TensorsProviderTorch> provider = std::make_shared<TensorsProviderTorch>(std::move(dict));
net->loadParams(*provider, partial);
Tensor::synchronizeDevice();
......@@ -66,7 +66,7 @@ public:
result[key] = to_torch(value);
}
}
return result;
}
......@@ -82,4 +82,4 @@ protected:
std::unique_ptr<DebugContext> debugContext;
int deviceId = -1;
};
\ No newline at end of file
};
......@@ -7,175 +7,132 @@
namespace nunchaku::ops {
void gemm_w4a4(
std::optional<torch::Tensor> act, // packed act [M, K / 2]
std::optional<torch::Tensor> wgt, // packed act [N, K / 2]
std::optional<torch::Tensor> out, // linear [M, N]
std::optional<torch::Tensor> qout, // packed act [M, N / 2]
std::optional<torch::Tensor> ascales, // packed as [K / 64, M]
std::optional<torch::Tensor> wscales, // packed ws [K / 64, N]
std::optional<torch::Tensor> oscales, // packed as [N / 64, M]
std::optional<torch::Tensor> poolout, // linear [M / PoolSize, N]
std::optional<torch::Tensor> lora_act_in, // packed lora_act [M, R]
std::optional<torch::Tensor> lora_up, // packed lora_wgt [N, R]
std::optional<torch::Tensor> lora_down, // packed lora_wgt [N, R]
std::optional<torch::Tensor> lora_act_out, // packed lora_act [M, R]
std::optional<torch::Tensor> norm_q, // linear [HEAD_DIM]
std::optional<torch::Tensor> norm_k, // linear [HEAD_DIM]
std::optional<torch::Tensor> rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
std::optional<torch::Tensor> bias, // packed ws [N]
std::optional<torch::Tensor> smooth_factor, // packed ws [N], for quantization of the next layer
std::optional<torch::Tensor> out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
std::optional<torch::Tensor> out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales,
bool fuse_silu,
bool fp4,
float alpha,
std::optional<torch::Tensor> wcscales,
std::optional<torch::Tensor> out_q, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_k, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_v, // packed attention [B, H, M, D]
int attn_tokens
) {
spdlog::trace("running gemm_w4a4: ");
void gemm_w4a4(std::optional<torch::Tensor> act, // packed act [M, K / 2]
std::optional<torch::Tensor> wgt, // packed act [N, K / 2]
std::optional<torch::Tensor> out, // linear [M, N]
std::optional<torch::Tensor> qout, // packed act [M, N / 2]
std::optional<torch::Tensor> ascales, // packed as [K / 64, M]
std::optional<torch::Tensor> wscales, // packed ws [K / 64, N]
std::optional<torch::Tensor> oscales, // packed as [N / 64, M]
std::optional<torch::Tensor> poolout, // linear [M / PoolSize, N]
std::optional<torch::Tensor> lora_act_in, // packed lora_act [M, R]
std::optional<torch::Tensor> lora_up, // packed lora_wgt [N, R]
std::optional<torch::Tensor> lora_down, // packed lora_wgt [N, R]
std::optional<torch::Tensor> lora_act_out, // packed lora_act [M, R]
std::optional<torch::Tensor> norm_q, // linear [HEAD_DIM]
std::optional<torch::Tensor> norm_k, // linear [HEAD_DIM]
std::optional<torch::Tensor> rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
std::optional<torch::Tensor> bias, // packed ws [N]
std::optional<torch::Tensor> smooth_factor, // packed ws [N], for quantization of the next layer
std::optional<torch::Tensor> out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
std::optional<torch::Tensor> out_linearattn, // linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales,
bool fuse_silu,
bool fp4,
float alpha,
std::optional<torch::Tensor> wcscales,
std::optional<torch::Tensor> out_q, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_k, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_v, // packed attention [B, H, M, D]
int attn_tokens) {
spdlog::trace("running gemm_w4a4: ");
auto getTensor = [](std::optional<torch::Tensor> &t) {
Tensor ret = t.has_value() ? from_torch(t.value()) : Tensor{};
if (ret.valid()) {
spdlog::trace(" {}", ret.shape.str());
} else {
spdlog::trace(" <invalid>");
}
return ret;
};
nunchaku::kernels::gemm_w4a4(
getTensor(act ),
getTensor(wgt ),
getTensor(out ),
getTensor(qout ),
getTensor(ascales ),
getTensor(wscales ),
getTensor(oscales ),
getTensor(poolout ),
getTensor(lora_act_in ),
getTensor(lora_up ),
getTensor(lora_down ),
getTensor(lora_act_out ),
getTensor(norm_q ),
getTensor(norm_k ),
getTensor(rotary_emb ),
getTensor(bias ),
getTensor(smooth_factor),
getTensor(out_vk ),
getTensor(out_linearattn),
act_unsigned,
lora_scales,
fuse_silu,
fp4,
alpha,
getTensor(wcscales),
getTensor(out_q),
getTensor(out_k),
getTensor(out_v),
attn_tokens
);
// Tensor::synchronizeDevice();
}
auto getTensor = [](std::optional<torch::Tensor> &t) {
Tensor ret = t.has_value() ? from_torch(t.value()) : Tensor{};
if (ret.valid()) {
spdlog::trace(" {}", ret.shape.str());
} else {
spdlog::trace(" <invalid>");
}
return ret;
};
nunchaku::kernels::gemm_w4a4(getTensor(act),
getTensor(wgt),
getTensor(out),
getTensor(qout),
getTensor(ascales),
getTensor(wscales),
getTensor(oscales),
getTensor(poolout),
getTensor(lora_act_in),
getTensor(lora_up),
getTensor(lora_down),
getTensor(lora_act_out),
getTensor(norm_q),
getTensor(norm_k),
getTensor(rotary_emb),
getTensor(bias),
getTensor(smooth_factor),
getTensor(out_vk),
getTensor(out_linearattn),
act_unsigned,
lora_scales,
fuse_silu,
fp4,
alpha,
getTensor(wcscales),
getTensor(out_q),
getTensor(out_k),
getTensor(out_v),
attn_tokens);
// Tensor::synchronizeDevice();
}
void attention_fp16(
torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
torch::Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
torch::Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
torch::Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
float scale
) {
nunchaku::kernels::attention_fp16(
from_torch(q),
from_torch(k),
from_torch(v),
from_torch(o),
scale
);
}
void attention_fp16(torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
torch::Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
torch::Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
torch::Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
float scale) {
nunchaku::kernels::attention_fp16(from_torch(q), from_torch(k), from_torch(v), from_torch(o), scale);
}
torch::Tensor gemv_awq(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int64_t m,
int64_t n,
int64_t k,
int64_t group_size)
{
Tensor result = ::gemv_awq(
from_torch(_in_feats.contiguous()),
from_torch(_kernel.contiguous()),
from_torch(_scaling_factors.contiguous()),
from_torch(_zeros.contiguous()),
(int)m,
(int)n,
(int)k,
(int)group_size
);
torch::Tensor gemv_awq(torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int64_t m,
int64_t n,
int64_t k,
int64_t group_size) {
Tensor result = ::gemv_awq(from_torch(_in_feats.contiguous()),
from_torch(_kernel.contiguous()),
from_torch(_scaling_factors.contiguous()),
from_torch(_zeros.contiguous()),
(int)m,
(int)n,
(int)k,
(int)group_size);
torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice();
torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice();
return output;
}
return output;
}
torch::Tensor gemm_awq(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros)
{
Tensor result = ::awq_gemm_forward_cuda(
from_torch(_in_feats.contiguous()),
from_torch(_kernel.contiguous()),
from_torch(_scaling_factors.contiguous()),
from_torch(_zeros.contiguous())
);
torch::Tensor
gemm_awq(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros) {
Tensor result = ::awq_gemm_forward_cuda(from_torch(_in_feats.contiguous()),
from_torch(_kernel.contiguous()),
from_torch(_scaling_factors.contiguous()),
from_torch(_zeros.contiguous()));
// TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy)
torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice();
// TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy)
torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice();
return output;
}
return output;
}
void test_rmsnorm_rope(
torch::Tensor input,
torch::Tensor output,
torch::Tensor norm_q,
torch::Tensor norm_k,
torch::Tensor rotary_emb)
{
nunchaku::kernels::test_rmsnorm_rope(
from_torch(input),
from_torch(output),
from_torch(norm_q),
from_torch(norm_k),
from_torch(rotary_emb)
);
}
void test_rmsnorm_rope(
torch::Tensor input, torch::Tensor output, torch::Tensor norm_q, torch::Tensor norm_k, torch::Tensor rotary_emb) {
nunchaku::kernels::test_rmsnorm_rope(
from_torch(input), from_torch(output), from_torch(norm_q), from_torch(norm_k), from_torch(rotary_emb));
}
void test_pack_qkv(
torch::Tensor input,
torch::Tensor out_q,
torch::Tensor out_k,
torch::Tensor out_v,
int numTokens)
{
nunchaku::kernels::test_pack_qkv(
from_torch(input),
from_torch(out_q),
from_torch(out_k),
from_torch(out_v),
numTokens
);
}
};
\ No newline at end of file
void test_pack_qkv(torch::Tensor input, torch::Tensor out_q, torch::Tensor out_k, torch::Tensor out_v, int numTokens) {
nunchaku::kernels::test_pack_qkv(
from_torch(input), from_torch(out_q), from_torch(out_k), from_torch(out_v), numTokens);
}
}; // namespace nunchaku::ops
......@@ -11,49 +11,44 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel")
.def(py::init<>())
.def("init", &QuantizedFluxModel::init,
py::arg("use_fp4"),
py::arg("offload"),
py::arg("bf16"),
py::arg("deviceId")
)
.def("set_residual_callback", [](QuantizedFluxModel &self, pybind11::object call_back) {
if (call_back.is_none()) {
self.set_residual_callback(pybind11::function());
} else {
self.set_residual_callback(call_back);
}
})
.def("init",
&QuantizedFluxModel::init,
py::arg("use_fp4"),
py::arg("offload"),
py::arg("bf16"),
py::arg("deviceId"))
.def("set_residual_callback",
[](QuantizedFluxModel &self, pybind11::object call_back) {
if (call_back.is_none()) {
self.set_residual_callback(pybind11::function());
} else {
self.set_residual_callback(call_back);
}
})
.def("reset", &QuantizedFluxModel::reset)
.def("load", &QuantizedFluxModel::load,
py::arg("path"),
py::arg("partial") = false
)
.def("loadDict", &QuantizedFluxModel::loadDict,
py::arg("dict"),
py::arg("partial") = false
)
.def("forward", &QuantizedFluxModel::forward,
py::arg("hidden_states"),
py::arg("encoder_hidden_states"),
py::arg("temb"),
py::arg("rotary_emb_img"),
py::arg("rotary_emb_context"),
py::arg("rotary_emb_single"),
py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none(),
py::arg("skip_first_layer") = false
)
.def("forward_layer", &QuantizedFluxModel::forward_layer,
py::arg("idx"),
py::arg("hidden_states"),
py::arg("encoder_hidden_states"),
py::arg("temb"),
py::arg("rotary_emb_img"),
py::arg("rotary_emb_context"),
py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none()
)
.def("load", &QuantizedFluxModel::load, py::arg("path"), py::arg("partial") = false)
.def("loadDict", &QuantizedFluxModel::loadDict, py::arg("dict"), py::arg("partial") = false)
.def("forward",
&QuantizedFluxModel::forward,
py::arg("hidden_states"),
py::arg("encoder_hidden_states"),
py::arg("temb"),
py::arg("rotary_emb_img"),
py::arg("rotary_emb_context"),
py::arg("rotary_emb_single"),
py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none(),
py::arg("skip_first_layer") = false)
.def("forward_layer",
&QuantizedFluxModel::forward_layer,
py::arg("idx"),
py::arg("hidden_states"),
py::arg("encoder_hidden_states"),
py::arg("temb"),
py::arg("rotary_emb_img"),
py::arg("rotary_emb_context"),
py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none())
.def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
.def("norm_one_forward", &QuantizedFluxModel::norm_one_forward)
.def("startDebug", &QuantizedFluxModel::startDebug)
......@@ -61,32 +56,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("getDebugResults", &QuantizedFluxModel::getDebugResults)
.def("setLoraScale", &QuantizedFluxModel::setLoraScale)
.def("setAttentionImpl", &QuantizedFluxModel::setAttentionImpl)
.def("isBF16", &QuantizedFluxModel::isBF16)
;
.def("isBF16", &QuantizedFluxModel::isBF16);
py::class_<QuantizedSanaModel>(m, "QuantizedSanaModel")
.def(py::init<>())
.def("init", &QuantizedSanaModel::init,
py::arg("config"),
py::arg("pag_layers"),
py::arg("use_fp4"),
py::arg("bf16"),
py::arg("deviceId")
)
.def("init",
&QuantizedSanaModel::init,
py::arg("config"),
py::arg("pag_layers"),
py::arg("use_fp4"),
py::arg("bf16"),
py::arg("deviceId"))
.def("reset", &QuantizedSanaModel::reset)
.def("load", &QuantizedSanaModel::load,
py::arg("path"),
py::arg("partial") = false
)
.def("loadDict", &QuantizedSanaModel::loadDict,
py::arg("dict"),
py::arg("partial") = false
)
.def("load", &QuantizedSanaModel::load, py::arg("path"), py::arg("partial") = false)
.def("loadDict", &QuantizedSanaModel::loadDict, py::arg("dict"), py::arg("partial") = false)
.def("forward", &QuantizedSanaModel::forward)
.def("forward_layer", &QuantizedSanaModel::forward_layer)
.def("startDebug", &QuantizedSanaModel::startDebug)
.def("stopDebug", &QuantizedSanaModel::stopDebug)
.def("getDebugResults", &QuantizedSanaModel::getDebugResults)
;
.def("getDebugResults", &QuantizedSanaModel::getDebugResults);
py::class_<QuantizedGEMM>(m, "QuantizedGEMM")
.def(py::init<>())
.def("init", &QuantizedGEMM::init)
......@@ -96,8 +83,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("quantize", &QuantizedGEMM::quantize)
.def("startDebug", &QuantizedGEMM::startDebug)
.def("stopDebug", &QuantizedGEMM::stopDebug)
.def("getDebugResults", &QuantizedGEMM::getDebugResults)
;
.def("getDebugResults", &QuantizedGEMM::getDebugResults);
py::class_<Tensor>(m, "Tensor");
py::class_<QuantizedGEMM88>(m, "QuantizedGEMM88")
.def(py::init<>())
......@@ -107,8 +93,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("forward", &QuantizedGEMM88::forward)
.def("startDebug", &QuantizedGEMM88::startDebug)
.def("stopDebug", &QuantizedGEMM88::stopDebug)
.def("getDebugResults", &QuantizedGEMM88::getDebugResults)
;
.def("getDebugResults", &QuantizedGEMM88::getDebugResults);
m.def_submodule("ops")
.def("gemm_w4a4", nunchaku::ops::gemm_w4a4)
......@@ -117,16 +102,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("gemv_awq", nunchaku::ops::gemv_awq)
.def("test_rmsnorm_rope", nunchaku::ops::test_rmsnorm_rope)
.def("test_pack_qkv", nunchaku::ops::test_pack_qkv)
;
.def("test_pack_qkv", nunchaku::ops::test_pack_qkv);
m.def_submodule("utils")
.def("set_log_level", [](const std::string &level) {
spdlog::set_level(spdlog::level::from_str(level));
})
.def("set_log_level", [](const std::string &level) { spdlog::set_level(spdlog::level::from_str(level)); })
.def("set_cuda_stack_limit", nunchaku::utils::set_cuda_stack_limit)
.def("disable_memory_auto_release", nunchaku::utils::disable_memory_auto_release)
.def("trim_memory", nunchaku::utils::trim_memory)
.def("set_faster_i2f_mode", nunchaku::utils::set_faster_i2f_mode)
;
.def("set_faster_i2f_mode", nunchaku::utils::set_faster_i2f_mode);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment