Commit 54e6d065 authored by muyangli's avatar muyangli
Browse files

[major] support NVFP4; upgrade to 0.1

parent c7f41661
...@@ -28,7 +28,10 @@ namespace nunchaku::ops { ...@@ -28,7 +28,10 @@ namespace nunchaku::ops {
std::optional<torch::Tensor> out_linearattn,// linear [B, (M), N / 3] std::optional<torch::Tensor> out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned, bool act_unsigned,
std::vector<float> lora_scales, std::vector<float> lora_scales,
bool fuse_silu bool fuse_silu,
bool fp4,
float alpha,
std::optional<torch::Tensor> wcscales
) { ) {
spdlog::trace("running gemm_w4a4: "); spdlog::trace("running gemm_w4a4: ");
...@@ -63,7 +66,10 @@ namespace nunchaku::ops { ...@@ -63,7 +66,10 @@ namespace nunchaku::ops {
getTensor(out_linearattn), getTensor(out_linearattn),
act_unsigned, act_unsigned,
lora_scales, lora_scales,
fuse_silu fuse_silu,
fp4,
alpha,
getTensor(wcscales)
); );
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
} }
......
...@@ -11,6 +11,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -11,6 +11,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel") py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel")
.def(py::init<>()) .def(py::init<>())
.def("init", &QuantizedFluxModel::init, .def("init", &QuantizedFluxModel::init,
py::arg("use_fp4"),
py::arg("bf16"), py::arg("bf16"),
py::arg("deviceId") py::arg("deviceId")
) )
...@@ -33,6 +34,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -33,6 +34,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("init", &QuantizedSanaModel::init, .def("init", &QuantizedSanaModel::init,
py::arg("config"), py::arg("config"),
py::arg("pag_layers"), py::arg("pag_layers"),
py::arg("use_fp4"),
py::arg("bf16"), py::arg("bf16"),
py::arg("deviceId") py::arg("deviceId")
) )
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
class QuantizedSanaModel : public ModuleWrapper<SanaModel> { class QuantizedSanaModel : public ModuleWrapper<SanaModel> {
public: public:
void init(pybind11::dict config, std::vector<int> pag_layers, bool bf16, int8_t deviceId) { void init(pybind11::dict config, std::vector<int> pag_layers, bool use_fp4, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedSanaModel"); spdlog::info("Initializing QuantizedSanaModel");
SanaConfig cfg{ SanaConfig cfg{
.num_layers = config["num_layers"].cast<int>(), .num_layers = config["num_layers"].cast<int>(),
...@@ -17,6 +17,7 @@ public: ...@@ -17,6 +17,7 @@ public:
.num_cross_attention_heads = config["num_cross_attention_heads"].cast<int>(), .num_cross_attention_heads = config["num_cross_attention_heads"].cast<int>(),
.expand_ratio = config["mlp_ratio"].cast<double>(), .expand_ratio = config["mlp_ratio"].cast<double>(),
.pag_layers = pag_layers, .pag_layers = pag_layers,
.use_fp4 = use_fp4,
}; };
net = std::make_unique<SanaModel>(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); net = std::make_unique<SanaModel>(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
} }
......
...@@ -108,13 +108,12 @@ class EmbedND(nn.Module): ...@@ -108,13 +108,12 @@ class EmbedND(nn.Module):
return emb.unsqueeze(1) return emb.unsqueeze(1)
def load_quantized_module(path: str, device: str | torch.device = "cuda") -> QuantizedFluxModel: def load_quantized_module(path: str, device: str | torch.device = "cuda", use_fp4: bool = False) -> QuantizedFluxModel:
device = torch.device(device) device = torch.device(device)
assert device.type == "cuda" assert device.type == "cuda"
m = QuantizedFluxModel() m = QuantizedFluxModel()
cutils.disable_memory_auto_release() cutils.disable_memory_auto_release()
m.init(True, 0 if device.index is None else device.index) m.init(use_fp4, True, 0 if device.index is None else device.index)
m.load(path) m.load(path)
return m return m
...@@ -153,8 +152,10 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -153,8 +152,10 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
@utils.validate_hf_hub_args @utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs):
device = kwargs.get("device", "cuda") device = kwargs.get("device", "cuda")
precision = kwargs.get("precision", "int4")
assert precision in ["int4", "fp4"]
transformer, transformer_block_path = cls._build_model(pretrained_model_name_or_path, **kwargs) transformer, transformer_block_path = cls._build_model(pretrained_model_name_or_path, **kwargs)
m = load_quantized_module(transformer_block_path, device=device) m = load_quantized_module(transformer_block_path, device=device, use_fp4=precision == "fp4")
transformer.inject_quantized_module(m, device) transformer.inject_quantized_module(m, device)
return transformer return transformer
......
...@@ -4,7 +4,13 @@ from diffusers import FluxPipeline ...@@ -4,7 +4,13 @@ from diffusers import FluxPipeline
from .models.transformer_flux import NunchakuFluxTransformer2dModel from .models.transformer_flux import NunchakuFluxTransformer2dModel
if __name__ == "__main__": if __name__ == "__main__":
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell") capability = torch.cuda.get_device_capability(0)
sm = f"{capability[0]}{capability[1]}"
precision = "fp4" if sm == "120" else "int4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-schnell", precision=precision
)
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
......
import os import os
import re
import subprocess
import sys
import setuptools import setuptools
from torch.utils.cpp_extension import BuildExtension, CUDAExtension import torch
from packaging import version as packaging_version
from torch.utils.cpp_extension import BuildExtension, CUDA_HOME, CUDAExtension
class CustomBuildExtension(BuildExtension): class CustomBuildExtension(BuildExtension):
...@@ -19,6 +24,40 @@ class CustomBuildExtension(BuildExtension): ...@@ -19,6 +24,40 @@ class CustomBuildExtension(BuildExtension):
super().build_extensions() super().build_extensions()
def get_sm_targets() -> list[str]:
nvcc_path = os.path.join(CUDA_HOME, "bin/nvcc") if CUDA_HOME else "nvcc"
try:
nvcc_output = subprocess.check_output([nvcc_path, "--version"]).decode()
match = re.search(r"release (\d+\.\d+), V(\d+\.\d+\.\d+)", nvcc_output)
if match:
nvcc_version = match.group(2)
else:
raise Exception("nvcc version not found")
print(f"Found nvcc version: {nvcc_version}")
except:
raise Exception("nvcc not found")
support_sm120 = packaging_version.parse(nvcc_version) >= packaging_version.parse("12.8")
install_mode = os.getenv("NUNCHAKU_INSTALL_MODE", "FAST")
if install_mode == "FAST":
ret = []
for i in range(torch.cuda.device_count()):
capability = torch.cuda.get_device_capability(i)
sm = f"{capability[0]}{capability[1]}"
if sm == "120" and support_sm120:
sm = "120a"
assert sm in ["80", "86", "89", "120a"], f"Unsupported SM {sm}"
if sm not in ret:
ret.append(sm)
else:
assert install_mode == "ALL"
ret = ["80", "86", "89"]
if support_sm120:
ret.append("120a")
return ret
if __name__ == "__main__": if __name__ == "__main__":
fp = open("nunchaku/__version__.py", "r").read() fp = open("nunchaku/__version__.py", "r").read()
version = eval(fp.strip().split()[-1]) version = eval(fp.strip().split()[-1])
...@@ -55,12 +94,6 @@ if __name__ == "__main__": ...@@ -55,12 +94,6 @@ if __name__ == "__main__":
NVCC_FLAGS = [ NVCC_FLAGS = [
"-DENABLE_BF16=1", "-DENABLE_BF16=1",
"-DBUILD_NUNCHAKU=1", "-DBUILD_NUNCHAKU=1",
"-gencode",
"arch=compute_86,code=sm_86",
"-gencode",
"arch=compute_89,code=sm_89",
# "-gencode",
# "arch=compute_89,code=sm_120a",
"-g", "-g",
"-std=c++20", "-std=c++20",
"-UNDEBUG", "-UNDEBUG",
...@@ -75,13 +108,21 @@ if __name__ == "__main__": ...@@ -75,13 +108,21 @@ if __name__ == "__main__":
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__", "-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__", "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--threads=2", "--threads=3",
"--expt-relaxed-constexpr", "--expt-relaxed-constexpr",
"--expt-extended-lambda", "--expt-extended-lambda",
"--generate-line-info", "--generate-line-info",
"--ptxas-options=--allow-expensive-optimizations=true", "--ptxas-options=--allow-expensive-optimizations=true",
] ]
# https://github.com/NVIDIA/cutlass/pull/1479#issuecomment-2052300487
sm_targets = get_sm_targets()
print(f"Detected SM targets: {sm_targets}", file=sys.stderr)
assert len(sm_targets) > 0, "No SM targets found"
for target in sm_targets:
NVCC_FLAGS += ["-gencode", f"arch=compute_{target},code=sm_{target}"]
NVCC_MSVC_FLAGS = ["-Xcompiler", "/Zc:__cplusplus"] NVCC_MSVC_FLAGS = ["-Xcompiler", "/Zc:__cplusplus"]
nunchaku_extension = CUDAExtension( nunchaku_extension = CUDAExtension(
......
...@@ -259,19 +259,19 @@ void Attention::setForceFP16(Module *module, bool value) { ...@@ -259,19 +259,19 @@ void Attention::setForceFP16(Module *module, bool value) {
}); });
} }
FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, Tensor::ScalarType dtype, Device device) : FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, bool use_fp4, Tensor::ScalarType dtype, Device device) :
dim(dim), dim(dim),
dim_head(attention_head_dim / num_attention_heads), dim_head(attention_head_dim / num_attention_heads),
num_heads(num_attention_heads), num_heads(num_attention_heads),
mlp_hidden_dim(dim * mlp_ratio), mlp_hidden_dim(dim * mlp_ratio),
norm(dim, dtype, device), norm(dim, dtype, device),
mlp_fc1(dim, mlp_hidden_dim, true, dtype, device), mlp_fc1(dim, mlp_hidden_dim, true, use_fp4, dtype, device),
mlp_fc2(mlp_hidden_dim, dim, true, dtype, device), mlp_fc2(mlp_hidden_dim, dim, true, use_fp4, dtype, device),
qkv_proj(dim, dim * 3, true, dtype, device), qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
norm_q(dim_head, 1e-6, false, dtype, device), norm_q(dim_head, 1e-6, false, dtype, device),
norm_k(dim_head, 1e-6, false, dtype, device), norm_k(dim_head, 1e-6, false, dtype, device),
attn(num_attention_heads, attention_head_dim / num_attention_heads, device), attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
out_proj(dim, dim, true, dtype, device) out_proj(dim, dim, true, use_fp4, dtype, device)
{ {
registerChildren registerChildren
(norm, "norm") (norm, "norm")
...@@ -327,28 +327,28 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te ...@@ -327,28 +327,28 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
return hidden_states; return hidden_states;
} }
JointTransformerBlock::JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, Tensor::ScalarType dtype, Device device) : JointTransformerBlock::JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, bool use_fp4, Tensor::ScalarType dtype, Device device) :
dim(dim), dim(dim),
dim_head(attention_head_dim / num_attention_heads), dim_head(attention_head_dim / num_attention_heads),
num_heads(num_attention_heads), num_heads(num_attention_heads),
context_pre_only(context_pre_only), context_pre_only(context_pre_only),
norm1(dim, false, dtype, device), norm1(dim, false, dtype, device),
norm1_context(dim, context_pre_only, dtype, device), norm1_context(dim, context_pre_only, dtype, device),
qkv_proj(dim, dim * 3, true, dtype, device), qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
qkv_proj_context(dim, dim * 3, true, dtype, device), qkv_proj_context(dim, dim * 3, true, use_fp4, dtype, device),
norm_q(dim_head, 1e-6, false, dtype, device), norm_q(dim_head, 1e-6, false, dtype, device),
norm_k(dim_head, 1e-6, false, dtype, device), norm_k(dim_head, 1e-6, false, dtype, device),
norm_added_q(dim_head, 1e-6, false, dtype, device), norm_added_q(dim_head, 1e-6, false, dtype, device),
norm_added_k(dim_head, 1e-6, false, dtype, device), norm_added_k(dim_head, 1e-6, false, dtype, device),
attn(num_attention_heads, attention_head_dim / num_attention_heads, device), attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
out_proj(dim, dim, true, dtype, device), out_proj(dim, dim, true, use_fp4, dtype, device),
out_proj_context(dim, dim, true, dtype, device), out_proj_context(dim, dim, true, use_fp4, dtype, device),
norm2(dim, 1e-6, false, dtype, device), norm2(dim, 1e-6, false, dtype, device),
norm2_context(dim, 1e-6, false, dtype, device), norm2_context(dim, 1e-6, false, dtype, device),
mlp_fc1(dim, dim * 4, true, dtype, device), mlp_fc1(dim, dim * 4, true, use_fp4, dtype, device),
mlp_fc2(dim * 4, dim, true, dtype, device), mlp_fc2(dim * 4, dim, true, use_fp4, dtype, device),
mlp_context_fc1(dim, dim * 4, true, dtype, device), mlp_context_fc1(dim, dim * 4, true, use_fp4, dtype, device),
mlp_context_fc2(dim * 4, dim, true, dtype, device) mlp_context_fc2(dim * 4, dim, true, use_fp4, dtype, device)
{ {
registerChildren registerChildren
(norm1, "norm1") (norm1, "norm1")
...@@ -607,13 +607,13 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -607,13 +607,13 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
return { hidden_states, encoder_hidden_states }; return { hidden_states, encoder_hidden_states };
} }
FluxModel::FluxModel(Tensor::ScalarType dtype, Device device) { FluxModel::FluxModel(bool use_fp4, Tensor::ScalarType dtype, Device device) {
for (int i = 0; i < 19; i++) { for (int i = 0; i < 19; i++) {
transformer_blocks.push_back(std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, dtype, device)); transformer_blocks.push_back(std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device));
registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i)); registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
} }
for (int i = 0; i < 38; i++) { for (int i = 0; i < 38; i++) {
single_transformer_blocks.push_back(std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, dtype, Device::cuda())); single_transformer_blocks.push_back(std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, Device::cuda()));
registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i)); registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i));
} }
} }
......
...@@ -77,7 +77,7 @@ public: ...@@ -77,7 +77,7 @@ public:
static constexpr bool USE_4BIT = true; static constexpr bool USE_4BIT = true;
using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>; using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>;
FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, Tensor::ScalarType dtype, Device device); FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, bool use_fp4, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb); Tensor forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb);
public: public:
...@@ -101,7 +101,7 @@ public: ...@@ -101,7 +101,7 @@ public:
static constexpr bool USE_4BIT = true; static constexpr bool USE_4BIT = true;
using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>; using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>;
JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, Tensor::ScalarType dtype, Device device); JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, bool use_fp4, Tensor::ScalarType dtype, Device device);
std::tuple<Tensor, Tensor> forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb, Tensor rotary_emb_context, float sparsityRatio); std::tuple<Tensor, Tensor> forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb, Tensor rotary_emb_context, float sparsityRatio);
public: public:
...@@ -128,7 +128,7 @@ private: ...@@ -128,7 +128,7 @@ private:
class FluxModel : public Module { class FluxModel : public Module {
public: public:
FluxModel(Tensor::ScalarType dtype, Device device); FluxModel(bool use_fp4, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single); Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single);
public: public:
......
...@@ -96,23 +96,33 @@ Tensor GEMV_AWQ::forward(Tensor x) { ...@@ -96,23 +96,33 @@ Tensor GEMV_AWQ::forward(Tensor x) {
#define NO_LORA_FUSION 0 #define NO_LORA_FUSION 0
GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device) : GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4, Tensor::ScalarType dtype, Device device) :
in_features(in_features), out_features(out_features), in_features(in_features), out_features(out_features),
in_features_pad(ceilDiv(in_features, 128) * 128), out_features_pad(ceilDiv(out_features, 128) * 128), in_features_pad(ceilDiv(in_features, 128) * 128), out_features_pad(ceilDiv(out_features, 128) * 128),
use_fp4(use_fp4),
lora_rank(0), dtype(dtype) lora_rank(0), dtype(dtype)
{ {
this->qweight = Tensor::allocate({out_features_pad, in_features_pad / 2}, Tensor::INT8, device, true); this->qweight = Tensor::allocate({out_features_pad, in_features_pad / 2}, Tensor::INT8, device, true);
this->wscales = Tensor::allocate({in_features_pad / 64, out_features_pad}, dtype, device, true); if (use_fp4) {
this->wscales = Tensor::allocate({in_features_pad / 16, out_features_pad}, Tensor::FP8_E4M3, device, true);
} else {
this->wscales = Tensor::allocate({in_features_pad / 64, out_features_pad}, dtype, device, true);
}
this->bias = bias ? Tensor::allocate({out_features_pad}, dtype, device, true) : Tensor{}; this->bias = bias ? Tensor::allocate({out_features_pad}, dtype, device, true) : Tensor{};
this->lora_down = Tensor::allocate({in_features_pad, lora_rank}, dtype, device, true); this->lora_down = Tensor::allocate({in_features_pad, lora_rank}, dtype, device, true);
this->lora_up = Tensor::allocate({out_features_pad, lora_rank}, dtype, device, true); this->lora_up = Tensor::allocate({out_features_pad, lora_rank}, dtype, device, true);
// TODO: smooth factor in FC1+FC2 fusion
// TODO: smooth factor in non-Lora fusion // TODO: smooth factor in non-Lora fusion
this->smooth = Tensor::allocate({in_features_pad}, dtype, device, true); this->smooth = Tensor::allocate({in_features_pad}, dtype, device, true);
// FIXME: reset wtscale and wcscales to default values when reloading the weights
this->wtscale = Tensor::allocate({1}, Tensor::FP32, Device::cpu(), true);
*this->wtscale.data_ptr<float>() = 1.0f;
this->wcscales = Tensor::allocate({0}, dtype, device, true);
registerParams registerParams
(qweight, "qweight") (qweight, "qweight")
(wscales, "wscales") (wscales, "wscales")
...@@ -120,6 +130,8 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, Tensor::Scala ...@@ -120,6 +130,8 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, Tensor::Scala
(lora_down, "lora_down", ParamFlags::Optional) (lora_down, "lora_down", ParamFlags::Optional)
(lora_up, "lora_up", ParamFlags::Optional) (lora_up, "lora_up", ParamFlags::Optional)
(smooth, "smooth") (smooth, "smooth")
(wtscale, "wtscale", ParamFlags::Optional)
(wcscales, "wcscales", ParamFlags::Optional)
; ;
#if NO_LORA_FUSION #if NO_LORA_FUSION
...@@ -137,6 +149,21 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) { ...@@ -137,6 +149,21 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
} else { } else {
dst.copy_(src); dst.copy_(src);
} }
} else if (key == "wcscales") {
assert(src.ndims() == 1);
assert(src.shape[0] == out_features_pad);
dst = src.copy(this->qweight.device());
} else if (key == "wtscale") {
assert(src.numel() == 1);
if (src.dtype() == Tensor::BF16) {
*dst.data_ptr<float>() = float(*src.data_ptr<__nv_bfloat16>());
} else if (src.dtype() == Tensor::FP16) {
*dst.data_ptr<float>() = float(*src.data_ptr<half>());
} else if (src.dtype() == Tensor::FP32) {
dst.copy_(src);
} else {
assert(false);
}
} else { } else {
Module::loadParam(key, dst, src); Module::loadParam(key, dst, src);
} }
...@@ -167,7 +194,10 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor ...@@ -167,7 +194,10 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
debug("gemm.nolora.out", out); debug("gemm.nolora.out", out);
#endif #endif
kernels::gemm_w4a4(qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, qact.lora_act, this->lora_up, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, {}, {}, qact.is_unsigned, this->lora_scales, false); kernels::gemm_w4a4(
qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, qact.lora_act, this->lora_up, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, {}, {}, qact.is_unsigned, this->lora_scales, false,
use_fp4, *this->wtscale.data_ptr<float>(), wcscales.numel() > 0 ? wcscales: Tensor{}
);
debug("gemm.out", out); debug("gemm.out", out);
#else #else
...@@ -215,9 +245,13 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu ...@@ -215,9 +245,13 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
out = Tensor::allocate(shape, dtype, qweight.device()); out = Tensor::allocate(shape, dtype, qweight.device());
} else { } else {
qout.act = Tensor::allocate({M, out_features_pad / 2}, Tensor::INT8, qweight.device()); qout.act = Tensor::allocate({M, out_features_pad / 2}, Tensor::INT8, qweight.device());
qout.ascales = Tensor::allocate({out_features_pad / 64, M}, dtype, qweight.device()); if (use_fp4) {
qout.ascales = Tensor::allocate({out_features_pad / 16, M}, Tensor::FP8_E4M3, qweight.device());
} else {
qout.ascales = Tensor::allocate({out_features_pad / 64, M}, dtype, qweight.device());
}
qout.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device()); qout.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device());
qout.is_unsigned = true; qout.is_unsigned = !use_fp4;
qout.actShape = qact.actShape; qout.actShape = qact.actShape;
next_lora = nextGEMM->lora_down; next_lora = nextGEMM->lora_down;
...@@ -241,7 +275,10 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu ...@@ -241,7 +275,10 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
} }
#endif #endif
kernels::gemm_w4a4(qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, qact.lora_act, this->lora_up, next_lora, qout.lora_act, {}, {}, {}, this->bias, next_smooth, {}, {}, qact.is_unsigned, this->lora_scales, fuse == FuseOptions::SILU); kernels::gemm_w4a4(
qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, qact.lora_act, this->lora_up, next_lora, qout.lora_act, {}, {}, {}, this->bias, next_smooth, {}, {}, qact.is_unsigned, this->lora_scales, fuse == FuseOptions::SILU,
use_fp4, *this->wtscale.data_ptr<float>(), wcscales.numel() > 0 ? wcscales: Tensor{}
);
if (fuse == FuseOptions::EMPTY || fuse == FuseOptions::SILU) { if (fuse == FuseOptions::EMPTY || fuse == FuseOptions::SILU) {
debug("gemm.out", out); debug("gemm.out", out);
...@@ -327,7 +364,11 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) { ...@@ -327,7 +364,11 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
QuantizedActivation qact; QuantizedActivation qact;
qact.act = Tensor::allocate({M, in_features_pad / 2}, Tensor::INT8, qweight.device()); qact.act = Tensor::allocate({M, in_features_pad / 2}, Tensor::INT8, qweight.device());
qact.ascales = Tensor::allocate({in_features_pad / 64, M}, dtype, qweight.device()); if (use_fp4) {
qact.ascales = Tensor::allocate({in_features_pad / 16, M}, Tensor::FP8_E4M3, qweight.device());
} else {
qact.ascales = Tensor::allocate({in_features_pad / 64, M}, dtype, qweight.device());
}
qact.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device()); qact.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device());
qact.is_unsigned = false; qact.is_unsigned = false;
qact.actShape = x.shape.dataExtent; qact.actShape = x.shape.dataExtent;
...@@ -336,7 +377,7 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) { ...@@ -336,7 +377,7 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
debug("quantize.x", x); debug("quantize.x", x);
debug("quantize.smooth", this->smooth); debug("quantize.smooth", this->smooth);
kernels::quantize_w4a4_act_fuse_lora(x, qact.act, qact.ascales, this->lora_down, qact.lora_act, this->smooth, fuse_glu); kernels::quantize_w4a4_act_fuse_lora(x, qact.act, qact.ascales, this->lora_down, qact.lora_act, this->smooth, fuse_glu, use_fp4);
debug("quantize.qact", qact.act); debug("quantize.qact", qact.act);
debug("quantize.ascales", qact.ascales); debug("quantize.ascales", qact.ascales);
......
...@@ -64,7 +64,7 @@ public: ...@@ -64,7 +64,7 @@ public:
}; };
public: public:
GEMM_W4A4(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device); GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor x); Tensor forward(Tensor x);
Tensor forward_silu(Tensor x); Tensor forward_silu(Tensor x);
std::variant<Tensor, QuantizedActivation> forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr); std::variant<Tensor, QuantizedActivation> forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr);
...@@ -80,6 +80,7 @@ public: ...@@ -80,6 +80,7 @@ public:
const int out_features; const int out_features;
const int in_features_pad; const int in_features_pad;
const int out_features_pad; const int out_features_pad;
const bool use_fp4;
int lora_rank; int lora_rank;
std::vector<float> lora_scales; // every 16 ranks share a scale std::vector<float> lora_scales; // every 16 ranks share a scale
...@@ -99,6 +100,9 @@ public: ...@@ -99,6 +100,9 @@ public:
Tensor smooth; Tensor smooth;
Tensor wtscale;
Tensor wcscales;
cublasHandle_t handle; cublasHandle_t handle;
}; };
......
...@@ -8,11 +8,11 @@ ...@@ -8,11 +8,11 @@
using spdlog::fmt_lib::format; using spdlog::fmt_lib::format;
using namespace nunchaku; using namespace nunchaku;
SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, Tensor::ScalarType dtype, Device device) : SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device) :
dim(dim), dim(dim),
dim_pad(ceilDiv(dim, 128) * 128), dim_pad(ceilDiv(dim, 128) * 128),
qkv_proj(dim, dim_pad * 3, bias, dtype, device), qkv_proj(dim, dim_pad * 3, bias, use_fp4, dtype, device),
out_proj(dim_pad, dim, bias, dtype, device), out_proj(dim_pad, dim, bias, use_fp4, dtype, device),
pag_to_v(std::nullopt) pag_to_v(std::nullopt)
{ {
registerChildren registerChildren
...@@ -21,7 +21,7 @@ SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, Tensor::S ...@@ -21,7 +21,7 @@ SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, Tensor::S
; ;
if (pag) { if (pag) {
pag_to_v.emplace(dim, dim_pad, bias, dtype, device); pag_to_v.emplace(dim, dim_pad, bias, use_fp4, dtype, device);
registerChildren(pag_to_v.value(), "pag_to_v"); registerChildren(pag_to_v.value(), "pag_to_v");
} }
} }
...@@ -63,7 +63,11 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) { ...@@ -63,7 +63,11 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
qkv_proj.wscales, qkv_proj.wscales,
{}, {}, qact.lora_act, qkv_proj.lora_up, {}, {}, {}, {}, {}, qkv_proj.bias, {}, {}, {}, qact.lora_act, qkv_proj.lora_up, {}, {}, {}, {}, {}, qkv_proj.bias, {},
vk, q, vk, q,
qact.is_unsigned, qkv_proj.lora_scales, false); qact.is_unsigned, qkv_proj.lora_scales, false,
qkv_proj.use_fp4,
*qkv_proj.wtscale.data_ptr<float>(),
qkv_proj.wcscales.numel() > 0 ? qkv_proj.wcscales : Tensor{}
);
debug("vk", vk); debug("vk", vk);
debug("q", q); debug("q", q);
...@@ -121,11 +125,11 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) { ...@@ -121,11 +125,11 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
return out; return out;
} }
MultiHeadCrossAttention::MultiHeadCrossAttention(int num_heads, int head_dim, Tensor::ScalarType dtype, Device device) : MultiHeadCrossAttention::MultiHeadCrossAttention(int num_heads, int head_dim, bool use_fp4, Tensor::ScalarType dtype, Device device) :
num_heads(num_heads), head_dim(head_dim), num_heads(num_heads), head_dim(head_dim),
q_linear(num_heads * head_dim, num_heads * head_dim, true, dtype, device), q_linear(num_heads * head_dim, num_heads * head_dim, true, use_fp4, dtype, device),
kv_linear(num_heads * head_dim, num_heads * head_dim * 2, true, dtype, device), kv_linear(num_heads * head_dim, num_heads * head_dim * 2, true, dtype, device),
out_proj(num_heads * head_dim, num_heads * head_dim, true, dtype, device) out_proj(num_heads * head_dim, num_heads * head_dim, true, use_fp4, dtype, device)
{ {
registerChildren registerChildren
(q_linear, "q_linear") (q_linear, "q_linear")
...@@ -173,11 +177,11 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens ...@@ -173,11 +177,11 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
return out_proj.forward(attn_output); return out_proj.forward(attn_output);
} }
SanaGLUMBConv::SanaGLUMBConv(int in_features, int hidden_features, Tensor::ScalarType dtype, Device device) : SanaGLUMBConv::SanaGLUMBConv(int in_features, int hidden_features, bool use_fp4, Tensor::ScalarType dtype, Device device) :
in_features(in_features), hidden_features(hidden_features), in_features(in_features), hidden_features(hidden_features),
inverted_conv(in_features, hidden_features * 2, true, dtype, device), inverted_conv(in_features, hidden_features * 2, true, use_fp4, dtype, device),
depth_conv(hidden_features * 2, true, dtype, device), depth_conv(hidden_features * 2, true, dtype, device),
point_conv(hidden_features, in_features, false, dtype, device) point_conv(hidden_features, in_features, false, use_fp4, dtype, device)
{ {
registerChildren registerChildren
(inverted_conv, "inverted_conv") (inverted_conv, "inverted_conv")
...@@ -200,11 +204,11 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) { ...@@ -200,11 +204,11 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
return point_conv.forward_quant(qact); return point_conv.forward_quant(qact);
} }
SanaLinearTransformerBlock::SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, Tensor::ScalarType dtype, Device device) : SanaLinearTransformerBlock::SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device) :
hidden_size(hidden_size), num_cross_attention_heads(num_cross_attention_heads), hidden_size(hidden_size), num_cross_attention_heads(num_cross_attention_heads),
attn(hidden_size, false, pag, dtype, device), attn(hidden_size, false, pag, use_fp4, dtype, device),
cross_attn(num_cross_attention_heads, hidden_size / num_cross_attention_heads, dtype, device), cross_attn(num_cross_attention_heads, hidden_size / num_cross_attention_heads, use_fp4, dtype, device),
ff(hidden_size, intermediate_size, dtype, device), ff(hidden_size, intermediate_size, use_fp4, dtype, device),
norm1(hidden_size, 1e-6, false, dtype, device), norm1(hidden_size, 1e-6, false, dtype, device),
norm2(hidden_size, 1e-6, false, dtype, device) norm2(hidden_size, 1e-6, false, dtype, device)
{ {
...@@ -313,6 +317,7 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device) ...@@ -313,6 +317,7 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
ceilDiv(int(round(config.expand_ratio * inner_dim)), 64) * 64, ceilDiv(int(round(config.expand_ratio * inner_dim)), 64) * 64,
config.num_cross_attention_heads, config.num_cross_attention_heads,
std::find(config.pag_layers.begin(), config.pag_layers.end(), i) != config.pag_layers.end(), std::find(config.pag_layers.begin(), config.pag_layers.end(), i) != config.pag_layers.end(),
config.use_fp4,
dtype, device dtype, device
)); ));
registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i)); registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
class SanaLinearAttention : public Module { class SanaLinearAttention : public Module {
public: public:
SanaLinearAttention(int dim, bool bias, bool pag, Tensor::ScalarType dtype, Device device); SanaLinearAttention(int dim, bool bias, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor x, Tensor out = {}); Tensor forward(Tensor x, Tensor out = {});
Tensor forward_pag(Tensor x, bool cfg); Tensor forward_pag(Tensor x, bool cfg);
...@@ -25,7 +25,7 @@ private: ...@@ -25,7 +25,7 @@ private:
class MultiHeadCrossAttention : public Module { class MultiHeadCrossAttention : public Module {
public: public:
MultiHeadCrossAttention(int num_heads, int head_dim, Tensor::ScalarType dtype, Device device); MultiHeadCrossAttention(int num_heads, int head_dim, bool use_fp4, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor x, Tensor cond, Tensor cu_seqlens_img, Tensor cu_seqlens_txt); Tensor forward(Tensor x, Tensor cond, Tensor cu_seqlens_img, Tensor cu_seqlens_txt);
...@@ -41,7 +41,7 @@ private: ...@@ -41,7 +41,7 @@ private:
class SanaGLUMBConv : public Module { class SanaGLUMBConv : public Module {
public: public:
SanaGLUMBConv(int in_features, int hidden_features, Tensor::ScalarType dtype, Device device); SanaGLUMBConv(int in_features, int hidden_features, bool use_fp4, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor x, int H, int W); Tensor forward(Tensor x, int H, int W);
...@@ -57,7 +57,7 @@ private: ...@@ -57,7 +57,7 @@ private:
class SanaLinearTransformerBlock : public Module { class SanaLinearTransformerBlock : public Module {
public: public:
SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, Tensor::ScalarType dtype, Device device); SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg); Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg);
...@@ -83,6 +83,7 @@ struct SanaConfig { ...@@ -83,6 +83,7 @@ struct SanaConfig {
int num_cross_attention_heads; int num_cross_attention_heads;
double expand_ratio; double expand_ratio;
std::vector<int> pag_layers; std::vector<int> pag_layers;
bool use_fp4;
}; };
class SanaModel : public Module { class SanaModel : public Module {
......
...@@ -117,6 +117,8 @@ void SafeTensors::parseHeader() { ...@@ -117,6 +117,8 @@ void SafeTensors::parseHeader() {
{ "I8", Tensor::INT8 }, { "I8", Tensor::INT8 },
{ "I32", Tensor::INT32 }, { "I32", Tensor::INT32 },
{ "I64", Tensor::INT64 }, { "I64", Tensor::INT64 },
{ "F8_E4M3", Tensor::FP8_E4M3 },
{ "F8_E5M2", Tensor::FP8_E5M2 },
}; };
auto check = [](bool cond, std::source_location location = std::source_location::current()) { auto check = [](bool cond, std::source_location location = std::source_location::current()) {
......
...@@ -218,7 +218,8 @@ public: ...@@ -218,7 +218,8 @@ public:
enum ScalarType { enum ScalarType {
INVALID_SCALAR_TYPE, INVALID_SCALAR_TYPE,
INT8, INT32, INT64, INT8, INT32, INT64,
FP16, FP32, BF16 FP16, FP32, BF16,
FP8_E4M3, FP8_E5M2,
}; };
struct TensorOptions { struct TensorOptions {
...@@ -545,6 +546,8 @@ inline const std::map<Tensor::ScalarType, size_t> Tensor::scalarSize = { ...@@ -545,6 +546,8 @@ inline const std::map<Tensor::ScalarType, size_t> Tensor::scalarSize = {
{FP16, 2}, {FP16, 2},
{FP32, 4}, {FP32, 4},
{BF16, 2}, {BF16, 2},
{FP8_E4M3, 1},
{FP8_E5M2, 1},
}; };
struct TensorsProvider { struct TensorsProvider {
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <memory> #include <memory>
#include <source_location> #include <source_location>
#include <vector> #include <vector>
#include <list>
#include <stack> #include <stack>
#include <map> #include <map>
#include <unordered_map> #include <unordered_map>
...@@ -79,6 +80,15 @@ constexpr T ceilDiv(T a, T b) { ...@@ -79,6 +80,15 @@ constexpr T ceilDiv(T a, T b) {
return (a + b - 1) / b; return (a + b - 1) / b;
} }
template<typename T>
constexpr int log2Up(T value) {
if (value <= 0)
return 0;
if (value == 1)
return 0;
return log2Up((value + 1) / 2) + 1;
}
struct CUBLASWrapper { struct CUBLASWrapper {
cublasHandle_t handle = nullptr; cublasHandle_t handle = nullptr;
......
...@@ -28,6 +28,8 @@ Tensor from_torch(at::Tensor input) { ...@@ -28,6 +28,8 @@ Tensor from_torch(at::Tensor input) {
{ at::ScalarType::Float, Tensor::FP32 }, { at::ScalarType::Float, Tensor::FP32 },
{ at::ScalarType::Half, Tensor::FP16 }, { at::ScalarType::Half, Tensor::FP16 },
{ at::ScalarType::BFloat16, Tensor::BF16 }, { at::ScalarType::BFloat16, Tensor::BF16 },
{ at::ScalarType::Float8_e4m3fn, Tensor::FP8_E4M3 },
{ at::ScalarType::Float8_e5m2, Tensor::FP8_E5M2 },
}; };
result.scalarType = mapType.at(input.scalar_type()); result.scalarType = mapType.at(input.scalar_type());
...@@ -53,6 +55,8 @@ at::Tensor to_torch(Tensor input) { ...@@ -53,6 +55,8 @@ at::Tensor to_torch(Tensor input) {
{ Tensor::FP32, at::ScalarType::Float }, { Tensor::FP32, at::ScalarType::Float },
{ Tensor::FP16, at::ScalarType::Half }, { Tensor::FP16, at::ScalarType::Half },
{ Tensor::BF16, at::ScalarType::BFloat16 }, { Tensor::BF16, at::ScalarType::BFloat16 },
{ Tensor::FP8_E4M3, at::ScalarType::Float8_e4m3fn },
{ Tensor::FP8_E5M2, at::ScalarType::Float8_e5m2 },
}; };
c10::TensorOptions opts(mapType.at(input.scalar_type())); c10::TensorOptions opts(mapType.at(input.scalar_type()));
......
...@@ -140,8 +140,10 @@ __global__ void gemv_kernel( ...@@ -140,8 +140,10 @@ __global__ void gemv_kernel(
for (int i = 0; i < Num; ++i) for (int i = 0; i < Num; ++i)
psum[i] = static_cast<accum_t>(0.f); psum[i] = static_cast<accum_t>(0.f);
extern __shared__ uint8_t shmem[]; // extern __shared__ uint8_t shmem[];
float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem); // float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem);
__shared__ float out_smem[BlockSize / WARP_SIZE * 2][Num * kInterleave];
const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave; const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave;
const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave; const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave;
......
...@@ -319,10 +319,10 @@ public: ...@@ -319,10 +319,10 @@ public:
int warpId = threadIdx.x / WARP_SIZE; int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) { for (int i = 0; i < WARP_M_TILES; i++) {
if (pred) { //if (pred) {
// out[i] = load(&act[((warpId * WARP_M_TILES + i) * K / WARP_K + k) * WARP_SIZE + laneId]); // out[i] = load(&act[((warpId * WARP_M_TILES + i) * K / WARP_K + k) * WARP_SIZE + laneId]);
out[i] = load(&act[((k * NUM_WARPS + warpId) * WARP_M_TILES + i) * WARP_SIZE + laneId]); out[i] = load_pred(&act[((k * NUM_WARPS + warpId) * WARP_M_TILES + i) * WARP_SIZE + laneId], pred);
} //}
} }
} }
...@@ -336,12 +336,12 @@ public: ...@@ -336,12 +336,12 @@ public:
// int offset = K / WARP_K * WARP_SIZE; // int offset = K / WARP_K * WARP_SIZE;
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_N_TILES; i++) { for (int i = 0; i < WARP_N_TILES; i++) {
if (pred) { //if (pred) {
// out[i] = load(&wgt[(i * K / WARP_K + k) * WARP_SIZE + laneId]); // out[i] = load(&wgt[(i * K / WARP_K + k) * WARP_SIZE + laneId]);
// out[i] = load(&wgt[(i + k * WARP_N_TILES) * WARP_SIZE + laneId]); // out[i] = load(&wgt[(i + k * WARP_N_TILES) * WARP_SIZE + laneId]);
out[i] = load(&ptr[i * WARP_SIZE]); out[i] = load_pred(&ptr[i * WARP_SIZE], pred);
// ptr += offset; // ptr += offset;
} //}
} }
} }
...@@ -352,11 +352,11 @@ public: ...@@ -352,11 +352,11 @@ public:
int warpId = threadIdx.x / WARP_SIZE; int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll #pragma unroll
for (int i = 0; i < ASCALES_NUM_PACKS; i++) { for (int i = 0; i < ASCALES_NUM_PACKS; i++) {
if (pred && laneId < ASCALES_VALID_LANES) { // if (pred && laneId < ASCALES_VALID_LANES) {
// out[i] = ascales[(group * M / WARP_M + warpId) * ASCALES_VALID_LANES * ASCALES_NUM_PACKS + i * ASCALES_VALID_LANES + laneId]; // out[i] = ascales[(group * M / WARP_M + warpId) * ASCALES_VALID_LANES * ASCALES_NUM_PACKS + i * ASCALES_VALID_LANES + laneId];
out[i] = ascales[(group * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES + i * ASCALES_VALID_LANES + laneId]; out[i] = load_pred(&ascales[(group * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES + i * ASCALES_VALID_LANES + laneId], pred && laneId < ASCALES_VALID_LANES);
} // }
} }
} }
...@@ -373,13 +373,13 @@ public: ...@@ -373,13 +373,13 @@ public:
#pragma unroll #pragma unroll
for (int i = 0; i < WSCALES_NUM_PACKS; i++) { for (int i = 0; i < WSCALES_NUM_PACKS; i++) {
if (pred && laneId < WSCALES_VALID_LANES) { // if (pred && laneId < WSCALES_VALID_LANES) {
// out[i] = wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId]; // out[i] = wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId];
// out[i] = load(&wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId]); // out[i] = load(&wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId]);
out[i] = load(&wscales[(group * WSCALES_NUM_PACKS + i) * WSCALES_VALID_LANES + laneId]); out[i] = load_pred(&wscales[(group * WSCALES_NUM_PACKS + i) * WSCALES_VALID_LANES + laneId], pred && laneId < WSCALES_VALID_LANES);
// out[i] = load(&ptr[i * WSCALES_VALID_LANES]); // out[i] = load(&ptr[i * WSCALES_VALID_LANES]);
} // }
} }
} }
...@@ -400,7 +400,7 @@ public: ...@@ -400,7 +400,7 @@ public:
return __shfl_sync(~0, block[packIdx].data[elementIdx], srcLane); return __shfl_sync(~0, block[packIdx].data[elementIdx], srcLane);
} }
template<typename F> template<bool FAST_I2F = false, typename F>
__device__ __forceinline__ __device__ __forceinline__
static void apply_scales(F &&getpsum, ascale_warp ascale, wscale_warp wscale, fpsum_warp &fpsum) { static void apply_scales(F &&getpsum, ascale_warp ascale, wscale_warp wscale, fpsum_warp &fpsum) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
...@@ -429,12 +429,31 @@ public: ...@@ -429,12 +429,31 @@ public:
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y); // printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
// } // }
fsum.data[0] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[0]), __int2float_rn(psum.data[1]))), __hmul2(asx[i], ws1), fsum.data[0]);
fsum.data[1] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[2]), __int2float_rn(psum.data[3]))), __hmul2(asy[i], ws1), fsum.data[1]);
fsum.data[2] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[4]), __int2float_rn(psum.data[5]))), __hmul2(asx[i], ws2), fsum.data[2]);
fsum.data[3] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[6]), __int2float_rn(psum.data[7]))), __hmul2(asy[i], ws2), fsum.data[3]);
auto scale_fma_normal = [&]() ALWAYSINLINE {
fsum.data[0] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[0]), __int2float_rn(psum.data[1]))), __hmul2(asx[i], ws1), fsum.data[0]);
fsum.data[1] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[2]), __int2float_rn(psum.data[3]))), __hmul2(asy[i], ws1), fsum.data[1]);
fsum.data[2] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[4]), __int2float_rn(psum.data[5]))), __hmul2(asx[i], ws2), fsum.data[2]);
fsum.data[3] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[6]), __int2float_rn(psum.data[7]))), __hmul2(asy[i], ws2), fsum.data[3]);
};
// should be faster on sm_80
auto scale_fma_fast = [&]() ALWAYSINLINE {
fsum.data[0] = __hfma2(float22half2<half2_t>(make_float2(int2float_fast(psum.data[0]), int2float_fast(psum.data[1]))), __hmul2(asx[i], ws1), fsum.data[0]);
fsum.data[1] = __hfma2(float22half2<half2_t>(make_float2(int2float_fast(psum.data[2]), int2float_fast(psum.data[3]))), __hmul2(asy[i], ws1), fsum.data[1]);
fsum.data[2] = __hfma2(float22half2<half2_t>(make_float2(int2float_fast(psum.data[4]), int2float_fast(psum.data[5]))), __hmul2(asx[i], ws2), fsum.data[2]);
fsum.data[3] = __hfma2(float22half2<half2_t>(make_float2(int2float_fast(psum.data[6]), int2float_fast(psum.data[7]))), __hmul2(asy[i], ws2), fsum.data[3]);
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ <= 800
if constexpr (FAST_I2F) {
scale_fma_fast();
} else {
scale_fma_normal();
}
#else
scale_fma_normal();
#endif
// if (threadIdx.x == 3 && j == 1 && i == 0) { // if (threadIdx.x == 3 && j == 1 && i == 0) {
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y); // printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
// } // }
...@@ -575,9 +594,9 @@ public: ...@@ -575,9 +594,9 @@ public:
(plugins(i * INSN_M + row, pack), ...); (plugins(i * INSN_M + row, pack), ...);
bool pred = i * INSN_M + row < maxRows && laneId * PACK_SIZE < maxCols; bool pred = i * INSN_M + row < maxRows && laneId * PACK_SIZE < maxCols;
if (pred) { // if (pred) {
store(reinterpret_cast<pack_t *>(&output[(i * INSN_M + row) * stride + laneId * PACK_SIZE]), pack); store_pred(reinterpret_cast<pack_t *>(&output[(i * INSN_M + row) * stride + laneId * PACK_SIZE]), pack, pred);
} // }
} }
__syncwarp(); __syncwarp();
...@@ -602,9 +621,9 @@ public: ...@@ -602,9 +621,9 @@ public:
(plugins(i * INSN_M + 8 + row, pack), ...); (plugins(i * INSN_M + 8 + row, pack), ...);
bool pred = i * INSN_M + 8 + row < maxRows && laneId * PACK_SIZE < maxCols; bool pred = i * INSN_M + 8 + row < maxRows && laneId * PACK_SIZE < maxCols;
if (pred) { // if (pred) {
store(reinterpret_cast<pack_t *>(&output[(i * INSN_M + 8 + row) * stride + laneId * PACK_SIZE]), pack); store_pred(reinterpret_cast<pack_t *>(&output[(i * INSN_M + 8 + row) * stride + laneId * PACK_SIZE]), pack, pred);
} // }
} }
__syncwarp(); __syncwarp();
} }
...@@ -680,33 +699,61 @@ public: ...@@ -680,33 +699,61 @@ public:
} }
}; };
template<bool USE_BIAS = true, bool USE_SCALE = false>
struct EpilogueBias { struct EpilogueBias {
struct Arguments { struct Arguments {
const packed_wscale_t *bias; // [N / BLOCK_N, WSCALES_NUM_PACKS, WSCALES_VALID_LANES] of packed_wscale_t const packed_wscale_t *bias; // [N / BLOCK_N, WSCALES_NUM_PACKS, WSCALES_VALID_LANES] of packed_wscale_t
const packed_wscale_t *scale;
}; };
__device__ __forceinline__ __device__ __forceinline__
void apply_bias(fpsum_warp &fpsum, int M, int N, int K, const packed_wscale_t *bias) { void apply_bias(fpsum_warp &fpsum, int M, int N, int K, const packed_wscale_t *bias, const packed_wscale_t *scale) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
// if (laneId == 0) { // if (laneId == 0) {
// printf("block.x=%d block.y=%d warpId=%d bias=%p\n", blockIdx.x, blockIdx.y, threadIdx.x / WARP_SIZE, bias); // printf("block.x=%d block.y=%d warpId=%d bias=%p\n", blockIdx.x, blockIdx.y, threadIdx.x / WARP_SIZE, bias);
// } // }
wscale_warp b; wscale_warp b, s;
load_wscale(bias, 0, N, b, true); if constexpr (USE_BIAS) {
load_wscale(bias, 0, N, b, true);
}
if constexpr (USE_SCALE) {
load_wscale(scale, 0, N, s, true);
}
for (int j = 0; j < WARP_N_TILES; j++) { for (int j = 0; j < WARP_N_TILES; j++) {
half2_t b1 = broadcast_wscale(b, j * 4, laneId); half2_t b1, b2;
half2_t b2 = broadcast_wscale(b, j * 4 + 2, laneId); half2_t s1, s2;
if constexpr (USE_BIAS) {
b1 = broadcast_wscale(b, j * 4, laneId);
b2 = broadcast_wscale(b, j * 4 + 2, laneId);
}
if constexpr (USE_SCALE) {
s1 = broadcast_wscale(s, j * 4, laneId);
s2 = broadcast_wscale(s, j * 4 + 2, laneId);
}
for (int i = 0; i < WARP_M_TILES; i++) { for (int i = 0; i < WARP_M_TILES; i++) {
auto &fsum = fpsum[i * WARP_N_TILES + j]; auto &fsum = fpsum[i * WARP_N_TILES + j];
fsum.data[0] = __hadd2(fsum.data[0], b1); if constexpr (USE_SCALE && USE_BIAS) {
fsum.data[1] = __hadd2(fsum.data[1], b1); fsum.data[0] = __hfma2(fsum.data[0], s1, b1);
fsum.data[2] = __hadd2(fsum.data[2], b2); fsum.data[1] = __hfma2(fsum.data[1], s1, b1);
fsum.data[3] = __hadd2(fsum.data[3], b2); fsum.data[2] = __hfma2(fsum.data[2], s2, b2);
fsum.data[3] = __hfma2(fsum.data[3], s2, b2);
} else if constexpr (USE_SCALE) {
fsum.data[0] = __hmul2(fsum.data[0], s1);
fsum.data[1] = __hmul2(fsum.data[1], s1);
fsum.data[2] = __hmul2(fsum.data[2], s2);
fsum.data[3] = __hmul2(fsum.data[3], s2);
} else if constexpr (USE_BIAS) {
fsum.data[0] = __hadd2(fsum.data[0], b1);
fsum.data[1] = __hadd2(fsum.data[1], b1);
fsum.data[2] = __hadd2(fsum.data[2], b2);
fsum.data[3] = __hadd2(fsum.data[3], b2);
}
} }
} }
} }
...@@ -714,10 +761,13 @@ public: ...@@ -714,10 +761,13 @@ public:
__device__ __forceinline__ __device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) { void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
const int bn = binfo.bn; const int bn = binfo.bn;
apply_bias( if constexpr (USE_BIAS || USE_SCALE) {
fpsum, M, N, K, apply_bias(
args.bias + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES fpsum, M, N, K,
); args.bias + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES,
args.scale + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES
);
}
} }
}; };
...@@ -797,7 +847,8 @@ public: ...@@ -797,7 +847,8 @@ public:
using typename Base::unpack_fpsum; \ using typename Base::unpack_fpsum; \
using typename Base::EpilogueDefault; \ using typename Base::EpilogueDefault; \
using typename Base::EpilogueNop; \ using typename Base::EpilogueNop; \
using typename Base::EpilogueBias; template<bool USE_BIAS, bool USE_SCALE> \
using EpilogueBias = typename Base::EpilogueBias<USE_BIAS, USE_SCALE>;
template<typename kernel, typename ...T> template<typename kernel, typename ...T>
......
...@@ -43,6 +43,41 @@ static T load(const T *addr) { ...@@ -43,6 +43,41 @@ static T load(const T *addr) {
return *addr; return *addr;
} }
template<typename T>
__device__ __forceinline__
static T load_pred(const T *addr, bool pred) {
if constexpr (sizeof(T) == 4) {
uint32_t data;
asm volatile (
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
"@loadpred ld.global.nc.b32 %0, [%1];"
"}" : "=r"(data) : "l"(addr), "r"((int)pred));
return *reinterpret_cast<T *>(&data);
}
if constexpr (sizeof(T) == 8) {
uint2 data;
asm volatile (
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
"@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
"}" : "=r"(data.x), "=r"(data.y) : "l"(addr), "r"((int)pred));
return *reinterpret_cast<T *>(&data);
}
if constexpr (sizeof(T) == 16) {
uint4 data;
asm volatile (
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
"@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
"}" : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) : "l"(addr), "r"((int)pred));
return *reinterpret_cast<T *>(&data);
}
T result;
if (pred) {
result = *addr;
}
return result;
}
template<bool shmem = false, typename T> template<bool shmem = false, typename T>
__device__ __forceinline__ __device__ __forceinline__
static void store(T *addr, T val) { static void store(T *addr, T val) {
...@@ -76,6 +111,39 @@ static void store(T *addr, T val) { ...@@ -76,6 +111,39 @@ static void store(T *addr, T val) {
*addr = val; *addr = val;
} }
template<typename T>
__device__ __forceinline__
static void store_pred(T *addr, T val, bool pred) {
if constexpr (sizeof(T) == 4) {
uint32_t data = *reinterpret_cast<uint32_t *>(&val);
asm volatile (
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.b32 [%1], %2;"
"}" :: "r"((int)pred), "l"(addr), "r"(data));
return;
}
if constexpr (sizeof(T) == 8) {
uint2 data = *reinterpret_cast<uint2 *>(&val);
asm volatile (
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
"}" :: "r"((int)pred), "l"(addr), "r"(data.x), "r"(data.y));
return;
}
if constexpr (sizeof(T) == 16) {
uint4 data = *reinterpret_cast<uint4 *>(&val);
asm volatile (
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
"}" :: "r"((int)pred), "l"(addr), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w));
return;
}
if (pred) {
*addr = val;
}
}
__device__ __forceinline__ __device__ __forceinline__
static float2 half22float2(half2 val) { static float2 half22float2(half2 val) {
return __half22float2(val); return __half22float2(val);
...@@ -159,6 +227,21 @@ uint32_t quantize_float2<8, false>(float2 value) { ...@@ -159,6 +227,21 @@ uint32_t quantize_float2<8, false>(float2 value) {
return result; return result;
} }
__device__ __forceinline__
uint32_t quantize_float2_fp4(float2 value) {
uint32_t result;
asm volatile ("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }" : "=r"(result) : "f"(value.y), "f"(value.x));
return result;
}
__device__ __forceinline__
uint32_t quantize_float4_fp8(float4 value) {
uint16_t lo, hi;
asm volatile ("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(lo) : "f"(value.y), "f"(value.x));
asm volatile ("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(hi) : "f"(value.w), "f"(value.z));
return uint32_t(lo) | (uint32_t(hi) << 16);
}
__device__ __forceinline__ __device__ __forceinline__
static float cuda_tanhf(float x) { static float cuda_tanhf(float x) {
float result; float result;
...@@ -271,4 +354,14 @@ static void unrolled_loop(F &&lambda) { ...@@ -271,4 +354,14 @@ static void unrolled_loop(F &&lambda) {
call(std::make_integer_sequence<int, cnt>()); call(std::make_integer_sequence<int, cnt>());
} }
// int2float is slow on sm_80 and before
// val in [-4194304, 4194303]
__device__ __forceinline__
static float int2float_fast(int val) {
float fval;
// fval = (val & 0x7FFFFF) ^ 0x4B400000
asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=f"(fval) : "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA));
return fval - 12582912.0f;
}
}; // namespace nunchaku::kernels }; // namespace nunchaku::kernels
\ No newline at end of file
...@@ -36,9 +36,23 @@ void gemm_w4a4( ...@@ -36,9 +36,23 @@ void gemm_w4a4(
Tensor out_linearattn,// linear [B, (M), N / 3] Tensor out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned, bool act_unsigned,
std::vector<float> lora_scales, // [R / 16] std::vector<float> lora_scales, // [R / 16]
bool fuse_silu bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales
) { ) {
invoke_launch(ascales.dtype(), [&]<typename Config>() { Tensor::ScalarType dtype = Tensor::INVALID_SCALAR_TYPE;
if (!fp4) {
dtype = ascales.dtype();
} else {
for (auto tensor : {out, bias, lora_up, lora_down, poolout, wcscales}) {
if (tensor.valid()) {
assert(dtype == Tensor::INVALID_SCALAR_TYPE || dtype == tensor.dtype());
dtype = tensor.dtype();
}
}
}
invoke_launch(dtype, [&]<typename Config>() {
GEMM_W4A4_Launch<Config>::gemm_w4a4( GEMM_W4A4_Launch<Config>::gemm_w4a4(
act, act,
wgt, wgt,
...@@ -61,7 +75,10 @@ void gemm_w4a4( ...@@ -61,7 +75,10 @@ void gemm_w4a4(
out_linearattn, out_linearattn,
act_unsigned, act_unsigned,
lora_scales, lora_scales,
fuse_silu fuse_silu,
fp4,
alpha,
wcscales
); );
}); });
} }
...@@ -72,10 +89,10 @@ void linearattn_vk_mul_q(Tensor q, Tensor vk) { ...@@ -72,10 +89,10 @@ void linearattn_vk_mul_q(Tensor q, Tensor vk) {
}); });
} }
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu) { void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) {
invoke_launch(input.dtype(), [&]<typename Config>() { invoke_launch(input.dtype(), [&]<typename Config>() {
GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora( GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(
input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu, fp4
); );
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment