Unverified Commit 57e50f8d authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

style: upgrade the linter (#339)

* style: reformated codes

* style: reformated codes
parent b737368d
import torch
from diffusers import FluxPipeline
from peft.tuners import lora
from vars import LORA_PATHS, SVDQ_LORA_PATHS
from nunchaku import NunchakuFluxTransformer2dModel
from vars import LORA_PATHS, SVDQ_LORA_PATHS
def hash_str_to_int(s: str) -> int:
......
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg"
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/>
<a href='https://nvlabs.github.io/Sana/' target="_blank">SANA-1.6B</a> Demo
......
......@@ -2,7 +2,6 @@ import argparse
import os
import torch
from utils import get_pipeline
......
......@@ -4,7 +4,6 @@ import time
import torch
from torch import nn
from tqdm import trange
from utils import get_pipeline
......
......@@ -8,13 +8,13 @@ from datetime import datetime
import GPUtil
import spaces
import torch
from nunchaku.models.safety_checker import SafetyChecker
from utils import get_pipeline
from vars import EXAMPLES, MAX_SEED
from nunchaku.models.safety_checker import SafetyChecker
# import gradio last to avoid conflicts with other imports
import gradio as gr
import gradio as gr # noqa: isort: skip
def get_args() -> argparse.Namespace:
......@@ -73,7 +73,7 @@ def generate(
prompt = "A peaceful world."
images, latency_strs = [], []
for i, pipeline in enumerate(pipelines):
progress = gr.Progress(track_tqdm=True)
gr.Progress(track_tqdm=True)
start_time = time.time()
image = pipeline(
prompt=prompt,
......@@ -124,11 +124,11 @@ if len(gpus) > 0:
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
with gr.Blocks(
css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"],
title=f"SVDQuant SANA-1600M Demo",
title="SVDQuant SANA-1600M Demo",
) as demo:
def get_header_str():
......
......@@ -4,7 +4,6 @@ from diffusers.models import FluxMultiControlNetModel
from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.diffusers_adapters.flux import apply_cache_on_pipe
from nunchaku.utils import get_gpu_memory, get_precision
base_model = "black-forest-labs/FLUX.1-dev"
......@@ -29,11 +28,6 @@ if need_offload:
else:
pipeline = pipeline.to("cuda")
# apply_cache_on_pipe(
# pipeline, residual_diff_threshold=0.1
# ) # Uncomment this line to enable first-block cache to speedup generation
prompt = "A anime style girl with messy beach waves."
control_image_depth = load_image(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg"
......
......@@ -7,14 +7,10 @@ from nunchaku.utils import get_precision
precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-dev"
)
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
apply_cache_on_pipe(
......@@ -24,9 +20,6 @@ apply_cache_on_pipe(
residual_diff_threshold_single=0.12,
)
image = pipeline(
["A cat holding a sign that says hello world"],
num_inference_steps=50
).images[0]
image = pipeline(["A cat holding a sign that says hello world"], num_inference_steps=50).images[0]
image.save(f"flux.1-dev-cache-{precision}.png")
from .models import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel, NunchakuT5EncoderModel
__all__ = ["NunchakuFluxTransformer2dModel", "NunchakuSanaTransformer2DModel", "NunchakuT5EncoderModel"]
......@@ -20,7 +20,8 @@ public:
ModuleWrapper::init(deviceId);
CUDADeviceContext ctx(this->deviceId);
net = std::make_unique<FluxModel>(use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
net = std::make_unique<FluxModel>(
use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
bool isBF16() {
......@@ -32,7 +33,7 @@ public:
pybind11::gil_scoped_acquire gil;
if (!callback || callback.is_none()) {
residual_callback = pybind11::function();
if (net){
if (net) {
net->set_residual_callback(nullptr);
}
return;
......@@ -52,8 +53,7 @@ public:
}
}
torch::Tensor forward(
torch::Tensor hidden_states,
torch::Tensor forward(torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_img,
......@@ -61,8 +61,7 @@ public:
torch::Tensor rotary_emb_single,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt,
bool skip_first_layer = false)
{
bool skip_first_layer = false) {
checkModel();
CUDADeviceContext ctx(deviceId);
......@@ -83,9 +82,10 @@ public:
from_torch(rotary_emb_context),
from_torch(rotary_emb_single),
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value() ? from_torch(controlnet_single_block_samples.value().contiguous()) : Tensor{},
skip_first_layer
);
controlnet_single_block_samples.has_value()
? from_torch(controlnet_single_block_samples.value().contiguous())
: Tensor{},
skip_first_layer);
torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice();
......@@ -93,16 +93,15 @@ public:
return output;
}
std::tuple<torch::Tensor, torch::Tensor> forward_layer(
int64_t idx,
std::tuple<torch::Tensor, torch::Tensor>
forward_layer(int64_t idx,
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt)
{
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt) {
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward_layer {}", idx);
......@@ -121,22 +120,21 @@ public:
from_torch(rotary_emb_img),
from_torch(rotary_emb_context),
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value() ? from_torch(controlnet_single_block_samples.value().contiguous()) : Tensor{}
);
controlnet_single_block_samples.has_value()
? from_torch(controlnet_single_block_samples.value().contiguous())
: Tensor{});
hidden_states = to_torch(hidden_states_);
encoder_hidden_states = to_torch(encoder_hidden_states_);
Tensor::synchronizeDevice();
return { hidden_states, encoder_hidden_states };
return {hidden_states, encoder_hidden_states};
}
torch::Tensor forward_single_layer(
int64_t idx,
torch::Tensor forward_single_layer(int64_t idx,
torch::Tensor hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_single)
{
torch::Tensor rotary_emb_single) {
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx);
......@@ -146,10 +144,7 @@ public:
rotary_emb_single = rotary_emb_single.contiguous();
Tensor result = net->single_transformer_blocks.at(idx)->forward(
from_torch(hidden_states),
from_torch(temb),
from_torch(rotary_emb_single)
);
from_torch(hidden_states), from_torch(temb), from_torch(rotary_emb_single));
hidden_states = to_torch(result);
Tensor::synchronizeDevice();
......@@ -159,19 +154,15 @@ public:
// expose the norm1 forward method of the transformer blocks
// this is used by TeaCache to get the norm1 output
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> norm_one_forward(
int64_t idx,
torch::Tensor hidden_states,
torch::Tensor temb
) {
AdaLayerNormZero::Output result = net->transformer_blocks.at(idx)->norm1.forward(from_torch(hidden_states), from_torch(temb));
return {
to_torch(result.x),
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
norm_one_forward(int64_t idx, torch::Tensor hidden_states, torch::Tensor temb) {
AdaLayerNormZero::Output result =
net->transformer_blocks.at(idx)->norm1.forward(from_torch(hidden_states), from_torch(temb));
return {to_torch(result.x),
to_torch(result.gate_msa),
to_torch(result.shift_mlp),
to_torch(result.scale_mlp),
to_torch(result.gate_mlp)
};
to_torch(result.gate_mlp)};
}
// must be called after loading lora
......@@ -214,5 +205,4 @@ public:
throw std::invalid_argument(spdlog::fmt_lib::format("Invalid attention implementation {}", name));
}
}
};
......@@ -16,7 +16,12 @@ public:
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val);
net = std::make_unique<GEMM_W4A4>((int)in_features, (int)out_features, bias, use_fp4, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
net = std::make_unique<GEMM_W4A4>((int)in_features,
(int)out_features,
bias,
use_fp4,
bf16 ? Tensor::BF16 : Tensor::FP16,
Device::cuda((int)deviceId));
}
torch::Tensor forward(torch::Tensor x) {
......@@ -95,10 +100,7 @@ public:
x = x.contiguous();
auto qout = net->quantize(
from_torch(x),
fuse_glu
);
auto qout = net->quantize(from_torch(x), fuse_glu);
Tensor act = qout.act.copy(Device::cpu());
Tensor ascales = qout.ascales.copy(Device::cpu());
......@@ -109,5 +111,4 @@ public:
spdlog::debug("act = {}", dumpTensorINT4(act));
spdlog::debug("ascales = {}", dumpTensorBF16(ascales));
}
};
......@@ -16,7 +16,8 @@ public:
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val);
net = std::make_unique<GEMM_W8A8>((int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
net = std::make_unique<GEMM_W8A8>(
(int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
torch::Tensor forward(torch::Tensor x) {
......
......@@ -7,8 +7,7 @@
namespace nunchaku::ops {
void gemm_w4a4(
std::optional<torch::Tensor> act, // packed act [M, K / 2]
void gemm_w4a4(std::optional<torch::Tensor> act, // packed act [M, K / 2]
std::optional<torch::Tensor> wgt, // packed act [N, K / 2]
std::optional<torch::Tensor> out, // linear [M, N]
std::optional<torch::Tensor> qout, // packed act [M, N / 2]
......@@ -26,7 +25,7 @@ namespace nunchaku::ops {
std::optional<torch::Tensor> bias, // packed ws [N]
std::optional<torch::Tensor> smooth_factor, // packed ws [N], for quantization of the next layer
std::optional<torch::Tensor> out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
std::optional<torch::Tensor> out_linearattn,// linear [B, (M), N / 3]
std::optional<torch::Tensor> out_linearattn, // linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales,
bool fuse_silu,
......@@ -36,8 +35,7 @@ namespace nunchaku::ops {
std::optional<torch::Tensor> out_q, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_k, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_v, // packed attention [B, H, M, D]
int attn_tokens
) {
int attn_tokens) {
spdlog::trace("running gemm_w4a4: ");
auto getTensor = [](std::optional<torch::Tensor> &t) {
......@@ -49,25 +47,24 @@ namespace nunchaku::ops {
}
return ret;
};
nunchaku::kernels::gemm_w4a4(
getTensor(act ),
getTensor(wgt ),
getTensor(out ),
getTensor(qout ),
getTensor(ascales ),
getTensor(wscales ),
getTensor(oscales ),
getTensor(poolout ),
getTensor(lora_act_in ),
getTensor(lora_up ),
getTensor(lora_down ),
getTensor(lora_act_out ),
getTensor(norm_q ),
getTensor(norm_k ),
getTensor(rotary_emb ),
getTensor(bias ),
nunchaku::kernels::gemm_w4a4(getTensor(act),
getTensor(wgt),
getTensor(out),
getTensor(qout),
getTensor(ascales),
getTensor(wscales),
getTensor(oscales),
getTensor(poolout),
getTensor(lora_act_in),
getTensor(lora_up),
getTensor(lora_down),
getTensor(lora_act_out),
getTensor(norm_q),
getTensor(norm_k),
getTensor(rotary_emb),
getTensor(bias),
getTensor(smooth_factor),
getTensor(out_vk ),
getTensor(out_vk),
getTensor(out_linearattn),
act_unsigned,
lora_scales,
......@@ -78,104 +75,64 @@ namespace nunchaku::ops {
getTensor(out_q),
getTensor(out_k),
getTensor(out_v),
attn_tokens
);
attn_tokens);
// Tensor::synchronizeDevice();
}
}
void attention_fp16(
torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
void attention_fp16(torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
torch::Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
torch::Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
torch::Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
float scale
) {
nunchaku::kernels::attention_fp16(
from_torch(q),
from_torch(k),
from_torch(v),
from_torch(o),
scale
);
}
float scale) {
nunchaku::kernels::attention_fp16(from_torch(q), from_torch(k), from_torch(v), from_torch(o), scale);
}
torch::Tensor gemv_awq(
torch::Tensor _in_feats,
torch::Tensor gemv_awq(torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int64_t m,
int64_t n,
int64_t k,
int64_t group_size)
{
Tensor result = ::gemv_awq(
from_torch(_in_feats.contiguous()),
int64_t group_size) {
Tensor result = ::gemv_awq(from_torch(_in_feats.contiguous()),
from_torch(_kernel.contiguous()),
from_torch(_scaling_factors.contiguous()),
from_torch(_zeros.contiguous()),
(int)m,
(int)n,
(int)k,
(int)group_size
);
(int)group_size);
torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice();
return output;
}
}
torch::Tensor gemm_awq(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros)
{
Tensor result = ::awq_gemm_forward_cuda(
from_torch(_in_feats.contiguous()),
torch::Tensor
gemm_awq(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros) {
Tensor result = ::awq_gemm_forward_cuda(from_torch(_in_feats.contiguous()),
from_torch(_kernel.contiguous()),
from_torch(_scaling_factors.contiguous()),
from_torch(_zeros.contiguous())
);
from_torch(_zeros.contiguous()));
// TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy)
torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice();
return output;
}
}
void test_rmsnorm_rope(
torch::Tensor input,
torch::Tensor output,
torch::Tensor norm_q,
torch::Tensor norm_k,
torch::Tensor rotary_emb)
{
void test_rmsnorm_rope(
torch::Tensor input, torch::Tensor output, torch::Tensor norm_q, torch::Tensor norm_k, torch::Tensor rotary_emb) {
nunchaku::kernels::test_rmsnorm_rope(
from_torch(input),
from_torch(output),
from_torch(norm_q),
from_torch(norm_k),
from_torch(rotary_emb)
);
}
from_torch(input), from_torch(output), from_torch(norm_q), from_torch(norm_k), from_torch(rotary_emb));
}
void test_pack_qkv(
torch::Tensor input,
torch::Tensor out_q,
torch::Tensor out_k,
torch::Tensor out_v,
int numTokens)
{
void test_pack_qkv(torch::Tensor input, torch::Tensor out_q, torch::Tensor out_k, torch::Tensor out_v, int numTokens) {
nunchaku::kernels::test_pack_qkv(
from_torch(input),
from_torch(out_q),
from_torch(out_k),
from_torch(out_v),
numTokens
);
}
from_torch(input), from_torch(out_q), from_torch(out_k), from_torch(out_v), numTokens);
}
};
\ No newline at end of file
}; // namespace nunchaku::ops
......@@ -11,13 +11,14 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel")
.def(py::init<>())
.def("init", &QuantizedFluxModel::init,
.def("init",
&QuantizedFluxModel::init,
py::arg("use_fp4"),
py::arg("offload"),
py::arg("bf16"),
py::arg("deviceId")
)
.def("set_residual_callback", [](QuantizedFluxModel &self, pybind11::object call_back) {
py::arg("deviceId"))
.def("set_residual_callback",
[](QuantizedFluxModel &self, pybind11::object call_back) {
if (call_back.is_none()) {
self.set_residual_callback(pybind11::function());
} else {
......@@ -25,15 +26,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
}
})
.def("reset", &QuantizedFluxModel::reset)
.def("load", &QuantizedFluxModel::load,
py::arg("path"),
py::arg("partial") = false
)
.def("loadDict", &QuantizedFluxModel::loadDict,
py::arg("dict"),
py::arg("partial") = false
)
.def("forward", &QuantizedFluxModel::forward,
.def("load", &QuantizedFluxModel::load, py::arg("path"), py::arg("partial") = false)
.def("loadDict", &QuantizedFluxModel::loadDict, py::arg("dict"), py::arg("partial") = false)
.def("forward",
&QuantizedFluxModel::forward,
py::arg("hidden_states"),
py::arg("encoder_hidden_states"),
py::arg("temb"),
......@@ -42,9 +38,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("rotary_emb_single"),
py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none(),
py::arg("skip_first_layer") = false
)
.def("forward_layer", &QuantizedFluxModel::forward_layer,
py::arg("skip_first_layer") = false)
.def("forward_layer",
&QuantizedFluxModel::forward_layer,
py::arg("idx"),
py::arg("hidden_states"),
py::arg("encoder_hidden_states"),
......@@ -52,8 +48,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("rotary_emb_img"),
py::arg("rotary_emb_context"),
py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none()
)
py::arg("controlnet_single_block_samples") = py::none())
.def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
.def("norm_one_forward", &QuantizedFluxModel::norm_one_forward)
.def("startDebug", &QuantizedFluxModel::startDebug)
......@@ -61,32 +56,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("getDebugResults", &QuantizedFluxModel::getDebugResults)
.def("setLoraScale", &QuantizedFluxModel::setLoraScale)
.def("setAttentionImpl", &QuantizedFluxModel::setAttentionImpl)
.def("isBF16", &QuantizedFluxModel::isBF16)
;
.def("isBF16", &QuantizedFluxModel::isBF16);
py::class_<QuantizedSanaModel>(m, "QuantizedSanaModel")
.def(py::init<>())
.def("init", &QuantizedSanaModel::init,
.def("init",
&QuantizedSanaModel::init,
py::arg("config"),
py::arg("pag_layers"),
py::arg("use_fp4"),
py::arg("bf16"),
py::arg("deviceId")
)
py::arg("deviceId"))
.def("reset", &QuantizedSanaModel::reset)
.def("load", &QuantizedSanaModel::load,
py::arg("path"),
py::arg("partial") = false
)
.def("loadDict", &QuantizedSanaModel::loadDict,
py::arg("dict"),
py::arg("partial") = false
)
.def("load", &QuantizedSanaModel::load, py::arg("path"), py::arg("partial") = false)
.def("loadDict", &QuantizedSanaModel::loadDict, py::arg("dict"), py::arg("partial") = false)
.def("forward", &QuantizedSanaModel::forward)
.def("forward_layer", &QuantizedSanaModel::forward_layer)
.def("startDebug", &QuantizedSanaModel::startDebug)
.def("stopDebug", &QuantizedSanaModel::stopDebug)
.def("getDebugResults", &QuantizedSanaModel::getDebugResults)
;
.def("getDebugResults", &QuantizedSanaModel::getDebugResults);
py::class_<QuantizedGEMM>(m, "QuantizedGEMM")
.def(py::init<>())
.def("init", &QuantizedGEMM::init)
......@@ -96,8 +83,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("quantize", &QuantizedGEMM::quantize)
.def("startDebug", &QuantizedGEMM::startDebug)
.def("stopDebug", &QuantizedGEMM::stopDebug)
.def("getDebugResults", &QuantizedGEMM::getDebugResults)
;
.def("getDebugResults", &QuantizedGEMM::getDebugResults);
py::class_<Tensor>(m, "Tensor");
py::class_<QuantizedGEMM88>(m, "QuantizedGEMM88")
.def(py::init<>())
......@@ -107,8 +93,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("forward", &QuantizedGEMM88::forward)
.def("startDebug", &QuantizedGEMM88::startDebug)
.def("stopDebug", &QuantizedGEMM88::stopDebug)
.def("getDebugResults", &QuantizedGEMM88::getDebugResults)
;
.def("getDebugResults", &QuantizedGEMM88::getDebugResults);
m.def_submodule("ops")
.def("gemm_w4a4", nunchaku::ops::gemm_w4a4)
......@@ -117,16 +102,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("gemv_awq", nunchaku::ops::gemv_awq)
.def("test_rmsnorm_rope", nunchaku::ops::test_rmsnorm_rope)
.def("test_pack_qkv", nunchaku::ops::test_pack_qkv)
;
.def("test_pack_qkv", nunchaku::ops::test_pack_qkv);
m.def_submodule("utils")
.def("set_log_level", [](const std::string &level) {
spdlog::set_level(spdlog::level::from_str(level));
})
.def("set_log_level", [](const std::string &level) { spdlog::set_level(spdlog::level::from_str(level)); })
.def("set_cuda_stack_limit", nunchaku::utils::set_cuda_stack_limit)
.def("disable_memory_auto_release", nunchaku::utils::disable_memory_auto_release)
.def("trim_memory", nunchaku::utils::trim_memory)
.def("set_faster_i2f_mode", nunchaku::utils::set_faster_i2f_mode)
;
.def("set_faster_i2f_mode", nunchaku::utils::set_faster_i2f_mode);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment