"docs/faq.mdx" did not exist on "b816ff86c923e0290f58f2275e831fc17c29ba37"
Commit e9ad0535 authored by muyangli's avatar muyangli
Browse files

[major] support SANA

parent 9eb2cee0
import torch
from diffusers import FluxPipeline
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5).images[0]
image.save("flux.1-dev.png")
......@@ -8,4 +8,4 @@ pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0]
image.save("example.png")
image.save("flux.1-schnell.png")
import torch
from diffusers import SanaPipeline
from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer,
variant="bf16",
torch_dtype=torch.bfloat16,
).to("cuda")
pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)
prompt = "A cute 🐼 eating 🎋, ink drawing style"
image = pipe(
prompt=prompt,
height=1024,
width=1024,
guidance_scale=4.5,
num_inference_steps=20,
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("sana_1600m.png")
import torch
from diffusers import SanaPAGPipeline
from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m", pag_layers=8)
pipe = SanaPAGPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer,
variant="bf16",
torch_dtype=torch.bfloat16,
pag_applied_layers="transformer_blocks.8",
).to("cuda")
pipe._set_pag_attn_processor = lambda *args, **kwargs: None
pipe.text_encoder.to(torch.bfloat16)
pipe.vae.to(torch.bfloat16)
image = pipe(
prompt="A cute 🐼 eating 🎋, ink drawing style",
height=1024,
width=1024,
guidance_scale=5.0,
pag_scale=2.0,
num_inference_steps=20,
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("sana_1600m_pag.png")
__version__ = "0.0.2beta0"
__version__ = "0.0.2beta1"
......@@ -5,34 +5,15 @@
#include "Serialization.h"
#include "debug.h"
#include "Linear.h"
#include "module.h"
class QuantizedFluxModel { // : public torch::CustomClassHolder {
class QuantizedFluxModel : public ModuleWrapper<FluxModel> { // : public torch::CustomClassHolder {
public:
void init(bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedFluxModel");
net = std::make_unique<FluxModel>(bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
void reset() {
debugContext.reset();
net.reset();
Tensor::synchronizeDevice();
trimMemory();
Tensor::synchronizeDevice();
}
void load(std::string path, bool partial = false) {
checkModel();
spdlog::info("{} weights from {}", partial ? "Loading partial" : "Loading", path);
std::shared_ptr<SafeTensors> provider = std::make_shared<SafeTensors>(path);
net->loadParams(*provider, partial);
Tensor::synchronizeDevice();
spdlog::info("Done.");
}
torch::Tensor forward(
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
......@@ -123,44 +104,6 @@ public:
return hidden_states;
}
void disableMemoryAutoRelease() {
int device;
checkCUDA(cudaGetDevice(&device));
cudaMemPool_t mempool;
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
uint64_t threshold = UINT64_MAX;
checkCUDA(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &threshold));
}
void trimMemory() {
int device;
checkCUDA(cudaGetDevice(&device));
cudaMemPool_t mempool;
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
size_t bytesToKeep = 0;
checkCUDA(cudaMemPoolTrimTo(mempool, bytesToKeep));
}
void startDebug() {
debugContext = std::make_unique<DebugContext>();
}
void stopDebug() {
debugContext.reset();
}
auto getDebugResults() {
// c10::Dict<std::string, torch::Tensor> result;
std::map<std::string, torch::Tensor> result;
if (debugContext) {
for (auto &&[key, value] : debugContext->tensors) {
// result.insert(key, to_torch(value));
result[key] = to_torch(value);
}
}
return result;
}
// must be called after loading lora
// skip specific ranks in W4A4 layers
......@@ -178,7 +121,7 @@ public:
for (int i = 0; i < skipRanks / 16; i++) {
m->lora_scales[i] = 1.0f;
}
for (int i = skipRanks / 16; i < m->lora_scales.size(); i++) {
for (int i = skipRanks / 16; i < (int)m->lora_scales.size(); i++) {
m->lora_scales[i] = scale;
}
}
......@@ -189,15 +132,4 @@ public:
Attention::setForceFP16(net.get(), enable);
}
private:
void checkModel() {
if (!net) {
throw std::runtime_error("Model not initialized");
}
}
private:
std::unique_ptr<FluxModel> net;
std::unique_ptr<DebugContext> debugContext;
};
\ No newline at end of file
......@@ -4,11 +4,9 @@
#include "Serialization.h"
#include "Linear.h"
#include "debug.h"
#include "module.h"
#include "kernels/gemm_w4a4.h"
#include "kernels/awq/gemv_awq.h"
class QuantizedGEMM { // : public torch::CustomClassHolder {
class QuantizedGEMM : public ModuleWrapper<GEMM_W4A4> {
public:
void init(int64_t in_features, int64_t out_features, bool bias, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedGEMM");
......@@ -21,21 +19,6 @@ public:
net = std::make_unique<GEMM_W4A4>((int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
void reset() {
debugContext.reset();
net.reset();
Tensor::synchronizeDevice();
}
void load(std::string path) {
checkModel();
spdlog::info("Loading weights from {}", path);
std::shared_ptr<SafeTensors> provider = std::make_shared<SafeTensors>(path);
net->loadParams(*provider);
Tensor::synchronizeDevice();
}
torch::Tensor forward(torch::Tensor x) {
checkModel();
......@@ -43,9 +26,7 @@ public:
x = x.contiguous();
Tensor result = std::get<Tensor>(net->forward(
from_torch(x)
));
Tensor result = net->forward(from_torch(x));
torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice();
......@@ -107,7 +88,7 @@ public:
return ss.str();
}
void quantize(torch::Tensor x) {
void quantize(torch::Tensor x, bool fuse_glu) {
checkModel();
spdlog::debug("QuantizedGEMM quantize");
......@@ -115,7 +96,8 @@ public:
x = x.contiguous();
auto qout = net->quantize(
from_torch(x)
from_torch(x),
fuse_glu
);
Tensor act = qout.act.copy(Device::cpu());
......@@ -128,120 +110,4 @@ public:
spdlog::debug("ascales = {}", dumpTensorBF16(ascales));
}
void gemm(
c10::optional<torch::Tensor> act, // packed act [M, K / 2]
c10::optional<torch::Tensor> wgt, // packed act [N, K / 2]
c10::optional<torch::Tensor> out, // linear [M, N]
c10::optional<torch::Tensor> qout, // packed act [M, N / 2]
c10::optional<torch::Tensor> ascales, // packed as [K / 64, M]
c10::optional<torch::Tensor> wscales, // packed ws [K / 64, N]
c10::optional<torch::Tensor> oscales, // packed as [N / 64, M]
c10::optional<torch::Tensor> poolout, // linear [M / PoolSize, N]
c10::optional<torch::Tensor> lora_act_in, // packed lora_act [M, R]
c10::optional<torch::Tensor> lora_up, // packed lora_wgt [N, R]
c10::optional<torch::Tensor> lora_down, // packed lora_wgt [N, R]
c10::optional<torch::Tensor> lora_act_out, // packed lora_act [M, R]
c10::optional<torch::Tensor> norm_q, // linear [HEAD_DIM]
c10::optional<torch::Tensor> norm_k, // linear [HEAD_DIM]
c10::optional<torch::Tensor> rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
c10::optional<torch::Tensor> bias, // packed ws [N]
c10::optional<torch::Tensor> smooth_factor, // packed ws [N], for quantization of the next layer
bool act_unsigned,
std::vector<float> lora_scales
) {
std::cerr << "running gemm_w4a4: " << std::endl;
auto getTensor = [](c10::optional<torch::Tensor> &t) {
Tensor ret = t.has_value() ? from_torch(t.value()) : Tensor{};
if (ret.valid()) {
std::cerr << " " << ret.shape.str() << std::endl;
} else {
std::cerr << " <invalid>" << std::endl;
}
return ret;
};
gemm_w4a4(
getTensor(act ),
getTensor(wgt ),
getTensor(out ),
getTensor(qout ),
getTensor(ascales ),
getTensor(wscales ),
getTensor(oscales ),
getTensor(poolout ),
getTensor(lora_act_in ),
getTensor(lora_up ),
getTensor(lora_down ),
getTensor(lora_act_out ),
getTensor(norm_q ),
getTensor(norm_k ),
getTensor(rotary_emb ),
getTensor(bias ),
getTensor(smooth_factor),
act_unsigned,
lora_scales
);
Tensor::synchronizeDevice();
}
torch::Tensor gemv_awq(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int64_t m,
int64_t n,
int64_t k,
int64_t group_size)
{
Tensor result = ::gemv_awq(
from_torch(_in_feats.contiguous()),
from_torch(_kernel.contiguous()),
from_torch(_scaling_factors.contiguous()),
from_torch(_zeros.contiguous()),
(int)m,
(int)n,
(int)k,
(int)group_size
);
torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice();
return output;
}
void startDebug() {
debugContext = std::make_unique<DebugContext>();
}
void stopDebug() {
debugContext.reset();
}
auto getDebugResults() {
// c10::Dict<std::string, torch::Tensor> result;
std::map<std::string, torch::Tensor> result;
if (debugContext) {
for (auto &&[key, value] : debugContext->tensors) {
// result.insert(key, to_torch(value));
result[key] = to_torch(value);
}
}
return result;
}
private:
void checkModel() {
if (!net) {
throw std::runtime_error("Model not initialized");
}
}
private:
std::unique_ptr<GEMM_W4A4> net;
std::unique_ptr<DebugContext> debugContext;
};
\ No newline at end of file
#pragma once
#include "interop/torch.h"
#include "Serialization.h"
#include "Linear.h"
#include "debug.h"
#include "module.h"
class QuantizedGEMM88 : public ModuleWrapper<GEMM_W8A8> {
public:
void init(int64_t in_features, int64_t out_features, bool bias, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedGEMM88");
size_t val = 0;
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val);
net = std::make_unique<GEMM_W8A8>((int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
torch::Tensor forward(torch::Tensor x) {
checkModel();
std::cerr << "QuantizedGEMM88 forward" << std::endl;
x = x.contiguous();
Tensor result = net->forward(from_torch(x));
torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice();
return output;
}
};
\ No newline at end of file
#pragma once
#include "interop/torch.h"
#include "Serialization.h"
#include "Module.h"
#include "debug.h"
#include "utils.h"
template<typename M>
class ModuleWrapper {
public:
void reset() {
debugContext.reset();
net.reset();
Tensor::synchronizeDevice();
nunchaku::utils::trim_memory();
Tensor::synchronizeDevice();
}
void load(std::string path, bool partial = false) {
checkModel();
spdlog::info("{} weights from {}", partial ? "Loading partial" : "Loading", path);
std::shared_ptr<SafeTensors> provider = std::make_shared<SafeTensors>(path);
net->loadParams(*provider, partial);
Tensor::synchronizeDevice();
spdlog::info("Done.");
}
void startDebug() {
debugContext = std::make_unique<DebugContext>();
}
void stopDebug() {
debugContext.reset();
}
auto getDebugResults() {
std::map<std::string, torch::Tensor> result;
if (debugContext) {
for (auto &&[key, value] : debugContext->tensors) {
result[key] = to_torch(value);
}
}
return result;
}
protected:
void checkModel() {
if (!net) {
throw std::runtime_error("Model not initialized");
}
}
protected:
std::unique_ptr<M> net;
std::unique_ptr<DebugContext> debugContext;
};
\ No newline at end of file
#pragma once
#include "interop/torch.h"
#include "kernels/zgemm/zgemm.h"
#include "kernels/awq/gemv_awq.h"
namespace nunchaku::ops {
void gemm_w4a4(
std::optional<torch::Tensor> act, // packed act [M, K / 2]
std::optional<torch::Tensor> wgt, // packed act [N, K / 2]
std::optional<torch::Tensor> out, // linear [M, N]
std::optional<torch::Tensor> qout, // packed act [M, N / 2]
std::optional<torch::Tensor> ascales, // packed as [K / 64, M]
std::optional<torch::Tensor> wscales, // packed ws [K / 64, N]
std::optional<torch::Tensor> oscales, // packed as [N / 64, M]
std::optional<torch::Tensor> poolout, // linear [M / PoolSize, N]
std::optional<torch::Tensor> lora_act_in, // packed lora_act [M, R]
std::optional<torch::Tensor> lora_up, // packed lora_wgt [N, R]
std::optional<torch::Tensor> lora_down, // packed lora_wgt [N, R]
std::optional<torch::Tensor> lora_act_out, // packed lora_act [M, R]
std::optional<torch::Tensor> norm_q, // linear [HEAD_DIM]
std::optional<torch::Tensor> norm_k, // linear [HEAD_DIM]
std::optional<torch::Tensor> rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
std::optional<torch::Tensor> bias, // packed ws [N]
std::optional<torch::Tensor> smooth_factor, // packed ws [N], for quantization of the next layer
std::optional<torch::Tensor> out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
std::optional<torch::Tensor> out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales,
bool fuse_silu
) {
spdlog::trace("running gemm_w4a4: ");
auto getTensor = [](std::optional<torch::Tensor> &t) {
Tensor ret = t.has_value() ? from_torch(t.value()) : Tensor{};
if (ret.valid()) {
spdlog::trace(" {}", ret.shape.str());
} else {
spdlog::trace(" <invalid>");
}
return ret;
};
nunchaku::kernels::gemm_w4a4(
getTensor(act ),
getTensor(wgt ),
getTensor(out ),
getTensor(qout ),
getTensor(ascales ),
getTensor(wscales ),
getTensor(oscales ),
getTensor(poolout ),
getTensor(lora_act_in ),
getTensor(lora_up ),
getTensor(lora_down ),
getTensor(lora_act_out ),
getTensor(norm_q ),
getTensor(norm_k ),
getTensor(rotary_emb ),
getTensor(bias ),
getTensor(smooth_factor),
getTensor(out_vk ),
getTensor(out_linearattn),
act_unsigned,
lora_scales,
fuse_silu
);
Tensor::synchronizeDevice();
}
torch::Tensor gemv_awq(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int64_t m,
int64_t n,
int64_t k,
int64_t group_size)
{
Tensor result = ::gemv_awq(
from_torch(_in_feats.contiguous()),
from_torch(_kernel.contiguous()),
from_torch(_scaling_factors.contiguous()),
from_torch(_zeros.contiguous()),
(int)m,
(int)n,
(int)k,
(int)group_size
);
torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice();
return output;
}
};
\ No newline at end of file
#include "gemm.h"
#include "gemm88.h"
#include "flux.h"
#include "sana.h"
#include "ops.h"
#include "utils.h"
#include <pybind11/pybind11.h>
// TORCH_LIBRARY(diffuxer, m) {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel")
// .def(torch::init<>())
.def(py::init<>())
.def("init", &QuantizedFluxModel::init,
py::arg("bf16"),
......@@ -20,26 +22,63 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("forward", &QuantizedFluxModel::forward)
.def("forward_layer", &QuantizedFluxModel::forward_layer)
.def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
.def("disableMemoryAutoRelease", &QuantizedFluxModel::disableMemoryAutoRelease)
.def("trimMemory", &QuantizedFluxModel::trimMemory)
.def("startDebug", &QuantizedFluxModel::startDebug)
.def("stopDebug", &QuantizedFluxModel::stopDebug)
.def("getDebugResults", &QuantizedFluxModel::getDebugResults)
.def("setLoraScale", &QuantizedFluxModel::setLoraScale)
.def("forceFP16Attention", &QuantizedFluxModel::forceFP16Attention)
;
py::class_<QuantizedSanaModel>(m, "QuantizedSanaModel")
.def(py::init<>())
.def("init", &QuantizedSanaModel::init,
py::arg("config"),
py::arg("pag_layers"),
py::arg("bf16"),
py::arg("deviceId")
)
.def("reset", &QuantizedSanaModel::reset)
.def("load", &QuantizedSanaModel::load,
py::arg("path"),
py::arg("partial") = false
)
.def("forward", &QuantizedSanaModel::forward)
.def("forward_layer", &QuantizedSanaModel::forward_layer)
.def("startDebug", &QuantizedSanaModel::startDebug)
.def("stopDebug", &QuantizedSanaModel::stopDebug)
.def("getDebugResults", &QuantizedSanaModel::getDebugResults)
;
py::class_<QuantizedGEMM>(m, "QuantizedGEMM")
// .def(torch::init<>())
.def(py::init<>())
.def("init", &QuantizedGEMM::init)
.def("reset", &QuantizedGEMM::reset)
.def("load", &QuantizedGEMM::load)
.def("forward", &QuantizedGEMM::forward)
.def("quantize", &QuantizedGEMM::quantize)
.def("gemm", &QuantizedGEMM::gemm)
.def("gemv_awq", &QuantizedGEMM::gemv_awq)
.def("startDebug", &QuantizedGEMM::startDebug)
.def("stopDebug", &QuantizedGEMM::stopDebug)
.def("getDebugResults", &QuantizedGEMM::getDebugResults)
;
py::class_<QuantizedGEMM88>(m, "QuantizedGEMM88")
.def(py::init<>())
.def("init", &QuantizedGEMM88::init)
.def("reset", &QuantizedGEMM88::reset)
.def("load", &QuantizedGEMM88::load)
.def("forward", &QuantizedGEMM88::forward)
.def("startDebug", &QuantizedGEMM88::startDebug)
.def("stopDebug", &QuantizedGEMM88::stopDebug)
.def("getDebugResults", &QuantizedGEMM88::getDebugResults)
;
m.def_submodule("ops")
.def("gemm_w4a4", nunchaku::ops::gemm_w4a4)
.def("gemv_awq", nunchaku::ops::gemv_awq)
;
m.def_submodule("utils")
.def("set_log_level", [](const std::string &level) {
spdlog::set_level(spdlog::level::from_str(level));
})
.def("disable_memory_auto_release", nunchaku::utils::disable_memory_auto_release)
.def("trim_memory", nunchaku::utils::trim_memory)
;
}
#pragma once
#include "interop/torch.h"
#include "SanaModel.h"
#include "Serialization.h"
#include "debug.h"
#include "module.h"
class QuantizedSanaModel : public ModuleWrapper<SanaModel> {
public:
void init(pybind11::dict config, std::vector<int> pag_layers, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedSanaModel");
SanaConfig cfg{
.num_layers = config["num_layers"].cast<int>(),
.num_attention_heads = config["num_attention_heads"].cast<int>(),
.attention_head_dim = config["attention_head_dim"].cast<int>(),
.num_cross_attention_heads = config["num_cross_attention_heads"].cast<int>(),
.expand_ratio = config["mlp_ratio"].cast<double>(),
.pag_layers = pag_layers,
};
net = std::make_unique<SanaModel>(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
torch::Tensor forward(
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor timestep,
torch::Tensor cu_seqlens_img,
torch::Tensor cu_seqlens_txt,
int H,
int W,
bool pag,
bool cfg)
{
checkModel();
spdlog::debug("QuantizedSanaModel forward");
hidden_states = hidden_states.contiguous();
encoder_hidden_states = encoder_hidden_states.contiguous();
timestep = timestep.contiguous();
cu_seqlens_img = cu_seqlens_img.contiguous();
cu_seqlens_txt = cu_seqlens_txt.contiguous();
Tensor result = net->forward(
from_torch(hidden_states),
from_torch(encoder_hidden_states),
from_torch(timestep),
from_torch(cu_seqlens_img),
from_torch(cu_seqlens_txt),
H, W,
pag, cfg
);
torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice();
return output;
}
torch::Tensor forward_layer(
int64_t idx,
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor timestep,
torch::Tensor cu_seqlens_img,
torch::Tensor cu_seqlens_txt,
int H,
int W,
bool pag,
bool cfg)
{
checkModel();
spdlog::debug("QuantizedSanaModel forward_layer {}", idx);
hidden_states = hidden_states.contiguous();
encoder_hidden_states = encoder_hidden_states.contiguous();
timestep = timestep.contiguous();
cu_seqlens_img = cu_seqlens_img.contiguous();
cu_seqlens_txt = cu_seqlens_txt.contiguous();
Tensor result = net->transformer_blocks.at(idx)->forward(
from_torch(hidden_states),
from_torch(encoder_hidden_states),
from_torch(timestep),
from_torch(cu_seqlens_img),
from_torch(cu_seqlens_txt),
H, W,
pag, cfg
);
torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice();
return output;
}
};
\ No newline at end of file
#pragma once
#include "common.h"
#include "Tensor.h"
namespace nunchaku::utils {
void disable_memory_auto_release() {
int device;
checkCUDA(cudaGetDevice(&device));
cudaMemPool_t mempool;
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
uint64_t threshold = UINT64_MAX;
checkCUDA(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &threshold));
}
void trim_memory() {
int device;
checkCUDA(cudaGetDevice(&device));
cudaMemPool_t mempool;
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
size_t bytesToKeep = 0;
checkCUDA(cudaMemPoolTrimTo(mempool, bytesToKeep));
}
};
\ No newline at end of file
......@@ -2,15 +2,14 @@ import os
import diffusers
import torch
from diffusers import __version__, FluxTransformer2DModel
from diffusers import FluxTransformer2DModel
from diffusers.configuration_utils import register_to_config
from huggingface_hub import hf_hub_download, utils, constants
from huggingface_hub import hf_hub_download, utils
from packaging.version import Version
from safetensors.torch import load_file
from torch import nn
from .._C import QuantizedFluxModel
from .utils import NunchakuModelLoaderMixin
from .._C import QuantizedFluxModel, utils as cutils
SVD_RANK = 32
......@@ -109,13 +108,13 @@ def load_quantized_module(path: str, device: str | torch.device = "cuda") -> Qua
assert device.type == "cuda"
m = QuantizedFluxModel()
m.disableMemoryAutoRelease()
cutils.disable_memory_auto_release()
m.init(True, 0 if device.index is None else device.index)
m.load(path)
return m
class NunchakuFluxTransformer2dModel(FluxTransformer2DModel):
class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoaderMixin):
@register_to_config
def __init__(
self,
......@@ -146,66 +145,10 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel):
@classmethod
@utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs):
subfolder = kwargs.get("subfolder", None)
if os.path.exists(pretrained_model_name_or_path):
dirname = (
pretrained_model_name_or_path
if subfolder is None
else os.path.join(pretrained_model_name_or_path, subfolder)
)
unquantized_part_path = os.path.join(dirname, "unquantized_layers.safetensors")
transformer_block_path = os.path.join(dirname, "transformer_blocks.safetensors")
else:
download_kwargs = {
"subfolder": subfolder,
"repo_type": "model",
"revision": kwargs.get("revision", None),
"cache_dir": kwargs.get("cache_dir", None),
"local_dir": kwargs.get("local_dir", None),
"user_agent": kwargs.get("user_agent", None),
"force_download": kwargs.get("force_download", False),
"proxies": kwargs.get("proxies", None),
"etag_timeout": kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT),
"token": kwargs.get("token", None),
"local_files_only": kwargs.get("local_files_only", None),
"headers": kwargs.get("headers", None),
"endpoint": kwargs.get("endpoint", None),
"resume_download": kwargs.get("resume_download", None),
"force_filename": kwargs.get("force_filename", None),
"local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
}
unquantized_part_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename="unquantized_layers.safetensors", **download_kwargs
)
transformer_block_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename="transformer_blocks.safetensors", **download_kwargs
)
config, _, _ = cls.load_config(
pretrained_model_name_or_path,
subfolder=subfolder,
cache_dir=kwargs.get("cache_dir", None),
return_unused_kwargs=True,
return_commit_hash=True,
force_download=kwargs.get("force_download", False),
proxies=kwargs.get("proxies", None),
local_files_only=kwargs.get("local_files_only", None),
token=kwargs.get("token", None),
revision=kwargs.get("revision", None),
user_agent={"diffusers": __version__, "file_type": "model", "framework": "pytorch"},
**kwargs,
)
device = kwargs.get("device", "cuda")
transformer: NunchakuFluxTransformer2dModel = cls.from_config(config).to(
kwargs.get("torch_dtype", torch.bfloat16)
)
state_dict = load_file(unquantized_part_path)
transformer.load_state_dict(state_dict, strict=False)
transformer, transformer_block_path = cls._build_model(pretrained_model_name_or_path, **kwargs)
m = load_quantized_module(transformer_block_path, device=device)
transformer.inject_quantized_module(m, device)
return transformer
def update_lora_params(self, path: str):
......
import os
from typing import Optional
import torch
import torch.nn.functional as F
from diffusers import SanaTransformer2DModel
from diffusers.configuration_utils import register_to_config
from huggingface_hub import utils
from torch import nn
from .utils import NunchakuModelLoaderMixin
from .._C import QuantizedSanaModel, utils as cutils
SVD_RANK = 32
class NunchakuSanaTransformerBlocks(nn.Module):
def __init__(self, m: QuantizedSanaModel, dtype: torch.dtype, device: str | torch.device):
super(NunchakuSanaTransformerBlocks, self).__init__()
self.m = m
self.dtype = dtype
self.device = device
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
height: Optional[int] = None,
width: Optional[int] = None,
):
batch_size = hidden_states.shape[0]
img_tokens = hidden_states.shape[1]
txt_tokens = encoder_hidden_states.shape[1]
original_dtype = hidden_states.dtype
original_device = hidden_states.device
assert encoder_attention_mask is not None
assert encoder_attention_mask.shape == (batch_size, 1, txt_tokens)
mask = encoder_attention_mask.reshape(batch_size, txt_tokens)
nunchaku_encoder_hidden_states = encoder_hidden_states[mask > -9000]
cu_seqlens_txt = F.pad((mask > -9000).sum(dim=1).cumsum(dim=0), pad=(1, 0), value=0).to(torch.int32)
cu_seqlens_img = torch.arange(
0, (batch_size + 1) * img_tokens, img_tokens, dtype=torch.int32, device=self.device
)
if height is None and width is None:
height = width = int(img_tokens**0.5)
elif height is None:
height = img_tokens // width
elif width is None:
width = img_tokens // height
assert height * width == img_tokens
return (
self.m.forward(
hidden_states.to(self.dtype).to(self.device),
nunchaku_encoder_hidden_states.to(self.dtype).to(self.device),
timestep.to(self.dtype).to(self.device),
cu_seqlens_img.to(self.device),
cu_seqlens_txt.to(self.device),
height,
width,
batch_size % 3 == 0, # pag is set when loading the model, FIXME: pag_scale == 0
True, # TODO: find a way to detect if we are doing CFG
)
.to(original_dtype)
.to(original_device)
)
class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoaderMixin):
@register_to_config
def __init__(
self,
in_channels: int = 32,
out_channels: Optional[int] = 32,
num_attention_heads: int = 70,
attention_head_dim: int = 32,
num_layers: int = 20,
num_cross_attention_heads: Optional[int] = 20,
cross_attention_head_dim: Optional[int] = 112,
cross_attention_dim: Optional[int] = 2240,
caption_channels: int = 2304,
mlp_ratio: float = 2.5,
dropout: float = 0.0,
attention_bias: bool = False,
sample_size: int = 32,
patch_size: int = 1,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
interpolation_scale: Optional[int] = None,
) -> None:
# set num_layers to 0 to avoid creating transformer blocks
self.original_num_layers = num_layers
super(NunchakuSanaTransformer2DModel, self).__init__(
in_channels=in_channels,
out_channels=out_channels,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_layers=0,
num_cross_attention_heads=num_cross_attention_heads,
cross_attention_head_dim=cross_attention_head_dim,
cross_attention_dim=cross_attention_dim,
caption_channels=caption_channels,
mlp_ratio=mlp_ratio,
dropout=dropout,
attention_bias=attention_bias,
sample_size=sample_size,
patch_size=patch_size,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
interpolation_scale=interpolation_scale,
)
@classmethod
@utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs):
device = kwargs.get("device", "cuda")
pag_layers = kwargs.get("pag_layers", [])
transformer, transformer_block_path = cls._build_model(pretrained_model_name_or_path, **kwargs)
transformer.config["num_layers"] = transformer.original_num_layers
m = load_quantized_module(transformer, transformer_block_path, device=device, pag_layers=pag_layers)
transformer.inject_quantized_module(m, device)
return transformer
def inject_quantized_module(self, m: QuantizedSanaModel, device: str | torch.device = "cuda"):
self.transformer_blocks = torch.nn.ModuleList([NunchakuSanaTransformerBlocks(m, self.dtype, device)])
return self
def load_quantized_module(
net: SanaTransformer2DModel,
path: str,
device: str | torch.device = "cuda",
pag_layers: int | list[int] | None = None,
) -> QuantizedSanaModel:
if pag_layers is None:
pag_layers = []
elif isinstance(pag_layers, int):
pag_layers = [pag_layers]
device = torch.device(device)
assert device.type == "cuda"
m = QuantizedSanaModel()
cutils.disable_memory_auto_release()
m.init(net.config, pag_layers, net.dtype == torch.bfloat16, 0 if device.index is None else device.index)
m.load(path)
return m
def inject_quantized_module(
net: SanaTransformer2DModel, m: QuantizedSanaModel, device: torch.device
) -> SanaTransformer2DModel:
net.transformer_blocks = torch.nn.ModuleList([NunchakuSanaTransformerBlocks(m, net.dtype, device)])
return net
import os
import torch
from diffusers import __version__
from huggingface_hub import constants, hf_hub_download
from safetensors.torch import load_file
class NunchakuModelLoaderMixin:
@classmethod
def _build_model(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs):
subfolder = kwargs.get("subfolder", None)
if os.path.exists(pretrained_model_name_or_path):
dirname = (
pretrained_model_name_or_path
if subfolder is None
else os.path.join(pretrained_model_name_or_path, subfolder)
)
unquantized_part_path = os.path.join(dirname, "unquantized_layers.safetensors")
transformer_block_path = os.path.join(dirname, "transformer_blocks.safetensors")
else:
download_kwargs = {
"subfolder": subfolder,
"repo_type": "model",
"revision": kwargs.get("revision", None),
"cache_dir": kwargs.get("cache_dir", None),
"local_dir": kwargs.get("local_dir", None),
"user_agent": kwargs.get("user_agent", None),
"force_download": kwargs.get("force_download", False),
"proxies": kwargs.get("proxies", None),
"etag_timeout": kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT),
"token": kwargs.get("token", None),
"local_files_only": kwargs.get("local_files_only", None),
"headers": kwargs.get("headers", None),
"endpoint": kwargs.get("endpoint", None),
"resume_download": kwargs.get("resume_download", None),
"force_filename": kwargs.get("force_filename", None),
"local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
}
unquantized_part_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename="unquantized_layers.safetensors", **download_kwargs
)
transformer_block_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename="transformer_blocks.safetensors", **download_kwargs
)
config, _, _ = cls.load_config(
pretrained_model_name_or_path,
subfolder=subfolder,
cache_dir=kwargs.get("cache_dir", None),
return_unused_kwargs=True,
return_commit_hash=True,
force_download=kwargs.get("force_download", False),
proxies=kwargs.get("proxies", None),
local_files_only=kwargs.get("local_files_only", None),
token=kwargs.get("token", None),
revision=kwargs.get("revision", None),
user_agent={"diffusers": __version__, "file_type": "model", "framework": "pytorch"},
**kwargs,
)
transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
state_dict = load_file(unquantized_part_path)
transformer.load_state_dict(state_dict, strict=False)
return transformer, transformer_block_path
......@@ -89,6 +89,7 @@ if __name__ == "__main__":
"src/layernorm.cpp",
"src/Linear.cpp",
*ncond("src/FluxModel.cpp"),
*ncond("src/SanaModel.cpp"),
"src/Serialization.cpp",
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_fp16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_bf16_sm80.cu"),
......@@ -101,7 +102,11 @@ if __name__ == "__main__":
"src/kernels/activation_kernels.cu",
"src/kernels/layernorm_kernels.cu",
"src/kernels/misc_kernels.cu",
"src/kernels/gemm_w4a4.cu",
"src/kernels/zgemm/gemm_w4a4.cu",
"src/kernels/zgemm/gemm_w4a4_launch_fp16.cu",
"src/kernels/zgemm/gemm_w4a4_launch_bf16.cu",
"src/kernels/zgemm/gemm_w8a8.cu",
"src/kernels/dwconv.cu",
"src/kernels/gemm_batched.cu",
"src/kernels/gemm_f16.cu",
"src/kernels/awq/gemv_awq.cu",
......
......@@ -9,12 +9,13 @@
#include <iostream>
using spdlog::fmt_lib::format;
using namespace nunchaku;
Tensor forward_mlp(GEMM_W4A4 &fc1, GEMM_W4A4 &fc2, Tensor norm_hidden_states) {
Tensor ff_output = std::get<Tensor>(fc2.forward_quant(
std::get<GEMM_W4A4::QuantizedActivation>(fc1.forward(norm_hidden_states, GEMM_W4A4::FuseOptions::GELU_QUANT, &fc2)))
Tensor ff_output = fc2.forward_quant(
std::get<GEMM_W4A4::QuantizedActivation>(fc1.forward(norm_hidden_states, GEMM_W4A4::FuseOptions::GELU_QUANT, &fc2))
);
return ff_output;
}
......@@ -26,7 +27,8 @@ Tensor forward_mlp(GEMM_W4A4 &fc1, GEMM_W4A4 &fc2, Tensor norm_hidden_states) {
Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) {
return std::get<Tensor>(fc.forward(x));
return fc.forward(x);
// return std::get<Tensor>(fc.forward(x));
}
// Tensor forward_fc(GEMM_W8A8 &fc, Tensor x) {
......@@ -49,7 +51,7 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor
debug("emb_input", emb);
emb = linear.forward(Silu::forward(emb));
debug("emb_linear", emb);
auto &&[shift_msa, scale_msa, gate_msa] = split_mod<3>(emb);
auto &&[shift_msa, scale_msa, gate_msa] = kernels::split_mod<3>(emb);
debug("scale_msa", scale_msa);
debug("shift_msa", shift_msa);
......@@ -57,7 +59,7 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor
Tensor norm_x = norm.forward(x);
debug("norm_x", norm_x);
mul_add(norm_x, scale_msa, shift_msa);
kernels::mul_add(norm_x, scale_msa, shift_msa);
return Output{norm_x, gate_msa};
}
......@@ -80,24 +82,24 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
debug("emb_linear", emb);
if (pre_only) {
auto &&[shift_msa, scale_msa] = split_mod<2>(emb);
auto &&[shift_msa, scale_msa] = kernels::split_mod<2>(emb);
debug("shift_msa", shift_msa);
Tensor norm_x = norm.forward(x);
debug("norm_x", norm_x);
mul_add(norm_x, scale_msa, shift_msa);
kernels::mul_add(norm_x, scale_msa, shift_msa);
debug("norm_x_scaled", norm_x);
return Output{norm_x};
} else {
auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = split_mod<6>(emb);
auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = kernels::split_mod<6>(emb);
debug("shift_msa", shift_msa);
Tensor norm_x = norm.forward(x);
debug("norm_x", norm_x);
mul_add(norm_x, scale_msa, shift_msa);
kernels::mul_add(norm_x, scale_msa, shift_msa);
debug("norm_x_scaled", norm_x);
return Output{norm_x, gate_msa, shift_mlp, scale_mlp, gate_mlp};
......@@ -149,7 +151,7 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
}
}
blockmask = topk(pool_score, pool_tokens * (1 - sparsityRatio));
blockmask = kernels::topk(pool_score, pool_tokens * (1 - sparsityRatio));
if (cu_seqlens_cpu.valid()) {
if (cu_seqlens_cpu.shape[0] != batch_size + 1) {
......@@ -173,7 +175,7 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
if (cast_fp16) {
Tensor tmp = Tensor::empty(qkv.shape.dataExtent, Tensor::FP16, qkv.device());
cast(qkv, tmp);
kernels::cast(qkv, tmp);
qkv = tmp;
}
......@@ -206,7 +208,7 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
if (cast_fp16) {
Tensor tmp = Tensor::empty(raw_attn_output.shape.dataExtent, Tensor::BF16, raw_attn_output.device());
cast(raw_attn_output, tmp);
kernels::cast(raw_attn_output, tmp);
raw_attn_output = tmp;
}
......@@ -315,10 +317,10 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
debug("ff_output", ff_output);
hidden_states = add(attn_output, ff_output);
hidden_states = kernels::add(attn_output, ff_output);
debug("attn_ff_output", hidden_states);
mul_add(hidden_states, gate, residual);
kernels::mul_add(hidden_states, gate, residual);
nvtxRangePop();
......@@ -501,7 +503,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug("img.attn_output", attn_output);
#if 1
mul_add(attn_output, gate_msa, hidden_states);
kernels::mul_add(attn_output, gate_msa, hidden_states);
hidden_states = std::move(attn_output);
nvtxRangePop();
......@@ -512,7 +514,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor norm_hidden_states = norm2.forward(hidden_states);
debug("scale_mlp", scale_mlp);
debug("shift_mlp", shift_mlp);
mul_add(norm_hidden_states, scale_mlp, shift_mlp);
kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
......@@ -525,7 +527,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug("img.ff_output", ff_output);
debug("gate_mlp", gate_mlp);
mul_add(ff_output, gate_mlp, hidden_states);
kernels::mul_add(ff_output, gate_mlp, hidden_states);
hidden_states = std::move(ff_output);
nvtxRangePop();
......@@ -566,7 +568,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug("context.attn_output", attn_output);
#if 1
mul_add(attn_output, gate_msa, encoder_hidden_states);
kernels::mul_add(attn_output, gate_msa, encoder_hidden_states);
encoder_hidden_states = std::move(attn_output);
nvtxRangePop();
......@@ -577,7 +579,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor norm_hidden_states = norm2_context.forward(encoder_hidden_states);
debug("c_scale_mlp", scale_mlp);
debug("c_shift_mlp", shift_mlp);
mul_add(norm_hidden_states, scale_mlp, shift_mlp);
kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
......@@ -592,7 +594,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug("context.ff_output", ff_output);
debug("c_gate_mlp", gate_mlp);
mul_add(ff_output, gate_mlp, encoder_hidden_states);
kernels::mul_add(ff_output, gate_mlp, encoder_hidden_states);
encoder_hidden_states = std::move(ff_output);
nvtxRangePop();
......
#include "Linear.h"
#include "kernels/gemm_w4a4.h"
#include "kernels/zgemm/zgemm.h"
#include "kernels/gemm_f16.h"
#include "kernels/misc_kernels.h"
#include "kernels/awq/gemv_awq.h"
#include "kernels/dwconv.h"
#include <nvtx3/nvToolsExt.h>
using namespace nunchaku;
GEMM_F16::GEMM_F16(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device) :
in_features(in_features), out_features(out_features)
{
this->weight = Tensor::allocate({out_features, in_features}, dtype, device);
this->bias = use_bias ? Tensor::allocate({out_features}, dtype, device) : Tensor{};
registerParams
(weight, "weight")
(bias, "bias")
;
}
Tensor GEMM_F16::forward(Tensor x) {
Tensor out = gemm_f16(x, this->weight, {}, this->bias, 1.0f);
return out;
}
GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device) :
in_features(in_features), out_features(out_features), group_size(64), lora_rank(0), lora_scale(1.0f)
......@@ -51,19 +73,19 @@ Tensor GEMV_AWQ::forward(Tensor x) {
if (bias.valid()) {
// TODO: batch
assert(out.numel() == bias.numel());
out = add(out, bias.view(out.shape.dataExtent));
out = kernels::add(out, bias.view(out.shape.dataExtent));
}
debug("out_before_lora", out);
if (this->lora_rank > 0) {
Tensor lora_act = gemm_f16(x, this->lora_down, {}, 1.0f, 0.0f);
Tensor lora_act = gemm_f16(x, this->lora_down, {}, {}, 1.0f);
debug("lora_act", lora_act);
Tensor lora_out = gemm_f16(lora_act, this->lora_up, {}, this->lora_scale, 0.0f);
Tensor lora_out = gemm_f16(lora_act, this->lora_up, {}, {}, this->lora_scale);
debug("lora_out", lora_out);
out = add(out, lora_out);
out = kernels::add(out, lora_out);
}
debug("out", out);
......@@ -75,19 +97,21 @@ Tensor GEMV_AWQ::forward(Tensor x) {
#define NO_LORA_FUSION 0
GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device) :
in_features(in_features), out_features(out_features), lora_rank(0), dtype(dtype)
in_features(in_features), out_features(out_features),
in_features_pad(ceilDiv(in_features, 128) * 128), out_features_pad(ceilDiv(out_features, 128) * 128),
lora_rank(0), dtype(dtype)
{
this->qweight = Tensor::allocate({out_features, in_features / 2}, Tensor::INT8, device, true);
this->wscales = Tensor::allocate({in_features / 64, out_features}, dtype, device, true);
this->qweight = Tensor::allocate({out_features_pad, in_features_pad / 2}, Tensor::INT8, device, true);
this->wscales = Tensor::allocate({in_features_pad / 64, out_features_pad}, dtype, device, true);
this->bias = bias ? Tensor::allocate({out_features}, dtype, device, true) : Tensor{};
this->bias = bias ? Tensor::allocate({out_features_pad}, dtype, device, true) : Tensor{};
this->lora_down = Tensor::allocate({in_features, lora_rank}, dtype, device, true);
this->lora_up = Tensor::allocate({out_features, lora_rank}, dtype, device, true);
this->lora_down = Tensor::allocate({in_features_pad, lora_rank}, dtype, device, true);
this->lora_up = Tensor::allocate({out_features_pad, lora_rank}, dtype, device, true);
// TODO: smooth factor in FC1+FC2 fusion
// TODO: smooth factor in non-Lora fusion
this->smooth = Tensor::allocate({in_features}, dtype, device, true);
this->smooth = Tensor::allocate({in_features_pad}, dtype, device, true);
registerParams
(qweight, "qweight")
......@@ -118,12 +142,20 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
}
}
Tensor GEMM_W4A4::forward(Tensor x) {
return std::get<Tensor>(this->forward(x, FuseOptions::EMPTY, nullptr));
}
Tensor GEMM_W4A4::forward_silu(Tensor x) {
return std::get<Tensor>(this->forward(x, FuseOptions::SILU, nullptr));
}
std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM) {
return forward_quant(quantize(x), fuse, nextGEMM);
return forward_quant(quantize(x, false), fuse, nextGEMM);
}
void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor norm_k, Tensor rotary_emb) {
QuantizedActivation qact = quantize(x);
QuantizedActivation qact = quantize(x, false);
#if !NO_LORA_FUSION
......@@ -135,13 +167,13 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
debug("gemm.nolora.out", out);
#endif
gemm_w4a4(qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, qact.lora_act, this->lora_up, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, qact.is_unsigned, this->lora_scales);
kernels::gemm_w4a4(qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, qact.lora_act, this->lora_up, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, {}, {}, qact.is_unsigned, this->lora_scales, false);
debug("gemm.out", out);
#else
const int M = (int)qact.act.numel() / qact.act.shape[-1];
gemm_w4a4(qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, {}, {}, {}, {});
kernels::gemm_w4a4(qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, {}, {}, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, qact.is_unsigned, this->lora_scales);
nvtxRangePushA("LoraUp");
......@@ -175,17 +207,18 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
const int M = (int)qact.act.numel() / qact.act.shape[-1];
if (fuse == FuseOptions::EMPTY) {
auto shape = TensorShape(qact.act.shape.dataExtent);
if (fuse == FuseOptions::EMPTY || fuse == FuseOptions::SILU) {
// auto shape = TensorShape(qact.act.shape.dataExtent);
// shape[-1] = out_features;
auto shape = TensorShape(qact.actShape.dataExtent);
shape[-1] = out_features;
out = Tensor::allocate(shape, dtype, qweight.device());
} else {
auto shape = TensorShape(qact.act.shape.dataExtent);
shape[-1] = out_features / 2;
qout.act = Tensor::allocate(shape, Tensor::INT8, qweight.device());
qout.ascales = Tensor::allocate({out_features / 64, M}, dtype, qweight.device());
qout.act = Tensor::allocate({M, out_features_pad / 2}, Tensor::INT8, qweight.device());
qout.ascales = Tensor::allocate({out_features_pad / 64, M}, dtype, qweight.device());
qout.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device());
qout.is_unsigned = true;
qout.actShape = qact.actShape;
next_lora = nextGEMM->lora_down;
next_smooth = nextGEMM->smooth;
......@@ -208,9 +241,9 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
}
#endif
gemm_w4a4(qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, qact.lora_act, this->lora_up, next_lora, qout.lora_act, {}, {}, {}, this->bias, next_smooth, qact.is_unsigned, this->lora_scales);
kernels::gemm_w4a4(qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, qact.lora_act, this->lora_up, next_lora, qout.lora_act, {}, {}, {}, this->bias, next_smooth, {}, {}, qact.is_unsigned, this->lora_scales, fuse == FuseOptions::SILU);
if (fuse == FuseOptions::EMPTY) {
if (fuse == FuseOptions::EMPTY || fuse == FuseOptions::SILU) {
debug("gemm.out", out);
} else {
debug("gemm.qout", qout.act);
......@@ -226,7 +259,7 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
out = Tensor::allocate(shape, Tensor::FP16, qweight.device());
}
gemm_w4a4(qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, {}, {}, {}, {});
kernels::gemm_w4a4(qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, {}, {}, {}, {}, {}, {}, {}, this->bias, next_smooth, qact.is_unsigned, this->lora_scales);
nvtxRangePushA("LoraUp");
......@@ -281,23 +314,29 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
return qout;
}
GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x) {
const int M = x.numel() / x.shape[-1];
Tensor GEMM_W4A4::forward_quant(QuantizedActivation qact) {
return std::get<Tensor>(this->forward_quant(qact, FuseOptions::EMPTY, nullptr));
}
GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
const int actualM = x.numel() / x.shape[-1];
const int M = ceilDiv(actualM, 256) * 256;
auto shape = TensorShape(x.shape.dataExtent);
shape[-1] = in_features / 2;
// auto shape = TensorShape(x.shape.dataExtent);
// shape[-1] = in_features / 2;
QuantizedActivation qact;
qact.act = Tensor::allocate(shape, Tensor::INT8, qweight.device());
qact.ascales = Tensor::allocate({in_features / 64, M}, dtype, qweight.device());
qact.act = Tensor::allocate({M, in_features_pad / 2}, Tensor::INT8, qweight.device());
qact.ascales = Tensor::allocate({in_features_pad / 64, M}, dtype, qweight.device());
qact.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device());
qact.is_unsigned = false;
qact.actShape = x.shape.dataExtent;
#if !NO_LORA_FUSION
debug("quantize.x", x);
debug("quantize.smooth", this->smooth);
quantize_w4a4_act_fuse_lora(x, qact.act, qact.ascales, this->lora_down, qact.lora_act, this->smooth);
kernels::quantize_w4a4_act_fuse_lora(x, qact.act, qact.ascales, this->lora_down, qact.lora_act, this->smooth, fuse_glu);
debug("quantize.qact", qact.act);
debug("quantize.ascales", qact.ascales);
......@@ -325,9 +364,68 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x) {
nvtxRangePop();
quantize_w4a4_act(x, qact.act, qact.ascales);
kernels::quantize_w4a4_act(x, qact.act, qact.ascales);
#endif
return qact;
}
GEMM_W8A8::GEMM_W8A8(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device) :
in_features(in_features), out_features(out_features), dtype(dtype)
{
this->qweight = Tensor::allocate({out_features, in_features}, Tensor::INT8, device);
this->wscales = Tensor::allocate({out_features}, dtype, device);
this->bias = bias ? Tensor::allocate({out_features}, dtype, device, true) : Tensor{};
registerParams
(qweight, "qweight")
(wscales, "wscales")
(this->bias, "bias")
;
}
GEMM_W8A8::QuantizedActivation GEMM_W8A8::quantize(Tensor x, bool fuse_glu) {
QuantizedActivation qact;
auto qshape = x.shape;
if (fuse_glu) {
qshape[-1] /= 2;
}
qact.act = Tensor::allocate(qshape, Tensor::INT8, x.device());
qact.ascales = Tensor::allocate({(int)x.numel() / x.shape[-1]}, this->dtype, x.device());
debug("quantize.x", x);
kernels::quantize_w8a8_act(x, qact.act, qact.ascales, fuse_glu);
debug("quantize.qact", qact.act);
debug("quantize.ascales", qact.ascales);
return qact;
}
Tensor GEMM_W8A8::forward_quant(QuantizedActivation qact) {
auto oshape = qact.act.shape;
oshape[-1] = out_features;
Tensor out = Tensor::allocate(oshape, this->dtype, qact.act.device());
kernels::gemm_w8a8(qact.act, this->qweight, out, qact.ascales, this->wscales, this->bias);
debug("gemm.out", out);
return out;
}
DWCONV::DWCONV(int in_features, bool use_bias, Tensor::ScalarType dtype, Device device) :
in_features(in_features)
{
this->weight = Tensor::allocate({in_features, 3, 3, 1}, dtype, device);
this->bias = use_bias ? Tensor::allocate({in_features}, dtype, device) : Tensor{};
registerParams
(this->weight, "weight")
(this->bias, "bias")
;
}
Tensor DWCONV::forward(Tensor x) {
return dwconv_f16(x, this->weight, {}, this->bias);
}
\ No newline at end of file
......@@ -4,6 +4,21 @@
#include "Tensor.h"
#include "Module.h"
class GEMM_F16 : public Module {
public:
GEMM_F16(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor x);
public:
const int in_features;
const int out_features;
public:
Tensor weight;
Tensor bias;
};
class GEMV_AWQ : public Module {
public:
GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device);
......@@ -38,26 +53,34 @@ public:
enum class FuseOptions {
EMPTY = 0,
GELU_QUANT,
SILU,
};
struct QuantizedActivation {
Tensor act;
Tensor ascales;
Tensor lora_act;
bool is_unsigned = false;
TensorShape actShape;
};
public:
GEMM_W4A4(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device);
std::variant<Tensor, QuantizedActivation> forward(Tensor x, FuseOptions fuse = FuseOptions::EMPTY, GEMM_W4A4 *nextGEMM = nullptr);
Tensor forward(Tensor x);
Tensor forward_silu(Tensor x);
std::variant<Tensor, QuantizedActivation> forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr);
void forward(Tensor x, Tensor out, Tensor pool = {}, Tensor norm_q = {}, Tensor norm_k = {}, Tensor rotary_emb = {});
std::variant<Tensor, QuantizedActivation> forward_quant(QuantizedActivation qact, FuseOptions fuse = FuseOptions::EMPTY, GEMM_W4A4 *nextGEMM = nullptr);
std::variant<Tensor, QuantizedActivation> forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr);
Tensor forward_quant(QuantizedActivation qact);
public:
QuantizedActivation quantize(Tensor x);
QuantizedActivation quantize(Tensor x, bool fuse_glu);
public:
const int in_features;
const int out_features;
const int in_features_pad;
const int out_features_pad;
int lora_rank;
std::vector<float> lora_scales; // every 16 ranks share a scale
......@@ -79,5 +102,41 @@ public:
cublasHandle_t handle;
};
// TODO
class GEMM_W8A8;
\ No newline at end of file
class GEMM_W8A8 : public Module {
public:
struct QuantizedActivation {
Tensor act;
Tensor ascales;
};
public:
GEMM_W8A8(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device);
public:
QuantizedActivation quantize(Tensor x, bool fuse_glu);
Tensor forward_quant(QuantizedActivation qact);
Tensor forward(Tensor x) { return forward_quant(quantize(x, false)); }
public:
const int in_features;
const int out_features;
const Tensor::ScalarType dtype;
public:
Tensor qweight;
Tensor wscales;
Tensor bias;
};
class DWCONV : public Module {
public:
DWCONV(int in_features, bool bias, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor x);
public:
const int in_features;
public:
Tensor weight;
Tensor bias;
};
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment