Unverified Commit 57e50f8d authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

style: upgrade the linter (#339)

* style: reformated codes

* style: reformated codes
parent b737368d
import torch import torch
from diffusers import FluxPipeline from diffusers import FluxPipeline
from peft.tuners import lora from peft.tuners import lora
from vars import LORA_PATHS, SVDQ_LORA_PATHS
from nunchaku import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
from vars import LORA_PATHS, SVDQ_LORA_PATHS
def hash_str_to_int(s: str) -> int: def hash_str_to_int(s: str) -> int:
......
...@@ -37,4 +37,4 @@ python latency.py ...@@ -37,4 +37,4 @@ python latency.py
* Adjust the number of inference steps and the guidance scale using `-t` and `-g`, respectively. The defaults are 20 steps and a guidance scale of 5. * Adjust the number of inference steps and the guidance scale using `-t` and `-g`, respectively. The defaults are 20 steps and a guidance scale of 5.
* You can also adjust the [PAG guidance](https://arxiv.org/abs/2403.17377) scale with `--pag-scale`. The default is 2. * You can also adjust the [PAG guidance](https://arxiv.org/abs/2403.17377) scale with `--pag-scale`. The default is 2.
* By default, the script measures the end-to-end latency for generating a single image. To measure the latency of a single DiT forward step instead, use the `--mode step` flag. * By default, the script measures the end-to-end latency for generating a single image. To measure the latency of a single DiT forward step instead, use the `--mode step` flag.
* Specify the number of warmup and test runs using `--warmup-times` and `--test-times`. The defaults are 2 warmup runs and 10 test runs. * Specify the number of warmup and test runs using `--warmup-times` and `--test-times`. The defaults are 2 warmup runs and 10 test runs.
\ No newline at end of file
...@@ -6,4 +6,4 @@ h2{text-align:center} ...@@ -6,4 +6,4 @@ h2{text-align:center}
#accessibility { #accessibility {
text-align: center; /* Center-aligns the text */ text-align: center; /* Center-aligns the text */
margin: auto; /* Centers the element horizontally */ margin: auto; /* Centers the element horizontally */
} }
\ No newline at end of file
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div> <div>
<h1> <h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg" <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="logo" alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/> style="height: 40px; width: auto; display: block; margin: auto;"/>
<a href='https://nvlabs.github.io/Sana/' target="_blank">SANA-1.6B</a> Demo <a href='https://nvlabs.github.io/Sana/' target="_blank">SANA-1.6B</a> Demo
...@@ -50,4 +50,4 @@ ...@@ -50,4 +50,4 @@
</div> </div>
{count_info} {count_info}
</div> </div>
</div> </div>
\ No newline at end of file
...@@ -2,7 +2,6 @@ import argparse ...@@ -2,7 +2,6 @@ import argparse
import os import os
import torch import torch
from utils import get_pipeline from utils import get_pipeline
......
...@@ -4,7 +4,6 @@ import time ...@@ -4,7 +4,6 @@ import time
import torch import torch
from torch import nn from torch import nn
from tqdm import trange from tqdm import trange
from utils import get_pipeline from utils import get_pipeline
......
...@@ -8,13 +8,13 @@ from datetime import datetime ...@@ -8,13 +8,13 @@ from datetime import datetime
import GPUtil import GPUtil
import spaces import spaces
import torch import torch
from nunchaku.models.safety_checker import SafetyChecker
from utils import get_pipeline from utils import get_pipeline
from vars import EXAMPLES, MAX_SEED from vars import EXAMPLES, MAX_SEED
from nunchaku.models.safety_checker import SafetyChecker
# import gradio last to avoid conflicts with other imports # import gradio last to avoid conflicts with other imports
import gradio as gr import gradio as gr # noqa: isort: skip
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
...@@ -73,7 +73,7 @@ def generate( ...@@ -73,7 +73,7 @@ def generate(
prompt = "A peaceful world." prompt = "A peaceful world."
images, latency_strs = [], [] images, latency_strs = [], []
for i, pipeline in enumerate(pipelines): for i, pipeline in enumerate(pipelines):
progress = gr.Progress(track_tqdm=True) gr.Progress(track_tqdm=True)
start_time = time.time() start_time = time.time()
image = pipeline( image = pipeline(
prompt=prompt, prompt=prompt,
...@@ -124,11 +124,11 @@ if len(gpus) > 0: ...@@ -124,11 +124,11 @@ if len(gpus) > 0:
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."' notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
with gr.Blocks( with gr.Blocks(
css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"], css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"],
title=f"SVDQuant SANA-1600M Demo", title="SVDQuant SANA-1600M Demo",
) as demo: ) as demo:
def get_header_str(): def get_header_str():
......
...@@ -46,4 +46,4 @@ ...@@ -46,4 +46,4 @@
</g> </g>
<path d="M418.39,22.56c-.9-2.12-3.08-3.99-2.86-6.3.6-6.24-1.96-9.26-5.87-10.8-5.59-2.76-10.79-2.48-15.59.89-5.16,3.63-6.9,8.92-5.88,15.06-3.44,1.79-6.77,3.46-10.03,5.27-1.04.58-1.67.45-2.57-.24-4.36-3.31-9.77-3.35-14.45-.38-2.92,1.85-5.92,3.61-8.99,5.2-4.67,2.41-8.51,5.37-9.23,11.06-.06.44-.81,1.01-1.34,1.15-2.64.72-5.32,1.29-7.97,1.98-1.09.28-1.8-.03-2.5-.87-3.33-4.01-7.59-5.28-12.62-4.14-3.55.8-7.1,1.63-10.65,2.41-4.53.99-8.9,2.23-11.5,6.61-.14.23-.76.32-1.12.26-3.14-.54-6.26-1.14-9.44-1.73-.4-4.66-2.91-7.77-6.66-10.13-3.81-2.39-7.54-4.92-11.29-7.41-2.5-1.65-5.47-2.9-8.14-1.91-3.92,1.46-5.66-.68-7.62-3.11-.53-.65-1.1-1.28-1.71-1.87-.91-.89-1.15-1.7-.63-3.04,2.56-6.58-1.25-14.13-8-16.06-4.78-1.36-9.57-2.67-14.37-3.94-6.58-1.74-12.14.91-14.99,7.05-.24.51-.79,1.18-1.25,1.23-1.63.18-3.26.33-4.89.46.01.52.01,1.04.01,1.56,4.44-1,8.77-1.17,13.19-.6-1.82,1.27-8.29,2.27-13.22,2.36-.04,1.47-.13,2.95-.23,4.43,4.6-.4,9.19-.79,13.79-1.19.01.08.02.15.03.23-2.2.7-4.39,1.39-6.62,2.09,1.3,2.68,3.69,4.83,6.67,5.69,5.33,1.55,10.69,3.06,16.09,4.37,1.72.42,3.61.13,5.84.18-1.34-2.39-2.39-4.26-3.44-6.13l.3-.23c5.72,6.3,11.43,12.61,17.15,18.91-.06.07-.12.13-.18.2-2.04-1.41-4.09-2.82-6.2-4.27-1.71,5.48.04,10.66,4.66,13.84,4.3,2.96,8.67,5.81,13.05,8.64,5.02,3.25,12.27,1.96,15.19-2.14-2.16-.92-4.3-1.83-6.44-2.74.05-.15.11-.3.16-.45,6.02,1.12,12.04,2.21,18.04,3.4.43.09.91.85,1.05,1.39,1.65,6.24,7.78,10.23,14.06,8.93,4.97-1.03,9.89-2.3,14.84-3.41,4.98-1.12,8.06-4.16,9.57-9.25-2.61.09-5,.18-7.4.27l-.02-.24,27-6.51c.05.15.09.31.14.46l-6.85,3.18c3.69,3.77,9.13,4.98,13.57,2.64,5.32-2.8,10.5-5.87,15.62-9.01,2.83-1.74,5.21-6.46,4.49-8.99-2.38.52-4.76,1.04-7.15,1.57-.01-.08-.03-.16-.04-.24l24.55-13.02.16.19c-1.43,1.36-2.86,2.72-4.35,4.14,4.09,3.31,8.57,4.15,13.26,2.79,5.85-1.7,9.32-5.87,10.62-12.29.39.9.81,1.74,1.2,2.55ZM240.66,6.17c2.19-1.05,6.89,2.57,6.7,5.28-2.92-.11-5.18-1.48-7-3.61-.24-.3-.01-1.52.3-1.67ZM236.31,14.54c-1.54,1.54-1.21,3.32.9,6.16-5.49-1.54-10.72-3-15.95-4.46.03-.17.07-.35.1-.52,2.43-.24,5.06-.28,5.67-3.36.39-1.94-.51-3.39-2.17-4.55,2.51.68,5.01,1.35,7.52,2.03,2.26.62,4.57,1.13,6.77,1.94,1.26.46,2.34,1.39,3.48,1.83-1.1-.18-2.23-.61-3.28-.46-1.08.15-2.29.64-3.04,1.39ZM243.02,19.76c3.02.35,11.2,8.77,12.25,12.7-4.84-3.4-8.69-7.74-12.25-12.7ZM271.35,48.21c-.99,2.02-.01,3.61,1.22,5.22-5.37-3.34-10.84-6.47-15.54-10.72.94.54,1.85,1.43,2.84,1.53,1.04.11,2.39-.23,3.21-.87,1.98-1.55,1.71-3.13-.61-7.24,4.91,3.25,9.83,6.5,14.74,9.76-2.44-.05-4.65-.17-5.86,2.32ZM267.38,32.23c4.46,2.84,9.48,4.89,13.41,9.32-2.49.4-12.99-7.11-13.41-9.32ZM284.99,50.83c3.61-1.39,15.07.42,17.7,2.77-5.94.19-11.65-.91-17.7-2.77ZM322.43,48.01c-2.55,1.22-3.64,2.83-3.16,4.68.58,2.26,2.21,3.21,5.16,3.2-6.25,1.93-12.54,3.69-19.16,4.1,2.4-.49,4.56-1.22,4.65-4.09.1-2.89-1.86-4.04-4.44-4.56,5.59-1.28,11.18-2.56,16.76-3.83.06.16.13.33.19.5ZM315.23,43.15c2.4-2.34,6.44-2.95,8.44-1.33-1.16,2.42-6.21,3.29-8.44,1.33ZM333.09,48.29c5.19-3.09,10.81-4.61,16.85-4.57-5.26,2.89-10.96,4.09-16.85,4.57ZM371.58,39.47l-15.81,9.08c-.12-.12-.24-.24-.36-.36,2.07-1.36,3.17-3.17,2.04-5.48-1.15-2.36-3.34-2.39-5.68-1.99,5.35-3.33,10.55-6.82,16.39-9.16-1.98,1.91-2.68,3.81-1.86,5.56.82,1.73,2.46,2.39,5.28,2.35ZM370.85,27.31c-2,.5-4.03.9-6.07,1.18-.43.06-1.37-.52-1.35-.76.03-.55.45-1.12.83-1.59.23-.28.67-.38,1.02-.57v-.42c1.79,0,3.58-.04,5.36.07.42.02.8.55,1.2.84-.33.43-.58,1.15-.99,1.25ZM378.71,29.44c4.29-4.26,9.38-7.12,15.26-8.59-4.37,4.11-9.65,6.64-15.26,8.59ZM391.92,14.77c-.33.39-1.13.37-1.71.54-.13-.58-.44-1.19-.34-1.73.4-2.33,2.42-4.9,4.89-6.03.17,0,.77.02,1.38.03-.03.62.17,1.4-.12,1.83-1.28,1.85-2.65,3.64-4.1,5.36ZM407.84,23.73c-1.86,1.82-5.89,3.26-8.87,1.19.94-1.27,2.06-2.44,2.73-3.83.31-.64-.06-1.82-.47-2.57-1.06-1.94-3.17-2.19-6.12-.83.01-3.35,2.27-5.98,5.73-6.88,3.25-.84,6.83.81,8.56,3.94,1.53,2.76.85,6.6-1.56,8.98Z"/> <path d="M418.39,22.56c-.9-2.12-3.08-3.99-2.86-6.3.6-6.24-1.96-9.26-5.87-10.8-5.59-2.76-10.79-2.48-15.59.89-5.16,3.63-6.9,8.92-5.88,15.06-3.44,1.79-6.77,3.46-10.03,5.27-1.04.58-1.67.45-2.57-.24-4.36-3.31-9.77-3.35-14.45-.38-2.92,1.85-5.92,3.61-8.99,5.2-4.67,2.41-8.51,5.37-9.23,11.06-.06.44-.81,1.01-1.34,1.15-2.64.72-5.32,1.29-7.97,1.98-1.09.28-1.8-.03-2.5-.87-3.33-4.01-7.59-5.28-12.62-4.14-3.55.8-7.1,1.63-10.65,2.41-4.53.99-8.9,2.23-11.5,6.61-.14.23-.76.32-1.12.26-3.14-.54-6.26-1.14-9.44-1.73-.4-4.66-2.91-7.77-6.66-10.13-3.81-2.39-7.54-4.92-11.29-7.41-2.5-1.65-5.47-2.9-8.14-1.91-3.92,1.46-5.66-.68-7.62-3.11-.53-.65-1.1-1.28-1.71-1.87-.91-.89-1.15-1.7-.63-3.04,2.56-6.58-1.25-14.13-8-16.06-4.78-1.36-9.57-2.67-14.37-3.94-6.58-1.74-12.14.91-14.99,7.05-.24.51-.79,1.18-1.25,1.23-1.63.18-3.26.33-4.89.46.01.52.01,1.04.01,1.56,4.44-1,8.77-1.17,13.19-.6-1.82,1.27-8.29,2.27-13.22,2.36-.04,1.47-.13,2.95-.23,4.43,4.6-.4,9.19-.79,13.79-1.19.01.08.02.15.03.23-2.2.7-4.39,1.39-6.62,2.09,1.3,2.68,3.69,4.83,6.67,5.69,5.33,1.55,10.69,3.06,16.09,4.37,1.72.42,3.61.13,5.84.18-1.34-2.39-2.39-4.26-3.44-6.13l.3-.23c5.72,6.3,11.43,12.61,17.15,18.91-.06.07-.12.13-.18.2-2.04-1.41-4.09-2.82-6.2-4.27-1.71,5.48.04,10.66,4.66,13.84,4.3,2.96,8.67,5.81,13.05,8.64,5.02,3.25,12.27,1.96,15.19-2.14-2.16-.92-4.3-1.83-6.44-2.74.05-.15.11-.3.16-.45,6.02,1.12,12.04,2.21,18.04,3.4.43.09.91.85,1.05,1.39,1.65,6.24,7.78,10.23,14.06,8.93,4.97-1.03,9.89-2.3,14.84-3.41,4.98-1.12,8.06-4.16,9.57-9.25-2.61.09-5,.18-7.4.27l-.02-.24,27-6.51c.05.15.09.31.14.46l-6.85,3.18c3.69,3.77,9.13,4.98,13.57,2.64,5.32-2.8,10.5-5.87,15.62-9.01,2.83-1.74,5.21-6.46,4.49-8.99-2.38.52-4.76,1.04-7.15,1.57-.01-.08-.03-.16-.04-.24l24.55-13.02.16.19c-1.43,1.36-2.86,2.72-4.35,4.14,4.09,3.31,8.57,4.15,13.26,2.79,5.85-1.7,9.32-5.87,10.62-12.29.39.9.81,1.74,1.2,2.55ZM240.66,6.17c2.19-1.05,6.89,2.57,6.7,5.28-2.92-.11-5.18-1.48-7-3.61-.24-.3-.01-1.52.3-1.67ZM236.31,14.54c-1.54,1.54-1.21,3.32.9,6.16-5.49-1.54-10.72-3-15.95-4.46.03-.17.07-.35.1-.52,2.43-.24,5.06-.28,5.67-3.36.39-1.94-.51-3.39-2.17-4.55,2.51.68,5.01,1.35,7.52,2.03,2.26.62,4.57,1.13,6.77,1.94,1.26.46,2.34,1.39,3.48,1.83-1.1-.18-2.23-.61-3.28-.46-1.08.15-2.29.64-3.04,1.39ZM243.02,19.76c3.02.35,11.2,8.77,12.25,12.7-4.84-3.4-8.69-7.74-12.25-12.7ZM271.35,48.21c-.99,2.02-.01,3.61,1.22,5.22-5.37-3.34-10.84-6.47-15.54-10.72.94.54,1.85,1.43,2.84,1.53,1.04.11,2.39-.23,3.21-.87,1.98-1.55,1.71-3.13-.61-7.24,4.91,3.25,9.83,6.5,14.74,9.76-2.44-.05-4.65-.17-5.86,2.32ZM267.38,32.23c4.46,2.84,9.48,4.89,13.41,9.32-2.49.4-12.99-7.11-13.41-9.32ZM284.99,50.83c3.61-1.39,15.07.42,17.7,2.77-5.94.19-11.65-.91-17.7-2.77ZM322.43,48.01c-2.55,1.22-3.64,2.83-3.16,4.68.58,2.26,2.21,3.21,5.16,3.2-6.25,1.93-12.54,3.69-19.16,4.1,2.4-.49,4.56-1.22,4.65-4.09.1-2.89-1.86-4.04-4.44-4.56,5.59-1.28,11.18-2.56,16.76-3.83.06.16.13.33.19.5ZM315.23,43.15c2.4-2.34,6.44-2.95,8.44-1.33-1.16,2.42-6.21,3.29-8.44,1.33ZM333.09,48.29c5.19-3.09,10.81-4.61,16.85-4.57-5.26,2.89-10.96,4.09-16.85,4.57ZM371.58,39.47l-15.81,9.08c-.12-.12-.24-.24-.36-.36,2.07-1.36,3.17-3.17,2.04-5.48-1.15-2.36-3.34-2.39-5.68-1.99,5.35-3.33,10.55-6.82,16.39-9.16-1.98,1.91-2.68,3.81-1.86,5.56.82,1.73,2.46,2.39,5.28,2.35ZM370.85,27.31c-2,.5-4.03.9-6.07,1.18-.43.06-1.37-.52-1.35-.76.03-.55.45-1.12.83-1.59.23-.28.67-.38,1.02-.57v-.42c1.79,0,3.58-.04,5.36.07.42.02.8.55,1.2.84-.33.43-.58,1.15-.99,1.25ZM378.71,29.44c4.29-4.26,9.38-7.12,15.26-8.59-4.37,4.11-9.65,6.64-15.26,8.59ZM391.92,14.77c-.33.39-1.13.37-1.71.54-.13-.58-.44-1.19-.34-1.73.4-2.33,2.42-4.9,4.89-6.03.17,0,.77.02,1.38.03-.03.62.17,1.4-.12,1.83-1.28,1.85-2.65,3.64-4.1,5.36ZM407.84,23.73c-1.86,1.82-5.89,3.26-8.87,1.19.94-1.27,2.06-2.44,2.73-3.83.31-.64-.06-1.82-.47-2.57-1.06-1.94-3.17-2.19-6.12-.83.01-3.35,2.27-5.98,5.73-6.88,3.25-.84,6.83.81,8.56,3.94,1.53,2.76.85,6.6-1.56,8.98Z"/>
<circle class="cls-1" cx="206.14" cy="15.03" r="8.22"/> <circle class="cls-1" cx="206.14" cy="15.03" r="8.22"/>
</svg> </svg>
\ No newline at end of file
...@@ -23,4 +23,4 @@ ...@@ -23,4 +23,4 @@
<polygon class="cls-1" points="538.46 21.84 538.46 39.47 503.18 0 484.49 0 484.49 63.62 502.88 63.62 502.88 24.15 538.26 63.62 556.85 63.62 556.85 21.84 538.46 21.84"/> <polygon class="cls-1" points="538.46 21.84 538.46 39.47 503.18 0 484.49 0 484.49 63.62 502.88 63.62 502.88 24.15 538.26 63.62 556.85 63.62 556.85 21.84 538.46 21.84"/>
<rect class="cls-2" x="538.46" width="18.39" height="14.25"/> <rect class="cls-2" x="538.46" width="18.39" height="14.25"/>
<path class="cls-1" d="M565.55,14.12V0h67.25v14.12h-23.48v49.5h-18.39V14.12h-25.38Z"/> <path class="cls-1" d="M565.55,14.12V0h67.25v14.12h-23.48v49.5h-18.39V14.12h-25.38Z"/>
</svg> </svg>
\ No newline at end of file
...@@ -40,4 +40,4 @@ For detailed guidance on testing, refer to the [`tests/README.md`](../tests/READ ...@@ -40,4 +40,4 @@ For detailed guidance on testing, refer to the [`tests/README.md`](../tests/READ
## Acknowledgments ## Acknowledgments
This contribution guide is adapted from [SGLang](https://docs.sglang.ai/references/contribution_guide.html). We thank them for the inspiration. This contribution guide is adapted from [SGLang](https://docs.sglang.ai/references/contribution_guide.html). We thank them for the inspiration.
\ No newline at end of file
...@@ -62,17 +62,17 @@ Then verify the Python version and installed PyTorch version: ...@@ -62,17 +62,17 @@ Then verify the Python version and installed PyTorch version:
Install PyTorch appropriate for your setup Install PyTorch appropriate for your setup
- **For most users**: - **For most users**:
```bash ```bash
"G:\ComfyuI\python\python.exe" -m pip install torch==2.6 torchvision==0.21 torchaudio==2.6 "G:\ComfyuI\python\python.exe" -m pip install torch==2.6 torchvision==0.21 torchaudio==2.6
``` ```
- **For RTX 50-series GPUs** (requires PyTorch ≥2.7 with CUDA 12.8): - **For RTX 50-series GPUs** (requires PyTorch ≥2.7 with CUDA 12.8):
```bash ```bash
"G:\ComfyuI\python\python.exe" -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 "G:\ComfyuI\python\python.exe" -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
``` ```
## Step 3: Install Nunchaku ## Step 3: Install Nunchaku
...@@ -108,55 +108,55 @@ You can also run a test (requires a Hugging Face token for downloading the model ...@@ -108,55 +108,55 @@ You can also run a test (requires a Hugging Face token for downloading the model
Please use CMD instead of PowerShell for building. Please use CMD instead of PowerShell for building.
- Step 1: Install Build Tools - Step 1: Install Build Tools
```bash ```bash
C:\Users\muyang\miniconda3\envs\comfyui\python.exe C:\Users\muyang\miniconda3\envs\comfyui\python.exe
"G:\ComfyuI\python\python.exe" -m pip install ninja setuptools wheel build "G:\ComfyuI\python\python.exe" -m pip install ninja setuptools wheel build
``` ```
- Step 2: Clone the Repository - Step 2: Clone the Repository
```bash ```bash
git clone https://github.com/mit-han-lab/nunchaku.git git clone https://github.com/mit-han-lab/nunchaku.git
cd nunchaku cd nunchaku
git submodule init git submodule init
git submodule update git submodule update
``` ```
- Step 3: Set Up Visual Studio Environment - Step 3: Set Up Visual Studio Environment
Locate the `VsDevCmd.bat` script on your system. Example path: Locate the `VsDevCmd.bat` script on your system. Example path:
``` ```
C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\VsDevCmd.bat C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\VsDevCmd.bat
``` ```
Then run: Then run:
```bash ```bash
"C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\VsDevCmd.bat" -startdir=none -arch=x64 -host_arch=x64 "C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\VsDevCmd.bat" -startdir=none -arch=x64 -host_arch=x64
set DISTUTILS_USE_SDK=1 set DISTUTILS_USE_SDK=1
``` ```
- Step 4: Build Nunchaku - Step 4: Build Nunchaku
```bash ```bash
"G:\ComfyuI\python\python.exe" setup.py develop "G:\ComfyuI\python\python.exe" setup.py develop
``` ```
Verify with: Verify with:
```bash ```bash
"G:\ComfyuI\python\python.exe" -c "import nunchaku" "G:\ComfyuI\python\python.exe" -c "import nunchaku"
``` ```
You can also run a test (requires a Hugging Face token for downloading the models): You can also run a test (requires a Hugging Face token for downloading the models):
```bash ```bash
"G:\ComfyuI\python\python.exe" -m huggingface-cli login "G:\ComfyuI\python\python.exe" -m huggingface-cli login
"G:\ComfyuI\python\python.exe" -m nunchaku.test "G:\ComfyuI\python\python.exe" -m nunchaku.test
``` ```
- (Optional) Step 5: Building wheel for Portable Python - (Optional) Step 5: Building wheel for Portable Python
If building directly with portable Python fails, you can first build the wheel in a working Conda environment, then install the `.whl` file using your portable Python: If building directly with portable Python fails, you can first build the wheel in a working Conda environment, then install the `.whl` file using your portable Python:
...@@ -182,42 +182,42 @@ Alternatively, install using [ComfyUI-Manager](https://github.com/Comfy-Org/Comf ...@@ -182,42 +182,42 @@ Alternatively, install using [ComfyUI-Manager](https://github.com/Comfy-Org/Comf
## 2. Download Models ## 2. Download Models
- **Standard FLUX.1-dev Models** - **Standard FLUX.1-dev Models**
Start by downloading the standard [FLUX.1-dev text encoders](https://huggingface.co/comfyanonymous/flux_text_encoders/tree/main) and [VAE](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/ae.safetensors). You can also optionally download the original [BF16 FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors) model. An example command: Start by downloading the standard [FLUX.1-dev text encoders](https://huggingface.co/comfyanonymous/flux_text_encoders/tree/main) and [VAE](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/ae.safetensors). You can also optionally download the original [BF16 FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors) model. An example command:
```bash ```bash
huggingface-cli download comfyanonymous/flux_text_encoders clip_l.safetensors --local-dir models/text_encoders huggingface-cli download comfyanonymous/flux_text_encoders clip_l.safetensors --local-dir models/text_encoders
huggingface-cli download comfyanonymous/flux_text_encoders t5xxl_fp16.safetensors --local-dir models/text_encoders huggingface-cli download comfyanonymous/flux_text_encoders t5xxl_fp16.safetensors --local-dir models/text_encoders
huggingface-cli download black-forest-labs/FLUX.1-schnell ae.safetensors --local-dir models/vae huggingface-cli download black-forest-labs/FLUX.1-schnell ae.safetensors --local-dir models/vae
huggingface-cli download black-forest-labs/FLUX.1-dev flux1-dev.safetensors --local-dir models/diffusion_models huggingface-cli download black-forest-labs/FLUX.1-dev flux1-dev.safetensors --local-dir models/diffusion_models
``` ```
- **SVDQuant 4-bit FLUX.1-dev Models** - **SVDQuant 4-bit FLUX.1-dev Models**
Next, download the SVDQuant 4-bit models: Next, download the SVDQuant 4-bit models:
- For **50-series GPUs**, use the [FP4 model](https://huggingface.co/mit-han-lab/svdq-fp4-flux.1-dev). - For **50-series GPUs**, use the [FP4 model](https://huggingface.co/mit-han-lab/svdq-fp4-flux.1-dev).
- For **other GPUs**, use the [INT4 model](https://huggingface.co/mit-han-lab/svdq-int4-flux.1-dev). - For **other GPUs**, use the [INT4 model](https://huggingface.co/mit-han-lab/svdq-int4-flux.1-dev).
Make sure to place the **entire downloaded folder** into `models/diffusion_models`. For example: Make sure to place the **entire downloaded folder** into `models/diffusion_models`. For example:
```bash ```bash
huggingface-cli download mit-han-lab/svdq-int4-flux.1-dev --local-dir models/diffusion_models/svdq-int4-flux.1-dev huggingface-cli download mit-han-lab/svdq-int4-flux.1-dev --local-dir models/diffusion_models/svdq-int4-flux.1-dev
``` ```
- **(Optional): Download Sample LoRAs** - **(Optional): Download Sample LoRAs**
You can test with some sample LoRAs like [FLUX.1-Turbo](https://huggingface.co/alimama-creative/FLUX.1-Turbo-Alpha/blob/main/diffusion_pytorch_model.safetensors) and [Ghibsky](https://huggingface.co/aleksa-codes/flux-ghibsky-illustration/blob/main/lora.safetensors). Place these files in the `models/loras` directory: You can test with some sample LoRAs like [FLUX.1-Turbo](https://huggingface.co/alimama-creative/FLUX.1-Turbo-Alpha/blob/main/diffusion_pytorch_model.safetensors) and [Ghibsky](https://huggingface.co/aleksa-codes/flux-ghibsky-illustration/blob/main/lora.safetensors). Place these files in the `models/loras` directory:
```bash ```bash
huggingface-cli download alimama-creative/FLUX.1-Turbo-Alpha diffusion_pytorch_model.safetensors --local-dir models/loras huggingface-cli download alimama-creative/FLUX.1-Turbo-Alpha diffusion_pytorch_model.safetensors --local-dir models/loras
huggingface-cli download aleksa-codes/flux-ghibsky-illustration lora.safetensors --local-dir models/loras huggingface-cli download aleksa-codes/flux-ghibsky-illustration lora.safetensors --local-dir models/loras
``` ```
## 3. Set Up Workflows ## 3. Set Up Workflows
To use the official workflows, download them from the [ComfyUI-nunchaku](https://github.com/mit-han-lab/ComfyUI-nunchaku/tree/main/workflows) and place them in your `ComfyUI/user/default/workflows` directory. The command can be To use the official workflows, download them from the [ComfyUI-nunchaku](https://github.com/mit-han-lab/ComfyUI-nunchaku/tree/main/workflows) and place them in your `ComfyUI/user/default/workflows` directory. The command can be
```bash ```bash
# From the root of your ComfyUI folder # From the root of your ComfyUI folder
...@@ -231,4 +231,4 @@ You can now launch ComfyUI and try running the example workflows. ...@@ -231,4 +231,4 @@ You can now launch ComfyUI and try running the example workflows.
If you encounter issues, refer to our: If you encounter issues, refer to our:
- [FAQs](https://github.com/mit-han-lab/nunchaku/discussions/262) - [FAQs](https://github.com/mit-han-lab/nunchaku/discussions/262)
- [GitHub Issues](https://github.com/mit-han-lab/nunchaku/issues) - [GitHub Issues](https://github.com/mit-han-lab/nunchaku/issues)
\ No newline at end of file
...@@ -4,7 +4,6 @@ from diffusers.models import FluxMultiControlNetModel ...@@ -4,7 +4,6 @@ from diffusers.models import FluxMultiControlNetModel
from diffusers.utils import load_image from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.diffusers_adapters.flux import apply_cache_on_pipe
from nunchaku.utils import get_gpu_memory, get_precision from nunchaku.utils import get_gpu_memory, get_precision
base_model = "black-forest-labs/FLUX.1-dev" base_model = "black-forest-labs/FLUX.1-dev"
...@@ -29,11 +28,6 @@ if need_offload: ...@@ -29,11 +28,6 @@ if need_offload:
else: else:
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
# apply_cache_on_pipe(
# pipeline, residual_diff_threshold=0.1
# ) # Uncomment this line to enable first-block cache to speedup generation
prompt = "A anime style girl with messy beach waves." prompt = "A anime style girl with messy beach waves."
control_image_depth = load_image( control_image_depth = load_image(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg" "https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg"
......
...@@ -7,14 +7,10 @@ from nunchaku.utils import get_precision ...@@ -7,14 +7,10 @@ from nunchaku.utils import get_precision
precision = get_precision() precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
f"mit-han-lab/svdq-{precision}-flux.1-dev"
)
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
transformer=transformer,
torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
apply_cache_on_pipe( apply_cache_on_pipe(
...@@ -24,9 +20,6 @@ apply_cache_on_pipe( ...@@ -24,9 +20,6 @@ apply_cache_on_pipe(
residual_diff_threshold_single=0.12, residual_diff_threshold_single=0.12,
) )
image = pipeline( image = pipeline(["A cat holding a sign that says hello world"], num_inference_steps=50).images[0]
["A cat holding a sign that says hello world"],
num_inference_steps=50
).images[0]
image.save(f"flux.1-dev-cache-{precision}.png") image.save(f"flux.1-dev-cache-{precision}.png")
from .models import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel, NunchakuT5EncoderModel from .models import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel, NunchakuT5EncoderModel
__all__ = ["NunchakuFluxTransformer2dModel", "NunchakuSanaTransformer2DModel", "NunchakuT5EncoderModel"]
...@@ -20,7 +20,8 @@ public: ...@@ -20,7 +20,8 @@ public:
ModuleWrapper::init(deviceId); ModuleWrapper::init(deviceId);
CUDADeviceContext ctx(this->deviceId); CUDADeviceContext ctx(this->deviceId);
net = std::make_unique<FluxModel>(use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); net = std::make_unique<FluxModel>(
use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
} }
bool isBF16() { bool isBF16() {
...@@ -28,52 +29,50 @@ public: ...@@ -28,52 +29,50 @@ public:
return net->dtype == Tensor::BF16; return net->dtype == Tensor::BF16;
} }
pybind11::function residual_callback; pybind11::function residual_callback;
void set_residual_callback(pybind11::function callback) { void set_residual_callback(pybind11::function callback) {
pybind11::gil_scoped_acquire gil; pybind11::gil_scoped_acquire gil;
if (!callback || callback.is_none()) { if (!callback || callback.is_none()) {
residual_callback = pybind11::function(); residual_callback = pybind11::function();
if (net){ if (net) {
net->set_residual_callback(nullptr); net->set_residual_callback(nullptr);
} }
return; return;
} }
residual_callback = std::move(callback); residual_callback = std::move(callback);
if (net) { if (net) {
pybind11::object cb = residual_callback; pybind11::object cb = residual_callback;
net->set_residual_callback([cb](const Tensor &x) -> Tensor { net->set_residual_callback([cb](const Tensor &x) -> Tensor {
pybind11::gil_scoped_acquire gil; pybind11::gil_scoped_acquire gil;
torch::Tensor torch_x = to_torch(x); torch::Tensor torch_x = to_torch(x);
pybind11::object result = cb(torch_x); pybind11::object result = cb(torch_x);
torch::Tensor torch_y = result.cast<torch::Tensor>(); torch::Tensor torch_y = result.cast<torch::Tensor>();
Tensor y = from_torch(torch_y); Tensor y = from_torch(torch_y);
return y; return y;
}); });
} else { } else {
} }
} }
torch::Tensor forward( torch::Tensor forward(torch::Tensor hidden_states,
torch::Tensor hidden_states, torch::Tensor encoder_hidden_states,
torch::Tensor encoder_hidden_states, torch::Tensor temb,
torch::Tensor temb, torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_img, torch::Tensor rotary_emb_context,
torch::Tensor rotary_emb_context, torch::Tensor rotary_emb_single,
torch::Tensor rotary_emb_single, std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt, std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt, bool skip_first_layer = false) {
bool skip_first_layer = false)
{
checkModel(); checkModel();
CUDADeviceContext ctx(deviceId); CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward"); spdlog::debug("QuantizedFluxModel forward");
hidden_states = hidden_states.contiguous(); hidden_states = hidden_states.contiguous();
encoder_hidden_states = encoder_hidden_states.contiguous(); encoder_hidden_states = encoder_hidden_states.contiguous();
temb = temb.contiguous(); temb = temb.contiguous();
rotary_emb_img = rotary_emb_img.contiguous(); rotary_emb_img = rotary_emb_img.contiguous();
rotary_emb_context = rotary_emb_context.contiguous(); rotary_emb_context = rotary_emb_context.contiguous();
rotary_emb_single = rotary_emb_single.contiguous(); rotary_emb_single = rotary_emb_single.contiguous();
Tensor result = net->forward( Tensor result = net->forward(
from_torch(hidden_states), from_torch(hidden_states),
...@@ -83,9 +82,10 @@ public: ...@@ -83,9 +82,10 @@ public:
from_torch(rotary_emb_context), from_torch(rotary_emb_context),
from_torch(rotary_emb_single), from_torch(rotary_emb_single),
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{}, controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value() ? from_torch(controlnet_single_block_samples.value().contiguous()) : Tensor{}, controlnet_single_block_samples.has_value()
skip_first_layer ? from_torch(controlnet_single_block_samples.value().contiguous())
); : Tensor{},
skip_first_layer);
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
...@@ -93,25 +93,24 @@ public: ...@@ -93,25 +93,24 @@ public:
return output; return output;
} }
std::tuple<torch::Tensor, torch::Tensor> forward_layer( std::tuple<torch::Tensor, torch::Tensor>
int64_t idx, forward_layer(int64_t idx,
torch::Tensor hidden_states, torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states, torch::Tensor encoder_hidden_states,
torch::Tensor temb, torch::Tensor temb,
torch::Tensor rotary_emb_img, torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context, torch::Tensor rotary_emb_context,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt, std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt) std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt) {
{
CUDADeviceContext ctx(deviceId); CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward_layer {}", idx); spdlog::debug("QuantizedFluxModel forward_layer {}", idx);
hidden_states = hidden_states.contiguous(); hidden_states = hidden_states.contiguous();
encoder_hidden_states = encoder_hidden_states.contiguous(); encoder_hidden_states = encoder_hidden_states.contiguous();
temb = temb.contiguous(); temb = temb.contiguous();
rotary_emb_img = rotary_emb_img.contiguous(); rotary_emb_img = rotary_emb_img.contiguous();
rotary_emb_context = rotary_emb_context.contiguous(); rotary_emb_context = rotary_emb_context.contiguous();
auto &&[hidden_states_, encoder_hidden_states_] = net->forward_layer( auto &&[hidden_states_, encoder_hidden_states_] = net->forward_layer(
idx, idx,
...@@ -121,35 +120,31 @@ public: ...@@ -121,35 +120,31 @@ public:
from_torch(rotary_emb_img), from_torch(rotary_emb_img),
from_torch(rotary_emb_context), from_torch(rotary_emb_context),
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{}, controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value() ? from_torch(controlnet_single_block_samples.value().contiguous()) : Tensor{} controlnet_single_block_samples.has_value()
); ? from_torch(controlnet_single_block_samples.value().contiguous())
: Tensor{});
hidden_states = to_torch(hidden_states_); hidden_states = to_torch(hidden_states_);
encoder_hidden_states = to_torch(encoder_hidden_states_); encoder_hidden_states = to_torch(encoder_hidden_states_);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
return { hidden_states, encoder_hidden_states }; return {hidden_states, encoder_hidden_states};
} }
torch::Tensor forward_single_layer( torch::Tensor forward_single_layer(int64_t idx,
int64_t idx, torch::Tensor hidden_states,
torch::Tensor hidden_states, torch::Tensor temb,
torch::Tensor temb, torch::Tensor rotary_emb_single) {
torch::Tensor rotary_emb_single)
{
CUDADeviceContext ctx(deviceId); CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx); spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx);
hidden_states = hidden_states.contiguous(); hidden_states = hidden_states.contiguous();
temb = temb.contiguous(); temb = temb.contiguous();
rotary_emb_single = rotary_emb_single.contiguous(); rotary_emb_single = rotary_emb_single.contiguous();
Tensor result = net->single_transformer_blocks.at(idx)->forward( Tensor result = net->single_transformer_blocks.at(idx)->forward(
from_torch(hidden_states), from_torch(hidden_states), from_torch(temb), from_torch(rotary_emb_single));
from_torch(temb),
from_torch(rotary_emb_single)
);
hidden_states = to_torch(result); hidden_states = to_torch(result);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
...@@ -159,19 +154,15 @@ public: ...@@ -159,19 +154,15 @@ public:
// expose the norm1 forward method of the transformer blocks // expose the norm1 forward method of the transformer blocks
// this is used by TeaCache to get the norm1 output // this is used by TeaCache to get the norm1 output
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> norm_one_forward( std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
int64_t idx, norm_one_forward(int64_t idx, torch::Tensor hidden_states, torch::Tensor temb) {
torch::Tensor hidden_states, AdaLayerNormZero::Output result =
torch::Tensor temb net->transformer_blocks.at(idx)->norm1.forward(from_torch(hidden_states), from_torch(temb));
) { return {to_torch(result.x),
AdaLayerNormZero::Output result = net->transformer_blocks.at(idx)->norm1.forward(from_torch(hidden_states), from_torch(temb)); to_torch(result.gate_msa),
return { to_torch(result.shift_mlp),
to_torch(result.x), to_torch(result.scale_mlp),
to_torch(result.gate_msa), to_torch(result.gate_mlp)};
to_torch(result.shift_mlp),
to_torch(result.scale_mlp),
to_torch(result.gate_mlp)
};
} }
// must be called after loading lora // must be called after loading lora
...@@ -214,5 +205,4 @@ public: ...@@ -214,5 +205,4 @@ public:
throw std::invalid_argument(spdlog::fmt_lib::format("Invalid attention implementation {}", name)); throw std::invalid_argument(spdlog::fmt_lib::format("Invalid attention implementation {}", name));
} }
} }
};
};
\ No newline at end of file
...@@ -16,7 +16,12 @@ public: ...@@ -16,7 +16,12 @@ public:
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize)); checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val); spdlog::debug("Stack={}", val);
net = std::make_unique<GEMM_W4A4>((int)in_features, (int)out_features, bias, use_fp4, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); net = std::make_unique<GEMM_W4A4>((int)in_features,
(int)out_features,
bias,
use_fp4,
bf16 ? Tensor::BF16 : Tensor::FP16,
Device::cuda((int)deviceId));
} }
torch::Tensor forward(torch::Tensor x) { torch::Tensor forward(torch::Tensor x) {
...@@ -53,11 +58,11 @@ public: ...@@ -53,11 +58,11 @@ public:
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4) // activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4)
constexpr int BLOCK_M = 256; constexpr int BLOCK_M = 256;
constexpr int WARP_K = 64; constexpr int WARP_K = 64;
constexpr int NUM_WARPS = 8; constexpr int NUM_WARPS = 8;
constexpr int WARP_M_TILES = 2; constexpr int WARP_M_TILES = 2;
constexpr int WARP_SIZE = 32; constexpr int WARP_SIZE = 32;
std::stringstream ss; std::stringstream ss;
for (int bm = 0; bm < M / BLOCK_M; bm++) { for (int bm = 0; bm < M / BLOCK_M; bm++) {
...@@ -95,13 +100,10 @@ public: ...@@ -95,13 +100,10 @@ public:
x = x.contiguous(); x = x.contiguous();
auto qout = net->quantize( auto qout = net->quantize(from_torch(x), fuse_glu);
from_torch(x),
fuse_glu
);
Tensor act = qout.act.copy(Device::cpu()); Tensor act = qout.act.copy(Device::cpu());
Tensor ascales = qout.ascales.copy(Device::cpu()); Tensor ascales = qout.ascales.copy(Device::cpu());
Tensor lora_act = qout.lora_act.copy(Device::cpu()); Tensor lora_act = qout.lora_act.copy(Device::cpu());
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
...@@ -109,5 +111,4 @@ public: ...@@ -109,5 +111,4 @@ public:
spdlog::debug("act = {}", dumpTensorINT4(act)); spdlog::debug("act = {}", dumpTensorINT4(act));
spdlog::debug("ascales = {}", dumpTensorBF16(ascales)); spdlog::debug("ascales = {}", dumpTensorBF16(ascales));
} }
}; };
...@@ -10,13 +10,14 @@ class QuantizedGEMM88 : public ModuleWrapper<GEMM_W8A8> { ...@@ -10,13 +10,14 @@ class QuantizedGEMM88 : public ModuleWrapper<GEMM_W8A8> {
public: public:
void init(int64_t in_features, int64_t out_features, bool bias, bool bf16, int8_t deviceId) { void init(int64_t in_features, int64_t out_features, bool bias, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedGEMM88"); spdlog::info("Initializing QuantizedGEMM88");
size_t val = 0; size_t val = 0;
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192)); checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize)); checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val); spdlog::debug("Stack={}", val);
net = std::make_unique<GEMM_W8A8>((int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); net = std::make_unique<GEMM_W8A8>(
(int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
} }
torch::Tensor forward(torch::Tensor x) { torch::Tensor forward(torch::Tensor x) {
...@@ -27,10 +28,10 @@ public: ...@@ -27,10 +28,10 @@ public:
x = x.contiguous(); x = x.contiguous();
Tensor result = net->forward(from_torch(x)); Tensor result = net->forward(from_torch(x));
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
return output; return output;
} }
}; };
\ No newline at end of file
...@@ -18,7 +18,7 @@ public: ...@@ -18,7 +18,7 @@ public:
debugContext.reset(); debugContext.reset();
net.reset(); net.reset();
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
nunchaku::utils::trim_memory(); nunchaku::utils::trim_memory();
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
} }
...@@ -28,7 +28,7 @@ public: ...@@ -28,7 +28,7 @@ public:
CUDADeviceContext ctx(this->deviceId); CUDADeviceContext ctx(this->deviceId);
spdlog::info("{} weights from {}", partial ? "Loading partial" : "Loading", path); spdlog::info("{} weights from {}", partial ? "Loading partial" : "Loading", path);
std::shared_ptr<SafeTensors> provider = std::make_shared<SafeTensors>(path); std::shared_ptr<SafeTensors> provider = std::make_shared<SafeTensors>(path);
net->loadParams(*provider, partial); net->loadParams(*provider, partial);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
...@@ -41,7 +41,7 @@ public: ...@@ -41,7 +41,7 @@ public:
CUDADeviceContext ctx(this->deviceId); CUDADeviceContext ctx(this->deviceId);
spdlog::info("{} weights from pytorch", partial ? "Loading partial" : "Loading"); spdlog::info("{} weights from pytorch", partial ? "Loading partial" : "Loading");
std::shared_ptr<TensorsProviderTorch> provider = std::make_shared<TensorsProviderTorch>(std::move(dict)); std::shared_ptr<TensorsProviderTorch> provider = std::make_shared<TensorsProviderTorch>(std::move(dict));
net->loadParams(*provider, partial); net->loadParams(*provider, partial);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
...@@ -66,7 +66,7 @@ public: ...@@ -66,7 +66,7 @@ public:
result[key] = to_torch(value); result[key] = to_torch(value);
} }
} }
return result; return result;
} }
...@@ -82,4 +82,4 @@ protected: ...@@ -82,4 +82,4 @@ protected:
std::unique_ptr<DebugContext> debugContext; std::unique_ptr<DebugContext> debugContext;
int deviceId = -1; int deviceId = -1;
}; };
\ No newline at end of file
...@@ -7,175 +7,132 @@ ...@@ -7,175 +7,132 @@
namespace nunchaku::ops { namespace nunchaku::ops {
void gemm_w4a4( void gemm_w4a4(std::optional<torch::Tensor> act, // packed act [M, K / 2]
std::optional<torch::Tensor> act, // packed act [M, K / 2] std::optional<torch::Tensor> wgt, // packed act [N, K / 2]
std::optional<torch::Tensor> wgt, // packed act [N, K / 2] std::optional<torch::Tensor> out, // linear [M, N]
std::optional<torch::Tensor> out, // linear [M, N] std::optional<torch::Tensor> qout, // packed act [M, N / 2]
std::optional<torch::Tensor> qout, // packed act [M, N / 2] std::optional<torch::Tensor> ascales, // packed as [K / 64, M]
std::optional<torch::Tensor> ascales, // packed as [K / 64, M] std::optional<torch::Tensor> wscales, // packed ws [K / 64, N]
std::optional<torch::Tensor> wscales, // packed ws [K / 64, N] std::optional<torch::Tensor> oscales, // packed as [N / 64, M]
std::optional<torch::Tensor> oscales, // packed as [N / 64, M] std::optional<torch::Tensor> poolout, // linear [M / PoolSize, N]
std::optional<torch::Tensor> poolout, // linear [M / PoolSize, N] std::optional<torch::Tensor> lora_act_in, // packed lora_act [M, R]
std::optional<torch::Tensor> lora_act_in, // packed lora_act [M, R] std::optional<torch::Tensor> lora_up, // packed lora_wgt [N, R]
std::optional<torch::Tensor> lora_up, // packed lora_wgt [N, R] std::optional<torch::Tensor> lora_down, // packed lora_wgt [N, R]
std::optional<torch::Tensor> lora_down, // packed lora_wgt [N, R] std::optional<torch::Tensor> lora_act_out, // packed lora_act [M, R]
std::optional<torch::Tensor> lora_act_out, // packed lora_act [M, R] std::optional<torch::Tensor> norm_q, // linear [HEAD_DIM]
std::optional<torch::Tensor> norm_q, // linear [HEAD_DIM] std::optional<torch::Tensor> norm_k, // linear [HEAD_DIM]
std::optional<torch::Tensor> norm_k, // linear [HEAD_DIM] std::optional<torch::Tensor> rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
std::optional<torch::Tensor> rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2] std::optional<torch::Tensor> bias, // packed ws [N]
std::optional<torch::Tensor> bias, // packed ws [N] std::optional<torch::Tensor> smooth_factor, // packed ws [N], for quantization of the next layer
std::optional<torch::Tensor> smooth_factor, // packed ws [N], for quantization of the next layer std::optional<torch::Tensor> out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
std::optional<torch::Tensor> out_vk, // linear [B, num_heads, head_dim + 1, head_dim] std::optional<torch::Tensor> out_linearattn, // linear [B, (M), N / 3]
std::optional<torch::Tensor> out_linearattn,// linear [B, (M), N / 3] bool act_unsigned,
bool act_unsigned, std::vector<float> lora_scales,
std::vector<float> lora_scales, bool fuse_silu,
bool fuse_silu, bool fp4,
bool fp4, float alpha,
float alpha, std::optional<torch::Tensor> wcscales,
std::optional<torch::Tensor> wcscales, std::optional<torch::Tensor> out_q, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_q, // packed attention [B, H, M, D] std::optional<torch::Tensor> out_k, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_k, // packed attention [B, H, M, D] std::optional<torch::Tensor> out_v, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_v, // packed attention [B, H, M, D] int attn_tokens) {
int attn_tokens spdlog::trace("running gemm_w4a4: ");
) {
spdlog::trace("running gemm_w4a4: ");
auto getTensor = [](std::optional<torch::Tensor> &t) { auto getTensor = [](std::optional<torch::Tensor> &t) {
Tensor ret = t.has_value() ? from_torch(t.value()) : Tensor{}; Tensor ret = t.has_value() ? from_torch(t.value()) : Tensor{};
if (ret.valid()) { if (ret.valid()) {
spdlog::trace(" {}", ret.shape.str()); spdlog::trace(" {}", ret.shape.str());
} else { } else {
spdlog::trace(" <invalid>"); spdlog::trace(" <invalid>");
} }
return ret; return ret;
}; };
nunchaku::kernels::gemm_w4a4( nunchaku::kernels::gemm_w4a4(getTensor(act),
getTensor(act ), getTensor(wgt),
getTensor(wgt ), getTensor(out),
getTensor(out ), getTensor(qout),
getTensor(qout ), getTensor(ascales),
getTensor(ascales ), getTensor(wscales),
getTensor(wscales ), getTensor(oscales),
getTensor(oscales ), getTensor(poolout),
getTensor(poolout ), getTensor(lora_act_in),
getTensor(lora_act_in ), getTensor(lora_up),
getTensor(lora_up ), getTensor(lora_down),
getTensor(lora_down ), getTensor(lora_act_out),
getTensor(lora_act_out ), getTensor(norm_q),
getTensor(norm_q ), getTensor(norm_k),
getTensor(norm_k ), getTensor(rotary_emb),
getTensor(rotary_emb ), getTensor(bias),
getTensor(bias ), getTensor(smooth_factor),
getTensor(smooth_factor), getTensor(out_vk),
getTensor(out_vk ), getTensor(out_linearattn),
getTensor(out_linearattn), act_unsigned,
act_unsigned, lora_scales,
lora_scales, fuse_silu,
fuse_silu, fp4,
fp4, alpha,
alpha, getTensor(wcscales),
getTensor(wcscales), getTensor(out_q),
getTensor(out_q), getTensor(out_k),
getTensor(out_k), getTensor(out_v),
getTensor(out_v), attn_tokens);
attn_tokens // Tensor::synchronizeDevice();
); }
// Tensor::synchronizeDevice();
}
void attention_fp16( void attention_fp16(torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM] torch::Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
torch::Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM] torch::Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
torch::Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM] torch::Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
torch::Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM] float scale) {
float scale nunchaku::kernels::attention_fp16(from_torch(q), from_torch(k), from_torch(v), from_torch(o), scale);
) { }
nunchaku::kernels::attention_fp16(
from_torch(q),
from_torch(k),
from_torch(v),
from_torch(o),
scale
);
}
torch::Tensor gemv_awq( torch::Tensor gemv_awq(torch::Tensor _in_feats,
torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _kernel, torch::Tensor _scaling_factors,
torch::Tensor _scaling_factors, torch::Tensor _zeros,
torch::Tensor _zeros, int64_t m,
int64_t m, int64_t n,
int64_t n, int64_t k,
int64_t k, int64_t group_size) {
int64_t group_size) Tensor result = ::gemv_awq(from_torch(_in_feats.contiguous()),
{ from_torch(_kernel.contiguous()),
Tensor result = ::gemv_awq( from_torch(_scaling_factors.contiguous()),
from_torch(_in_feats.contiguous()), from_torch(_zeros.contiguous()),
from_torch(_kernel.contiguous()), (int)m,
from_torch(_scaling_factors.contiguous()), (int)n,
from_torch(_zeros.contiguous()), (int)k,
(int)m, (int)group_size);
(int)n,
(int)k,
(int)group_size
);
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice(); // Tensor::synchronizeDevice();
return output; return output;
} }
torch::Tensor gemm_awq( torch::Tensor
torch::Tensor _in_feats, gemm_awq(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros) {
torch::Tensor _kernel, Tensor result = ::awq_gemm_forward_cuda(from_torch(_in_feats.contiguous()),
torch::Tensor _scaling_factors, from_torch(_kernel.contiguous()),
torch::Tensor _zeros) from_torch(_scaling_factors.contiguous()),
{ from_torch(_zeros.contiguous()));
Tensor result = ::awq_gemm_forward_cuda(
from_torch(_in_feats.contiguous()),
from_torch(_kernel.contiguous()),
from_torch(_scaling_factors.contiguous()),
from_torch(_zeros.contiguous())
);
// TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy) // TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy)
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice(); // Tensor::synchronizeDevice();
return output; return output;
} }
void test_rmsnorm_rope( void test_rmsnorm_rope(
torch::Tensor input, torch::Tensor input, torch::Tensor output, torch::Tensor norm_q, torch::Tensor norm_k, torch::Tensor rotary_emb) {
torch::Tensor output, nunchaku::kernels::test_rmsnorm_rope(
torch::Tensor norm_q, from_torch(input), from_torch(output), from_torch(norm_q), from_torch(norm_k), from_torch(rotary_emb));
torch::Tensor norm_k, }
torch::Tensor rotary_emb)
{
nunchaku::kernels::test_rmsnorm_rope(
from_torch(input),
from_torch(output),
from_torch(norm_q),
from_torch(norm_k),
from_torch(rotary_emb)
);
}
void test_pack_qkv( void test_pack_qkv(torch::Tensor input, torch::Tensor out_q, torch::Tensor out_k, torch::Tensor out_v, int numTokens) {
torch::Tensor input, nunchaku::kernels::test_pack_qkv(
torch::Tensor out_q, from_torch(input), from_torch(out_q), from_torch(out_k), from_torch(out_v), numTokens);
torch::Tensor out_k, }
torch::Tensor out_v,
int numTokens) }; // namespace nunchaku::ops
{
nunchaku::kernels::test_pack_qkv(
from_torch(input),
from_torch(out_q),
from_torch(out_k),
from_torch(out_v),
numTokens
);
}
};
\ No newline at end of file
...@@ -11,49 +11,44 @@ ...@@ -11,49 +11,44 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel") py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel")
.def(py::init<>()) .def(py::init<>())
.def("init", &QuantizedFluxModel::init, .def("init",
py::arg("use_fp4"), &QuantizedFluxModel::init,
py::arg("offload"), py::arg("use_fp4"),
py::arg("bf16"), py::arg("offload"),
py::arg("deviceId") py::arg("bf16"),
) py::arg("deviceId"))
.def("set_residual_callback", [](QuantizedFluxModel &self, pybind11::object call_back) { .def("set_residual_callback",
if (call_back.is_none()) { [](QuantizedFluxModel &self, pybind11::object call_back) {
self.set_residual_callback(pybind11::function()); if (call_back.is_none()) {
} else { self.set_residual_callback(pybind11::function());
self.set_residual_callback(call_back); } else {
} self.set_residual_callback(call_back);
}) }
})
.def("reset", &QuantizedFluxModel::reset) .def("reset", &QuantizedFluxModel::reset)
.def("load", &QuantizedFluxModel::load, .def("load", &QuantizedFluxModel::load, py::arg("path"), py::arg("partial") = false)
py::arg("path"), .def("loadDict", &QuantizedFluxModel::loadDict, py::arg("dict"), py::arg("partial") = false)
py::arg("partial") = false .def("forward",
) &QuantizedFluxModel::forward,
.def("loadDict", &QuantizedFluxModel::loadDict, py::arg("hidden_states"),
py::arg("dict"), py::arg("encoder_hidden_states"),
py::arg("partial") = false py::arg("temb"),
) py::arg("rotary_emb_img"),
.def("forward", &QuantizedFluxModel::forward, py::arg("rotary_emb_context"),
py::arg("hidden_states"), py::arg("rotary_emb_single"),
py::arg("encoder_hidden_states"), py::arg("controlnet_block_samples") = py::none(),
py::arg("temb"), py::arg("controlnet_single_block_samples") = py::none(),
py::arg("rotary_emb_img"), py::arg("skip_first_layer") = false)
py::arg("rotary_emb_context"), .def("forward_layer",
py::arg("rotary_emb_single"), &QuantizedFluxModel::forward_layer,
py::arg("controlnet_block_samples") = py::none(), py::arg("idx"),
py::arg("controlnet_single_block_samples") = py::none(), py::arg("hidden_states"),
py::arg("skip_first_layer") = false py::arg("encoder_hidden_states"),
) py::arg("temb"),
.def("forward_layer", &QuantizedFluxModel::forward_layer, py::arg("rotary_emb_img"),
py::arg("idx"), py::arg("rotary_emb_context"),
py::arg("hidden_states"), py::arg("controlnet_block_samples") = py::none(),
py::arg("encoder_hidden_states"), py::arg("controlnet_single_block_samples") = py::none())
py::arg("temb"),
py::arg("rotary_emb_img"),
py::arg("rotary_emb_context"),
py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none()
)
.def("forward_single_layer", &QuantizedFluxModel::forward_single_layer) .def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
.def("norm_one_forward", &QuantizedFluxModel::norm_one_forward) .def("norm_one_forward", &QuantizedFluxModel::norm_one_forward)
.def("startDebug", &QuantizedFluxModel::startDebug) .def("startDebug", &QuantizedFluxModel::startDebug)
...@@ -61,32 +56,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -61,32 +56,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("getDebugResults", &QuantizedFluxModel::getDebugResults) .def("getDebugResults", &QuantizedFluxModel::getDebugResults)
.def("setLoraScale", &QuantizedFluxModel::setLoraScale) .def("setLoraScale", &QuantizedFluxModel::setLoraScale)
.def("setAttentionImpl", &QuantizedFluxModel::setAttentionImpl) .def("setAttentionImpl", &QuantizedFluxModel::setAttentionImpl)
.def("isBF16", &QuantizedFluxModel::isBF16) .def("isBF16", &QuantizedFluxModel::isBF16);
;
py::class_<QuantizedSanaModel>(m, "QuantizedSanaModel") py::class_<QuantizedSanaModel>(m, "QuantizedSanaModel")
.def(py::init<>()) .def(py::init<>())
.def("init", &QuantizedSanaModel::init, .def("init",
py::arg("config"), &QuantizedSanaModel::init,
py::arg("pag_layers"), py::arg("config"),
py::arg("use_fp4"), py::arg("pag_layers"),
py::arg("bf16"), py::arg("use_fp4"),
py::arg("deviceId") py::arg("bf16"),
) py::arg("deviceId"))
.def("reset", &QuantizedSanaModel::reset) .def("reset", &QuantizedSanaModel::reset)
.def("load", &QuantizedSanaModel::load, .def("load", &QuantizedSanaModel::load, py::arg("path"), py::arg("partial") = false)
py::arg("path"), .def("loadDict", &QuantizedSanaModel::loadDict, py::arg("dict"), py::arg("partial") = false)
py::arg("partial") = false
)
.def("loadDict", &QuantizedSanaModel::loadDict,
py::arg("dict"),
py::arg("partial") = false
)
.def("forward", &QuantizedSanaModel::forward) .def("forward", &QuantizedSanaModel::forward)
.def("forward_layer", &QuantizedSanaModel::forward_layer) .def("forward_layer", &QuantizedSanaModel::forward_layer)
.def("startDebug", &QuantizedSanaModel::startDebug) .def("startDebug", &QuantizedSanaModel::startDebug)
.def("stopDebug", &QuantizedSanaModel::stopDebug) .def("stopDebug", &QuantizedSanaModel::stopDebug)
.def("getDebugResults", &QuantizedSanaModel::getDebugResults) .def("getDebugResults", &QuantizedSanaModel::getDebugResults);
;
py::class_<QuantizedGEMM>(m, "QuantizedGEMM") py::class_<QuantizedGEMM>(m, "QuantizedGEMM")
.def(py::init<>()) .def(py::init<>())
.def("init", &QuantizedGEMM::init) .def("init", &QuantizedGEMM::init)
...@@ -96,8 +83,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -96,8 +83,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("quantize", &QuantizedGEMM::quantize) .def("quantize", &QuantizedGEMM::quantize)
.def("startDebug", &QuantizedGEMM::startDebug) .def("startDebug", &QuantizedGEMM::startDebug)
.def("stopDebug", &QuantizedGEMM::stopDebug) .def("stopDebug", &QuantizedGEMM::stopDebug)
.def("getDebugResults", &QuantizedGEMM::getDebugResults) .def("getDebugResults", &QuantizedGEMM::getDebugResults);
;
py::class_<Tensor>(m, "Tensor"); py::class_<Tensor>(m, "Tensor");
py::class_<QuantizedGEMM88>(m, "QuantizedGEMM88") py::class_<QuantizedGEMM88>(m, "QuantizedGEMM88")
.def(py::init<>()) .def(py::init<>())
...@@ -107,8 +93,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -107,8 +93,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("forward", &QuantizedGEMM88::forward) .def("forward", &QuantizedGEMM88::forward)
.def("startDebug", &QuantizedGEMM88::startDebug) .def("startDebug", &QuantizedGEMM88::startDebug)
.def("stopDebug", &QuantizedGEMM88::stopDebug) .def("stopDebug", &QuantizedGEMM88::stopDebug)
.def("getDebugResults", &QuantizedGEMM88::getDebugResults) .def("getDebugResults", &QuantizedGEMM88::getDebugResults);
;
m.def_submodule("ops") m.def_submodule("ops")
.def("gemm_w4a4", nunchaku::ops::gemm_w4a4) .def("gemm_w4a4", nunchaku::ops::gemm_w4a4)
...@@ -117,16 +102,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -117,16 +102,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("gemv_awq", nunchaku::ops::gemv_awq) .def("gemv_awq", nunchaku::ops::gemv_awq)
.def("test_rmsnorm_rope", nunchaku::ops::test_rmsnorm_rope) .def("test_rmsnorm_rope", nunchaku::ops::test_rmsnorm_rope)
.def("test_pack_qkv", nunchaku::ops::test_pack_qkv) .def("test_pack_qkv", nunchaku::ops::test_pack_qkv);
;
m.def_submodule("utils") m.def_submodule("utils")
.def("set_log_level", [](const std::string &level) { .def("set_log_level", [](const std::string &level) { spdlog::set_level(spdlog::level::from_str(level)); })
spdlog::set_level(spdlog::level::from_str(level));
})
.def("set_cuda_stack_limit", nunchaku::utils::set_cuda_stack_limit) .def("set_cuda_stack_limit", nunchaku::utils::set_cuda_stack_limit)
.def("disable_memory_auto_release", nunchaku::utils::disable_memory_auto_release) .def("disable_memory_auto_release", nunchaku::utils::disable_memory_auto_release)
.def("trim_memory", nunchaku::utils::trim_memory) .def("trim_memory", nunchaku::utils::trim_memory)
.def("set_faster_i2f_mode", nunchaku::utils::set_faster_i2f_mode) .def("set_faster_i2f_mode", nunchaku::utils::set_faster_i2f_mode);
;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment