Unverified Commit 57e50f8d authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

style: upgrade the linter (#339)

* style: reformated codes

* style: reformated codes
parent b737368d
import torch import torch
from diffusers import FluxPipeline from diffusers import FluxPipeline
from peft.tuners import lora from peft.tuners import lora
from vars import LORA_PATHS, SVDQ_LORA_PATHS
from nunchaku import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
from vars import LORA_PATHS, SVDQ_LORA_PATHS
def hash_str_to_int(s: str) -> int: def hash_str_to_int(s: str) -> int:
......
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div> <div>
<h1> <h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg" <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="logo" alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/> style="height: 40px; width: auto; display: block; margin: auto;"/>
<a href='https://nvlabs.github.io/Sana/' target="_blank">SANA-1.6B</a> Demo <a href='https://nvlabs.github.io/Sana/' target="_blank">SANA-1.6B</a> Demo
......
...@@ -2,7 +2,6 @@ import argparse ...@@ -2,7 +2,6 @@ import argparse
import os import os
import torch import torch
from utils import get_pipeline from utils import get_pipeline
......
...@@ -4,7 +4,6 @@ import time ...@@ -4,7 +4,6 @@ import time
import torch import torch
from torch import nn from torch import nn
from tqdm import trange from tqdm import trange
from utils import get_pipeline from utils import get_pipeline
......
...@@ -8,13 +8,13 @@ from datetime import datetime ...@@ -8,13 +8,13 @@ from datetime import datetime
import GPUtil import GPUtil
import spaces import spaces
import torch import torch
from nunchaku.models.safety_checker import SafetyChecker
from utils import get_pipeline from utils import get_pipeline
from vars import EXAMPLES, MAX_SEED from vars import EXAMPLES, MAX_SEED
from nunchaku.models.safety_checker import SafetyChecker
# import gradio last to avoid conflicts with other imports # import gradio last to avoid conflicts with other imports
import gradio as gr import gradio as gr # noqa: isort: skip
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
...@@ -73,7 +73,7 @@ def generate( ...@@ -73,7 +73,7 @@ def generate(
prompt = "A peaceful world." prompt = "A peaceful world."
images, latency_strs = [], [] images, latency_strs = [], []
for i, pipeline in enumerate(pipelines): for i, pipeline in enumerate(pipelines):
progress = gr.Progress(track_tqdm=True) gr.Progress(track_tqdm=True)
start_time = time.time() start_time = time.time()
image = pipeline( image = pipeline(
prompt=prompt, prompt=prompt,
...@@ -124,11 +124,11 @@ if len(gpus) > 0: ...@@ -124,11 +124,11 @@ if len(gpus) > 0:
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."' notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
with gr.Blocks( with gr.Blocks(
css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"], css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"],
title=f"SVDQuant SANA-1600M Demo", title="SVDQuant SANA-1600M Demo",
) as demo: ) as demo:
def get_header_str(): def get_header_str():
......
...@@ -4,7 +4,6 @@ from diffusers.models import FluxMultiControlNetModel ...@@ -4,7 +4,6 @@ from diffusers.models import FluxMultiControlNetModel
from diffusers.utils import load_image from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.diffusers_adapters.flux import apply_cache_on_pipe
from nunchaku.utils import get_gpu_memory, get_precision from nunchaku.utils import get_gpu_memory, get_precision
base_model = "black-forest-labs/FLUX.1-dev" base_model = "black-forest-labs/FLUX.1-dev"
...@@ -29,11 +28,6 @@ if need_offload: ...@@ -29,11 +28,6 @@ if need_offload:
else: else:
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
# apply_cache_on_pipe(
# pipeline, residual_diff_threshold=0.1
# ) # Uncomment this line to enable first-block cache to speedup generation
prompt = "A anime style girl with messy beach waves." prompt = "A anime style girl with messy beach waves."
control_image_depth = load_image( control_image_depth = load_image(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg" "https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg"
......
...@@ -7,14 +7,10 @@ from nunchaku.utils import get_precision ...@@ -7,14 +7,10 @@ from nunchaku.utils import get_precision
precision = get_precision() precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
f"mit-han-lab/svdq-{precision}-flux.1-dev"
)
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
transformer=transformer,
torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
apply_cache_on_pipe( apply_cache_on_pipe(
...@@ -24,9 +20,6 @@ apply_cache_on_pipe( ...@@ -24,9 +20,6 @@ apply_cache_on_pipe(
residual_diff_threshold_single=0.12, residual_diff_threshold_single=0.12,
) )
image = pipeline( image = pipeline(["A cat holding a sign that says hello world"], num_inference_steps=50).images[0]
["A cat holding a sign that says hello world"],
num_inference_steps=50
).images[0]
image.save(f"flux.1-dev-cache-{precision}.png") image.save(f"flux.1-dev-cache-{precision}.png")
from .models import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel, NunchakuT5EncoderModel from .models import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel, NunchakuT5EncoderModel
__all__ = ["NunchakuFluxTransformer2dModel", "NunchakuSanaTransformer2DModel", "NunchakuT5EncoderModel"]
...@@ -20,7 +20,8 @@ public: ...@@ -20,7 +20,8 @@ public:
ModuleWrapper::init(deviceId); ModuleWrapper::init(deviceId);
CUDADeviceContext ctx(this->deviceId); CUDADeviceContext ctx(this->deviceId);
net = std::make_unique<FluxModel>(use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); net = std::make_unique<FluxModel>(
use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
} }
bool isBF16() { bool isBF16() {
...@@ -32,7 +33,7 @@ public: ...@@ -32,7 +33,7 @@ public:
pybind11::gil_scoped_acquire gil; pybind11::gil_scoped_acquire gil;
if (!callback || callback.is_none()) { if (!callback || callback.is_none()) {
residual_callback = pybind11::function(); residual_callback = pybind11::function();
if (net){ if (net) {
net->set_residual_callback(nullptr); net->set_residual_callback(nullptr);
} }
return; return;
...@@ -52,8 +53,7 @@ public: ...@@ -52,8 +53,7 @@ public:
} }
} }
torch::Tensor forward( torch::Tensor forward(torch::Tensor hidden_states,
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states, torch::Tensor encoder_hidden_states,
torch::Tensor temb, torch::Tensor temb,
torch::Tensor rotary_emb_img, torch::Tensor rotary_emb_img,
...@@ -61,8 +61,7 @@ public: ...@@ -61,8 +61,7 @@ public:
torch::Tensor rotary_emb_single, torch::Tensor rotary_emb_single,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt, std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt, std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt,
bool skip_first_layer = false) bool skip_first_layer = false) {
{
checkModel(); checkModel();
CUDADeviceContext ctx(deviceId); CUDADeviceContext ctx(deviceId);
...@@ -83,9 +82,10 @@ public: ...@@ -83,9 +82,10 @@ public:
from_torch(rotary_emb_context), from_torch(rotary_emb_context),
from_torch(rotary_emb_single), from_torch(rotary_emb_single),
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{}, controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value() ? from_torch(controlnet_single_block_samples.value().contiguous()) : Tensor{}, controlnet_single_block_samples.has_value()
skip_first_layer ? from_torch(controlnet_single_block_samples.value().contiguous())
); : Tensor{},
skip_first_layer);
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
...@@ -93,16 +93,15 @@ public: ...@@ -93,16 +93,15 @@ public:
return output; return output;
} }
std::tuple<torch::Tensor, torch::Tensor> forward_layer( std::tuple<torch::Tensor, torch::Tensor>
int64_t idx, forward_layer(int64_t idx,
torch::Tensor hidden_states, torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states, torch::Tensor encoder_hidden_states,
torch::Tensor temb, torch::Tensor temb,
torch::Tensor rotary_emb_img, torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context, torch::Tensor rotary_emb_context,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt, std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt) std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt) {
{
CUDADeviceContext ctx(deviceId); CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward_layer {}", idx); spdlog::debug("QuantizedFluxModel forward_layer {}", idx);
...@@ -121,22 +120,21 @@ public: ...@@ -121,22 +120,21 @@ public:
from_torch(rotary_emb_img), from_torch(rotary_emb_img),
from_torch(rotary_emb_context), from_torch(rotary_emb_context),
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{}, controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value() ? from_torch(controlnet_single_block_samples.value().contiguous()) : Tensor{} controlnet_single_block_samples.has_value()
); ? from_torch(controlnet_single_block_samples.value().contiguous())
: Tensor{});
hidden_states = to_torch(hidden_states_); hidden_states = to_torch(hidden_states_);
encoder_hidden_states = to_torch(encoder_hidden_states_); encoder_hidden_states = to_torch(encoder_hidden_states_);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
return { hidden_states, encoder_hidden_states }; return {hidden_states, encoder_hidden_states};
} }
torch::Tensor forward_single_layer( torch::Tensor forward_single_layer(int64_t idx,
int64_t idx,
torch::Tensor hidden_states, torch::Tensor hidden_states,
torch::Tensor temb, torch::Tensor temb,
torch::Tensor rotary_emb_single) torch::Tensor rotary_emb_single) {
{
CUDADeviceContext ctx(deviceId); CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx); spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx);
...@@ -146,10 +144,7 @@ public: ...@@ -146,10 +144,7 @@ public:
rotary_emb_single = rotary_emb_single.contiguous(); rotary_emb_single = rotary_emb_single.contiguous();
Tensor result = net->single_transformer_blocks.at(idx)->forward( Tensor result = net->single_transformer_blocks.at(idx)->forward(
from_torch(hidden_states), from_torch(hidden_states), from_torch(temb), from_torch(rotary_emb_single));
from_torch(temb),
from_torch(rotary_emb_single)
);
hidden_states = to_torch(result); hidden_states = to_torch(result);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
...@@ -159,19 +154,15 @@ public: ...@@ -159,19 +154,15 @@ public:
// expose the norm1 forward method of the transformer blocks // expose the norm1 forward method of the transformer blocks
// this is used by TeaCache to get the norm1 output // this is used by TeaCache to get the norm1 output
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> norm_one_forward( std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
int64_t idx, norm_one_forward(int64_t idx, torch::Tensor hidden_states, torch::Tensor temb) {
torch::Tensor hidden_states, AdaLayerNormZero::Output result =
torch::Tensor temb net->transformer_blocks.at(idx)->norm1.forward(from_torch(hidden_states), from_torch(temb));
) { return {to_torch(result.x),
AdaLayerNormZero::Output result = net->transformer_blocks.at(idx)->norm1.forward(from_torch(hidden_states), from_torch(temb));
return {
to_torch(result.x),
to_torch(result.gate_msa), to_torch(result.gate_msa),
to_torch(result.shift_mlp), to_torch(result.shift_mlp),
to_torch(result.scale_mlp), to_torch(result.scale_mlp),
to_torch(result.gate_mlp) to_torch(result.gate_mlp)};
};
} }
// must be called after loading lora // must be called after loading lora
...@@ -214,5 +205,4 @@ public: ...@@ -214,5 +205,4 @@ public:
throw std::invalid_argument(spdlog::fmt_lib::format("Invalid attention implementation {}", name)); throw std::invalid_argument(spdlog::fmt_lib::format("Invalid attention implementation {}", name));
} }
} }
}; };
...@@ -16,7 +16,12 @@ public: ...@@ -16,7 +16,12 @@ public:
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize)); checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val); spdlog::debug("Stack={}", val);
net = std::make_unique<GEMM_W4A4>((int)in_features, (int)out_features, bias, use_fp4, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); net = std::make_unique<GEMM_W4A4>((int)in_features,
(int)out_features,
bias,
use_fp4,
bf16 ? Tensor::BF16 : Tensor::FP16,
Device::cuda((int)deviceId));
} }
torch::Tensor forward(torch::Tensor x) { torch::Tensor forward(torch::Tensor x) {
...@@ -95,10 +100,7 @@ public: ...@@ -95,10 +100,7 @@ public:
x = x.contiguous(); x = x.contiguous();
auto qout = net->quantize( auto qout = net->quantize(from_torch(x), fuse_glu);
from_torch(x),
fuse_glu
);
Tensor act = qout.act.copy(Device::cpu()); Tensor act = qout.act.copy(Device::cpu());
Tensor ascales = qout.ascales.copy(Device::cpu()); Tensor ascales = qout.ascales.copy(Device::cpu());
...@@ -109,5 +111,4 @@ public: ...@@ -109,5 +111,4 @@ public:
spdlog::debug("act = {}", dumpTensorINT4(act)); spdlog::debug("act = {}", dumpTensorINT4(act));
spdlog::debug("ascales = {}", dumpTensorBF16(ascales)); spdlog::debug("ascales = {}", dumpTensorBF16(ascales));
} }
}; };
...@@ -16,7 +16,8 @@ public: ...@@ -16,7 +16,8 @@ public:
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize)); checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val); spdlog::debug("Stack={}", val);
net = std::make_unique<GEMM_W8A8>((int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); net = std::make_unique<GEMM_W8A8>(
(int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
} }
torch::Tensor forward(torch::Tensor x) { torch::Tensor forward(torch::Tensor x) {
......
...@@ -7,8 +7,7 @@ ...@@ -7,8 +7,7 @@
namespace nunchaku::ops { namespace nunchaku::ops {
void gemm_w4a4( void gemm_w4a4(std::optional<torch::Tensor> act, // packed act [M, K / 2]
std::optional<torch::Tensor> act, // packed act [M, K / 2]
std::optional<torch::Tensor> wgt, // packed act [N, K / 2] std::optional<torch::Tensor> wgt, // packed act [N, K / 2]
std::optional<torch::Tensor> out, // linear [M, N] std::optional<torch::Tensor> out, // linear [M, N]
std::optional<torch::Tensor> qout, // packed act [M, N / 2] std::optional<torch::Tensor> qout, // packed act [M, N / 2]
...@@ -26,7 +25,7 @@ namespace nunchaku::ops { ...@@ -26,7 +25,7 @@ namespace nunchaku::ops {
std::optional<torch::Tensor> bias, // packed ws [N] std::optional<torch::Tensor> bias, // packed ws [N]
std::optional<torch::Tensor> smooth_factor, // packed ws [N], for quantization of the next layer std::optional<torch::Tensor> smooth_factor, // packed ws [N], for quantization of the next layer
std::optional<torch::Tensor> out_vk, // linear [B, num_heads, head_dim + 1, head_dim] std::optional<torch::Tensor> out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
std::optional<torch::Tensor> out_linearattn,// linear [B, (M), N / 3] std::optional<torch::Tensor> out_linearattn, // linear [B, (M), N / 3]
bool act_unsigned, bool act_unsigned,
std::vector<float> lora_scales, std::vector<float> lora_scales,
bool fuse_silu, bool fuse_silu,
...@@ -36,8 +35,7 @@ namespace nunchaku::ops { ...@@ -36,8 +35,7 @@ namespace nunchaku::ops {
std::optional<torch::Tensor> out_q, // packed attention [B, H, M, D] std::optional<torch::Tensor> out_q, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_k, // packed attention [B, H, M, D] std::optional<torch::Tensor> out_k, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_v, // packed attention [B, H, M, D] std::optional<torch::Tensor> out_v, // packed attention [B, H, M, D]
int attn_tokens int attn_tokens) {
) {
spdlog::trace("running gemm_w4a4: "); spdlog::trace("running gemm_w4a4: ");
auto getTensor = [](std::optional<torch::Tensor> &t) { auto getTensor = [](std::optional<torch::Tensor> &t) {
...@@ -49,25 +47,24 @@ namespace nunchaku::ops { ...@@ -49,25 +47,24 @@ namespace nunchaku::ops {
} }
return ret; return ret;
}; };
nunchaku::kernels::gemm_w4a4( nunchaku::kernels::gemm_w4a4(getTensor(act),
getTensor(act ), getTensor(wgt),
getTensor(wgt ), getTensor(out),
getTensor(out ), getTensor(qout),
getTensor(qout ), getTensor(ascales),
getTensor(ascales ), getTensor(wscales),
getTensor(wscales ), getTensor(oscales),
getTensor(oscales ), getTensor(poolout),
getTensor(poolout ), getTensor(lora_act_in),
getTensor(lora_act_in ), getTensor(lora_up),
getTensor(lora_up ), getTensor(lora_down),
getTensor(lora_down ), getTensor(lora_act_out),
getTensor(lora_act_out ), getTensor(norm_q),
getTensor(norm_q ), getTensor(norm_k),
getTensor(norm_k ), getTensor(rotary_emb),
getTensor(rotary_emb ), getTensor(bias),
getTensor(bias ),
getTensor(smooth_factor), getTensor(smooth_factor),
getTensor(out_vk ), getTensor(out_vk),
getTensor(out_linearattn), getTensor(out_linearattn),
act_unsigned, act_unsigned,
lora_scales, lora_scales,
...@@ -78,104 +75,64 @@ namespace nunchaku::ops { ...@@ -78,104 +75,64 @@ namespace nunchaku::ops {
getTensor(out_q), getTensor(out_q),
getTensor(out_k), getTensor(out_k),
getTensor(out_v), getTensor(out_v),
attn_tokens attn_tokens);
);
// Tensor::synchronizeDevice(); // Tensor::synchronizeDevice();
} }
void attention_fp16( void attention_fp16(torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
torch::Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM] torch::Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
torch::Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM] torch::Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
torch::Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM] torch::Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
float scale float scale) {
) { nunchaku::kernels::attention_fp16(from_torch(q), from_torch(k), from_torch(v), from_torch(o), scale);
nunchaku::kernels::attention_fp16( }
from_torch(q),
from_torch(k),
from_torch(v),
from_torch(o),
scale
);
}
torch::Tensor gemv_awq( torch::Tensor gemv_awq(torch::Tensor _in_feats,
torch::Tensor _in_feats,
torch::Tensor _kernel, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _scaling_factors,
torch::Tensor _zeros, torch::Tensor _zeros,
int64_t m, int64_t m,
int64_t n, int64_t n,
int64_t k, int64_t k,
int64_t group_size) int64_t group_size) {
{ Tensor result = ::gemv_awq(from_torch(_in_feats.contiguous()),
Tensor result = ::gemv_awq(
from_torch(_in_feats.contiguous()),
from_torch(_kernel.contiguous()), from_torch(_kernel.contiguous()),
from_torch(_scaling_factors.contiguous()), from_torch(_scaling_factors.contiguous()),
from_torch(_zeros.contiguous()), from_torch(_zeros.contiguous()),
(int)m, (int)m,
(int)n, (int)n,
(int)k, (int)k,
(int)group_size (int)group_size);
);
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice(); // Tensor::synchronizeDevice();
return output; return output;
} }
torch::Tensor gemm_awq( torch::Tensor
torch::Tensor _in_feats, gemm_awq(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros) {
torch::Tensor _kernel, Tensor result = ::awq_gemm_forward_cuda(from_torch(_in_feats.contiguous()),
torch::Tensor _scaling_factors,
torch::Tensor _zeros)
{
Tensor result = ::awq_gemm_forward_cuda(
from_torch(_in_feats.contiguous()),
from_torch(_kernel.contiguous()), from_torch(_kernel.contiguous()),
from_torch(_scaling_factors.contiguous()), from_torch(_scaling_factors.contiguous()),
from_torch(_zeros.contiguous()) from_torch(_zeros.contiguous()));
);
// TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy) // TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy)
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice(); // Tensor::synchronizeDevice();
return output; return output;
} }
void test_rmsnorm_rope( void test_rmsnorm_rope(
torch::Tensor input, torch::Tensor input, torch::Tensor output, torch::Tensor norm_q, torch::Tensor norm_k, torch::Tensor rotary_emb) {
torch::Tensor output,
torch::Tensor norm_q,
torch::Tensor norm_k,
torch::Tensor rotary_emb)
{
nunchaku::kernels::test_rmsnorm_rope( nunchaku::kernels::test_rmsnorm_rope(
from_torch(input), from_torch(input), from_torch(output), from_torch(norm_q), from_torch(norm_k), from_torch(rotary_emb));
from_torch(output), }
from_torch(norm_q),
from_torch(norm_k),
from_torch(rotary_emb)
);
}
void test_pack_qkv( void test_pack_qkv(torch::Tensor input, torch::Tensor out_q, torch::Tensor out_k, torch::Tensor out_v, int numTokens) {
torch::Tensor input,
torch::Tensor out_q,
torch::Tensor out_k,
torch::Tensor out_v,
int numTokens)
{
nunchaku::kernels::test_pack_qkv( nunchaku::kernels::test_pack_qkv(
from_torch(input), from_torch(input), from_torch(out_q), from_torch(out_k), from_torch(out_v), numTokens);
from_torch(out_q), }
from_torch(out_k),
from_torch(out_v),
numTokens
);
}
}; }; // namespace nunchaku::ops
\ No newline at end of file
...@@ -11,13 +11,14 @@ ...@@ -11,13 +11,14 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel") py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel")
.def(py::init<>()) .def(py::init<>())
.def("init", &QuantizedFluxModel::init, .def("init",
&QuantizedFluxModel::init,
py::arg("use_fp4"), py::arg("use_fp4"),
py::arg("offload"), py::arg("offload"),
py::arg("bf16"), py::arg("bf16"),
py::arg("deviceId") py::arg("deviceId"))
) .def("set_residual_callback",
.def("set_residual_callback", [](QuantizedFluxModel &self, pybind11::object call_back) { [](QuantizedFluxModel &self, pybind11::object call_back) {
if (call_back.is_none()) { if (call_back.is_none()) {
self.set_residual_callback(pybind11::function()); self.set_residual_callback(pybind11::function());
} else { } else {
...@@ -25,15 +26,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -25,15 +26,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
} }
}) })
.def("reset", &QuantizedFluxModel::reset) .def("reset", &QuantizedFluxModel::reset)
.def("load", &QuantizedFluxModel::load, .def("load", &QuantizedFluxModel::load, py::arg("path"), py::arg("partial") = false)
py::arg("path"), .def("loadDict", &QuantizedFluxModel::loadDict, py::arg("dict"), py::arg("partial") = false)
py::arg("partial") = false .def("forward",
) &QuantizedFluxModel::forward,
.def("loadDict", &QuantizedFluxModel::loadDict,
py::arg("dict"),
py::arg("partial") = false
)
.def("forward", &QuantizedFluxModel::forward,
py::arg("hidden_states"), py::arg("hidden_states"),
py::arg("encoder_hidden_states"), py::arg("encoder_hidden_states"),
py::arg("temb"), py::arg("temb"),
...@@ -42,9 +38,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -42,9 +38,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("rotary_emb_single"), py::arg("rotary_emb_single"),
py::arg("controlnet_block_samples") = py::none(), py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none(), py::arg("controlnet_single_block_samples") = py::none(),
py::arg("skip_first_layer") = false py::arg("skip_first_layer") = false)
) .def("forward_layer",
.def("forward_layer", &QuantizedFluxModel::forward_layer, &QuantizedFluxModel::forward_layer,
py::arg("idx"), py::arg("idx"),
py::arg("hidden_states"), py::arg("hidden_states"),
py::arg("encoder_hidden_states"), py::arg("encoder_hidden_states"),
...@@ -52,8 +48,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -52,8 +48,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("rotary_emb_img"), py::arg("rotary_emb_img"),
py::arg("rotary_emb_context"), py::arg("rotary_emb_context"),
py::arg("controlnet_block_samples") = py::none(), py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none() py::arg("controlnet_single_block_samples") = py::none())
)
.def("forward_single_layer", &QuantizedFluxModel::forward_single_layer) .def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
.def("norm_one_forward", &QuantizedFluxModel::norm_one_forward) .def("norm_one_forward", &QuantizedFluxModel::norm_one_forward)
.def("startDebug", &QuantizedFluxModel::startDebug) .def("startDebug", &QuantizedFluxModel::startDebug)
...@@ -61,32 +56,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -61,32 +56,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("getDebugResults", &QuantizedFluxModel::getDebugResults) .def("getDebugResults", &QuantizedFluxModel::getDebugResults)
.def("setLoraScale", &QuantizedFluxModel::setLoraScale) .def("setLoraScale", &QuantizedFluxModel::setLoraScale)
.def("setAttentionImpl", &QuantizedFluxModel::setAttentionImpl) .def("setAttentionImpl", &QuantizedFluxModel::setAttentionImpl)
.def("isBF16", &QuantizedFluxModel::isBF16) .def("isBF16", &QuantizedFluxModel::isBF16);
;
py::class_<QuantizedSanaModel>(m, "QuantizedSanaModel") py::class_<QuantizedSanaModel>(m, "QuantizedSanaModel")
.def(py::init<>()) .def(py::init<>())
.def("init", &QuantizedSanaModel::init, .def("init",
&QuantizedSanaModel::init,
py::arg("config"), py::arg("config"),
py::arg("pag_layers"), py::arg("pag_layers"),
py::arg("use_fp4"), py::arg("use_fp4"),
py::arg("bf16"), py::arg("bf16"),
py::arg("deviceId") py::arg("deviceId"))
)
.def("reset", &QuantizedSanaModel::reset) .def("reset", &QuantizedSanaModel::reset)
.def("load", &QuantizedSanaModel::load, .def("load", &QuantizedSanaModel::load, py::arg("path"), py::arg("partial") = false)
py::arg("path"), .def("loadDict", &QuantizedSanaModel::loadDict, py::arg("dict"), py::arg("partial") = false)
py::arg("partial") = false
)
.def("loadDict", &QuantizedSanaModel::loadDict,
py::arg("dict"),
py::arg("partial") = false
)
.def("forward", &QuantizedSanaModel::forward) .def("forward", &QuantizedSanaModel::forward)
.def("forward_layer", &QuantizedSanaModel::forward_layer) .def("forward_layer", &QuantizedSanaModel::forward_layer)
.def("startDebug", &QuantizedSanaModel::startDebug) .def("startDebug", &QuantizedSanaModel::startDebug)
.def("stopDebug", &QuantizedSanaModel::stopDebug) .def("stopDebug", &QuantizedSanaModel::stopDebug)
.def("getDebugResults", &QuantizedSanaModel::getDebugResults) .def("getDebugResults", &QuantizedSanaModel::getDebugResults);
;
py::class_<QuantizedGEMM>(m, "QuantizedGEMM") py::class_<QuantizedGEMM>(m, "QuantizedGEMM")
.def(py::init<>()) .def(py::init<>())
.def("init", &QuantizedGEMM::init) .def("init", &QuantizedGEMM::init)
...@@ -96,8 +83,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -96,8 +83,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("quantize", &QuantizedGEMM::quantize) .def("quantize", &QuantizedGEMM::quantize)
.def("startDebug", &QuantizedGEMM::startDebug) .def("startDebug", &QuantizedGEMM::startDebug)
.def("stopDebug", &QuantizedGEMM::stopDebug) .def("stopDebug", &QuantizedGEMM::stopDebug)
.def("getDebugResults", &QuantizedGEMM::getDebugResults) .def("getDebugResults", &QuantizedGEMM::getDebugResults);
;
py::class_<Tensor>(m, "Tensor"); py::class_<Tensor>(m, "Tensor");
py::class_<QuantizedGEMM88>(m, "QuantizedGEMM88") py::class_<QuantizedGEMM88>(m, "QuantizedGEMM88")
.def(py::init<>()) .def(py::init<>())
...@@ -107,8 +93,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -107,8 +93,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("forward", &QuantizedGEMM88::forward) .def("forward", &QuantizedGEMM88::forward)
.def("startDebug", &QuantizedGEMM88::startDebug) .def("startDebug", &QuantizedGEMM88::startDebug)
.def("stopDebug", &QuantizedGEMM88::stopDebug) .def("stopDebug", &QuantizedGEMM88::stopDebug)
.def("getDebugResults", &QuantizedGEMM88::getDebugResults) .def("getDebugResults", &QuantizedGEMM88::getDebugResults);
;
m.def_submodule("ops") m.def_submodule("ops")
.def("gemm_w4a4", nunchaku::ops::gemm_w4a4) .def("gemm_w4a4", nunchaku::ops::gemm_w4a4)
...@@ -117,16 +102,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -117,16 +102,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("gemv_awq", nunchaku::ops::gemv_awq) .def("gemv_awq", nunchaku::ops::gemv_awq)
.def("test_rmsnorm_rope", nunchaku::ops::test_rmsnorm_rope) .def("test_rmsnorm_rope", nunchaku::ops::test_rmsnorm_rope)
.def("test_pack_qkv", nunchaku::ops::test_pack_qkv) .def("test_pack_qkv", nunchaku::ops::test_pack_qkv);
;
m.def_submodule("utils") m.def_submodule("utils")
.def("set_log_level", [](const std::string &level) { .def("set_log_level", [](const std::string &level) { spdlog::set_level(spdlog::level::from_str(level)); })
spdlog::set_level(spdlog::level::from_str(level));
})
.def("set_cuda_stack_limit", nunchaku::utils::set_cuda_stack_limit) .def("set_cuda_stack_limit", nunchaku::utils::set_cuda_stack_limit)
.def("disable_memory_auto_release", nunchaku::utils::disable_memory_auto_release) .def("disable_memory_auto_release", nunchaku::utils::disable_memory_auto_release)
.def("trim_memory", nunchaku::utils::trim_memory) .def("trim_memory", nunchaku::utils::trim_memory)
.def("set_faster_i2f_mode", nunchaku::utils::set_faster_i2f_mode) .def("set_faster_i2f_mode", nunchaku::utils::set_faster_i2f_mode);
;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment