Commit 92ac7b40 authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

Add our own FP16 Attention implementation

parent 182c323c
......@@ -143,8 +143,20 @@ public:
});
}
void forceFP16Attention(bool enable) {
Attention::setForceFP16(net.get(), enable);
void setAttentionImpl(std::string name) {
if (name.empty() || name == "default") {
name = "flashattn2";
}
spdlog::info("Set attention implementation to {}", name);
if (name == "flashattn2") {
net->setAttentionImpl(AttentionImpl::FlashAttention2);
} else if (name == "nunchaku-fp16") {
net->setAttentionImpl(AttentionImpl::NunchakuFP16);
} else {
throw std::invalid_argument(spdlog::fmt_lib::format("Invalid attention implementation {}", name));
}
}
};
\ No newline at end of file
......@@ -32,7 +32,11 @@ namespace nunchaku::ops {
bool fuse_silu,
bool fp4,
float alpha,
std::optional<torch::Tensor> wcscales
std::optional<torch::Tensor> wcscales,
std::optional<torch::Tensor> out_q, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_k, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_v, // packed attention [B, H, M, D]
int attn_tokens
) {
spdlog::trace("running gemm_w4a4: ");
......@@ -70,11 +74,31 @@ namespace nunchaku::ops {
fuse_silu,
fp4,
alpha,
getTensor(wcscales)
getTensor(wcscales),
getTensor(out_q),
getTensor(out_k),
getTensor(out_v),
attn_tokens
);
// Tensor::synchronizeDevice();
}
void attention_fp16(
torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
torch::Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
torch::Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
torch::Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
float scale
) {
nunchaku::kernels::attention_fp16(
from_torch(q),
from_torch(k),
from_torch(v),
from_torch(o),
scale
);
}
torch::Tensor gemv_awq(
torch::Tensor _in_feats,
torch::Tensor _kernel,
......@@ -122,6 +146,36 @@ namespace nunchaku::ops {
return output;
}
void test_rmsnorm_rope(
torch::Tensor input,
torch::Tensor output,
torch::Tensor norm_q,
torch::Tensor norm_k,
torch::Tensor rotary_emb)
{
nunchaku::kernels::test_rmsnorm_rope(
from_torch(input),
from_torch(output),
from_torch(norm_q),
from_torch(norm_k),
from_torch(rotary_emb)
);
}
void test_pack_qkv(
torch::Tensor input,
torch::Tensor out_q,
torch::Tensor out_k,
torch::Tensor out_v,
int numTokens)
{
nunchaku::kernels::test_pack_qkv(
from_torch(input),
from_torch(out_q),
from_torch(out_k),
from_torch(out_v),
numTokens
);
}
};
\ No newline at end of file
......@@ -33,7 +33,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("stopDebug", &QuantizedFluxModel::stopDebug)
.def("getDebugResults", &QuantizedFluxModel::getDebugResults)
.def("setLoraScale", &QuantizedFluxModel::setLoraScale)
.def("forceFP16Attention", &QuantizedFluxModel::forceFP16Attention)
.def("setAttentionImpl", &QuantizedFluxModel::setAttentionImpl)
;
py::class_<QuantizedSanaModel>(m, "QuantizedSanaModel")
.def(py::init<>())
......@@ -82,14 +82,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
;
m.def_submodule("ops")
.def("gemm_w4a4", nunchaku::ops::gemm_w4a4)
.def("attention_fp16", nunchaku::ops::attention_fp16)
.def("gemm_awq", nunchaku::ops::gemm_awq)
.def("gemv_awq", nunchaku::ops::gemv_awq)
.def("test_rmsnorm_rope", nunchaku::ops::test_rmsnorm_rope)
.def("test_pack_qkv", nunchaku::ops::test_pack_qkv)
;
m.def_submodule("utils")
.def("set_log_level", [](const std::string &level) {
spdlog::set_level(spdlog::level::from_str(level));
})
.def("set_cuda_stack_limit", nunchaku::utils::set_cuda_stack_limit)
.def("disable_memory_auto_release", nunchaku::utils::disable_memory_auto_release)
.def("trim_memory", nunchaku::utils::trim_memory)
;
......
......@@ -5,6 +5,13 @@
namespace nunchaku::utils {
void set_cuda_stack_limit(int64_t newval) {
size_t val = 0;
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, (size_t)newval));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val);
}
void disable_memory_auto_release() {
int device;
checkCUDA(cudaGetDevice(&device));
......
......@@ -22,6 +22,28 @@ class NunchakuFluxTransformerBlocks(nn.Module):
self.dtype = torch.bfloat16
self.device = device
@staticmethod
def pack_rotemb(rotemb: torch.Tensor) -> torch.Tensor:
assert rotemb.dtype == torch.float32
B = rotemb.shape[0]
M = rotemb.shape[1]
D = rotemb.shape[2] * 2
assert rotemb.shape == (B, M, D // 2, 1, 2)
assert M % 16 == 0
assert D % 8 == 0
rotemb = rotemb.reshape(B, M // 16, 16, D // 8, 8)
rotemb = rotemb.permute(0, 1, 3, 2, 4)
# 16*8 pack, FP32 accumulator (C) format
# https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-16816-c
##########################################|--M--|--D--|
##########################################|-3--4--5--6|
########################################## : : : :
rotemb = rotemb.reshape(*rotemb.shape[0:3], 2, 8, 4, 2)
rotemb = rotemb.permute(0, 1, 2, 4, 5, 3, 6)
rotemb = rotemb.contiguous()
rotemb = rotemb.view(B, M, D)
return rotemb
def forward(
self,
hidden_states: torch.Tensor,
......@@ -53,9 +75,9 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
rotary_emb_single = image_rotary_emb # .to(self.dtype)
rotary_emb_txt = pad_tensor(rotary_emb_txt, 256, 1)
rotary_emb_img = pad_tensor(rotary_emb_img, 256, 1)
rotary_emb_single = pad_tensor(rotary_emb_single, 256, 1)
rotary_emb_txt = self.pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1))
rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
rotary_emb_single = self.pack_rotemb(pad_tensor(rotary_emb_single, 256, 1))
hidden_states = self.m.forward(
hidden_states,
......@@ -104,8 +126,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype)
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
rotary_emb_txt = pad_tensor(rotary_emb_txt, 256, 1)
rotary_emb_img = pad_tensor(rotary_emb_img, 256, 1)
rotary_emb_txt = self.pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1))
rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
hidden_states, encoder_hidden_states = self.m.forward_layer(
idx, hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt
......@@ -254,6 +276,11 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
if len(self.unquantized_loras) > 0:
self.update_unquantized_lora_params(strength)
def set_attention_impl(self, impl: str):
block = self.transformer_blocks[0]
assert isinstance(block, NunchakuFluxTransformerBlocks)
block.m.setAttentionImpl(impl)
def inject_quantized_module(self, m: QuantizedFluxModel, device: str | torch.device = "cuda"):
print("Injecting quantized module")
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=[16, 56, 56])
......
......@@ -158,9 +158,13 @@ if __name__ == "__main__":
"src/kernels/layernorm_kernels.cu",
"src/kernels/misc_kernels.cu",
"src/kernels/zgemm/gemm_w4a4.cu",
"src/kernels/zgemm/gemm_w4a4_launch_fp16.cu",
"src/kernels/zgemm/gemm_w4a4_launch_bf16.cu",
"src/kernels/zgemm/gemm_w4a4_test.cu",
"src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu",
"src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu",
"src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu",
"src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu",
"src/kernels/zgemm/gemm_w8a8.cu",
"src/kernels/zgemm/attention.cu",
"src/kernels/dwconv.cu",
"src/kernels/gemm_batched.cu",
"src/kernels/gemm_f16.cu",
......
#include "FluxModel.h"
#include "kernels/misc_kernels.h"
#include "kernels/gemm_batched.h"
#include "kernels/zgemm/zgemm.h"
#include "flash_api.h"
#include "activation.h"
......@@ -235,7 +236,7 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
Tensor raw_attn_output = mha_varlen_fwd(
q, k, v,
cu_seqlens, cu_seqlens,
num_tokens_img + num_tokens_context, num_tokens_img + num_tokens_context,
num_tokens_img + num_tokens_txt, num_tokens_img + num_tokens_txt,
0.0f,
pow(q.shape[-1], (-0.5)),
false, false, -1, -1, false
......@@ -298,19 +299,49 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
Tensor residual = hidden_states;
Tensor qkv = Tensor::allocate({batch_size, num_tokens, dim * 3}, norm_hidden_states.scalar_type(), norm_hidden_states.device());
// qkv_proj.forward(norm_hidden_states, qkv, {});
// debug("qkv_raw", qkv);
Tensor attn_output;
debug("rotary_emb", rotary_emb);
qkv_proj.forward(norm_hidden_states, qkv, {}, norm_q.weight, norm_k.weight, rotary_emb);
debug("qkv", qkv);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
Tensor attn_output = attn.forward(qkv, {}, 0);
attn_output = attn_output.reshape({batch_size, num_tokens, num_heads * dim_head});
if (attnImpl == AttentionImpl::FlashAttention2) {
Tensor qkv = Tensor::allocate({batch_size, num_tokens, dim * 3}, norm_hidden_states.scalar_type(), norm_hidden_states.device());
// qkv_proj.forward(norm_hidden_states, qkv, {});
// debug("qkv_raw", qkv);
qkv_proj.forward(norm_hidden_states, qkv, {}, norm_q.weight, norm_k.weight, rotary_emb);
debug("qkv", qkv);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
attn_output = attn.forward(qkv, {}, 0);
attn_output = attn_output.reshape({batch_size, num_tokens, num_heads * dim_head});
} else if (attnImpl == AttentionImpl::NunchakuFP16) {
assert(batch_size == 1);
const int num_tokens_pad = ceilDiv(num_tokens, 256) * 256;
Tensor q = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
Tensor k = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
Tensor v = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
qkv_proj.forward(norm_hidden_states, {}, {}, norm_q.weight, norm_k.weight, rotary_emb, q, k, v, num_tokens);
debug("packed_q", q);
debug("packed_k", k);
debug("packed_v", v);
Tensor o = Tensor::allocate({batch_size, num_tokens_pad, num_heads * dim_head}, norm_hidden_states.scalar_type(), norm_hidden_states.device());
kernels::attention_fp16(q, k, v, o, pow(dim_head, (-0.5)));
attn_output = o.slice(1, 0, num_tokens);
} else {
assert(false);
}
debug("raw_attn_output", attn_output);
attn_output = forward_fc(out_proj, attn_output);
debug("attn_output", attn_output);
......@@ -384,13 +415,13 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
int num_tokens_img = hidden_states.shape[1];
int num_tokens_context = encoder_hidden_states.shape[1];
int num_tokens_txt = encoder_hidden_states.shape[1];
assert(hidden_states.shape[2] == dim);
assert(encoder_hidden_states.shape[2] == dim);
spdlog::debug("hidden_states={} encoder_hidden_states={} temb={}", hidden_states.shape.str(), encoder_hidden_states.shape.str(), temb.shape.str());
spdlog::debug("batch_size={} num_tokens_img={} num_tokens_context={}", batch_size, num_tokens_img, num_tokens_context);
spdlog::debug("batch_size={} num_tokens_img={} num_tokens_txt={}", batch_size, num_tokens_img, num_tokens_txt);
auto norm1_output = norm1.forward(hidden_states, temb);
auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);
......@@ -408,76 +439,137 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePop();
auto stream = getCurrentCUDAStream();
Tensor concat;
Tensor pool;
{
nvtxRangePushA("qkv_proj");
const bool blockSparse = sparsityRatio > 0;
const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_context / POOL_SIZE;
concat = Tensor::allocate({batch_size, num_tokens_img + num_tokens_context, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device());
int num_tokens_img_pad = 0, num_tokens_txt_pad = 0;
Tensor raw_attn_output;
pool = blockSparse
? Tensor::allocate({batch_size, poolTokens, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device())
: Tensor{};
if (attnImpl == AttentionImpl::FlashAttention2) {
num_tokens_img_pad = num_tokens_img;
num_tokens_txt_pad = num_tokens_txt;
for (int i = 0; i < batch_size; i++) {
// img first
Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
Tensor qkv_context = concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_context);
Tensor pool_qkv = pool.valid()
? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE)
: Tensor{};
Tensor pool_qkv_context = pool.valid()
? concat.slice(0, i, i + 1).slice(1, num_tokens_img / POOL_SIZE, num_tokens_img / POOL_SIZE + num_tokens_context / POOL_SIZE)
Tensor concat;
Tensor pool;
{
nvtxRangePushA("qkv_proj");
const bool blockSparse = sparsityRatio > 0;
const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
concat = Tensor::allocate({batch_size, num_tokens_img + num_tokens_txt, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device());
pool = blockSparse
? Tensor::allocate({batch_size, poolTokens, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device())
: Tensor{};
for (int i = 0; i < batch_size; i++) {
// img first
Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
Tensor qkv_context = concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);
Tensor pool_qkv = pool.valid()
? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE)
: Tensor{};
Tensor pool_qkv_context = pool.valid()
? concat.slice(0, i, i + 1).slice(1, num_tokens_img / POOL_SIZE, num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE)
: Tensor{};
// qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
// debug("qkv_raw", qkv);
debug("rotary_emb", rotary_emb);
qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
debug("qkv", qkv);
// qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
// debug("qkv_context_raw", qkv_context);
debug("rotary_emb_context", rotary_emb_context);
qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context, pool_qkv_context, norm_added_q.weight, norm_added_k.weight, rotary_emb_context);
debug("qkv_context", qkv_context);
}
nvtxRangePop();
}
spdlog::debug("concat={}", concat.shape.str());
debug("concat", concat);
assert(concat.shape[2] == num_heads * dim_head * 3);
nvtxRangePushA("Attention");
raw_attn_output = attn.forward(concat, pool, sparsityRatio);
nvtxRangePop();
spdlog::debug("raw_attn_output={}", raw_attn_output.shape.str());
raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img + num_tokens_txt, num_heads, dim_head});
} else if (attnImpl == AttentionImpl::NunchakuFP16) {
num_tokens_img_pad = ceilDiv(num_tokens_img, 256) * 256;
num_tokens_txt_pad = ceilDiv(num_tokens_txt, 256) * 256;
// qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
// debug("qkv_raw", qkv);
debug("rotary_emb", rotary_emb);
qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
debug("qkv", qkv);
// qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
// debug("qkv_context_raw", qkv_context);
Tensor concat_q, concat_k, concat_v;
debug("rotary_emb_context", rotary_emb_context);
{
nvtxRangePushA("qkv_proj");
concat_q = Tensor::allocate({batch_size, num_heads, num_tokens_img_pad + num_tokens_txt_pad, dim_head}, Tensor::FP16, norm1_output.x.device());
concat_k = Tensor::empty_like(concat_q);
concat_v = Tensor::empty_like(concat_q);
for (int i = 0; i < batch_size; i++) {
// img first
auto sliceImg = [&](Tensor x) {
return x.slice(0, i, i+1).slice(2, 0, num_tokens_img_pad);
};
auto sliceTxt = [&](Tensor x) {
return x.slice(0, i, i+1).slice(2, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt_pad);
};
qkv_proj.forward(
norm1_output.x.slice(0, i, i + 1), {}, {}, norm_q.weight, norm_k.weight, rotary_emb,
sliceImg(concat_q), sliceImg(concat_k), sliceImg(concat_v), num_tokens_img
);
qkv_proj_context.forward(
norm1_context_output.x.slice(0, i, i + 1), {}, {}, norm_added_q.weight, norm_added_k.weight, rotary_emb_context,
sliceTxt(concat_q), sliceTxt(concat_k), sliceTxt(concat_v), num_tokens_txt
);
}
qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context, pool_qkv_context, norm_added_q.weight, norm_added_k.weight, rotary_emb_context);
debug("qkv_context", qkv_context);
debug("concat_q", concat_q);
debug("concat_k", concat_k);
debug("concat_v", concat_v);
nvtxRangePop();
}
nvtxRangePop();
}
spdlog::debug("concat={}", concat.shape.str());
debug("concat", concat);
raw_attn_output = Tensor::allocate({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads * dim_head}, norm1_output.x.scalar_type(), norm1_output.x.device());
assert(concat.shape[2] == num_heads * dim_head * 3);
nvtxRangePushA("Attention");
nvtxRangePushA("Attention");
kernels::attention_fp16(concat_q, concat_k, concat_v, raw_attn_output, pow(dim_head, (-0.5)));
Tensor raw_attn_output = attn.forward(concat, pool, sparsityRatio);
nvtxRangePop();
nvtxRangePop();
spdlog::debug("raw_attn_output={}", raw_attn_output.shape.str());
raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads, dim_head});
} else {
assert(false);
}
raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img + num_tokens_context, num_heads, dim_head});
debug("raw_attn_output", raw_attn_output);
{
nvtxRangePushA("o_proj");
auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_output;
// raw_attn_output: [batch_size, num_tokens_img + num_tokens_context, num_heads * dim_head]
// raw_attn_output: [batch_size, num_tokens_img + num_tokens_txt, num_heads * dim_head]
Tensor raw_attn_output_split;
if (batch_size == 1) {
......@@ -488,7 +580,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
raw_attn_output_split.data_ptr(),
num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
raw_attn_output.data_ptr(),
(num_tokens_img + num_tokens_context) * num_heads * dim_head * raw_attn_output.scalar_size(),
(num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head * raw_attn_output.scalar_size(),
num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
batch_size,
cudaMemcpyDeviceToDevice,
......@@ -546,15 +638,15 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor raw_attn_output_split;
if (batch_size == 1) {
raw_attn_output_split = raw_attn_output.slice(1, num_tokens_img, num_tokens_img + num_tokens_context).reshape({batch_size, num_tokens_context, num_heads * dim_head});
raw_attn_output_split = raw_attn_output.slice(1, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt).reshape({batch_size, num_tokens_txt, num_heads * dim_head});
} else {
raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_context, num_heads * dim_head}, raw_attn_output.scalar_type(), raw_attn_output.device());
raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_txt, num_heads * dim_head}, raw_attn_output.scalar_type(), raw_attn_output.device());
checkCUDA(cudaMemcpy2DAsync(
raw_attn_output_split.data_ptr(),
num_tokens_context * num_heads * dim_head * raw_attn_output_split.scalar_size(),
raw_attn_output.data_ptr<char>() + num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
(num_tokens_img + num_tokens_context) * num_heads * dim_head * raw_attn_output.scalar_size(),
num_tokens_context * num_heads * dim_head * raw_attn_output_split.scalar_size(),
num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
raw_attn_output.data_ptr<char>() + num_tokens_img_pad * num_heads * dim_head * raw_attn_output_split.scalar_size(),
(num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head * raw_attn_output.scalar_size(),
num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
batch_size,
cudaMemcpyDeviceToDevice,
stream));
......@@ -682,4 +774,13 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
helper.run();
return hidden_states;
}
\ No newline at end of file
}
void FluxModel::setAttentionImpl(AttentionImpl impl) {
for (auto &&block : this->transformer_blocks) {
block->attnImpl = impl;
}
for (auto &&block : this->single_transformer_blocks) {
block->attnImpl = impl;
}
}
......@@ -6,6 +6,11 @@
#include "Linear.h"
#include "layernorm.h"
enum class AttentionImpl {
FlashAttention2 = 0,
NunchakuFP16,
};
class AdaLayerNormZeroSingle : public Module {
public:
static constexpr bool USE_4BIT = true;
......@@ -86,6 +91,8 @@ public:
const int num_heads;
const int mlp_hidden_dim;
AttentionImpl attnImpl = AttentionImpl::FlashAttention2;
private:
AdaLayerNormZeroSingle norm;
GEMM mlp_fc1;
......@@ -110,6 +117,8 @@ public:
const int num_heads;
const bool context_pre_only;
AttentionImpl attnImpl = AttentionImpl::FlashAttention2;
private:
AdaLayerNormZero norm1;
AdaLayerNormZero norm1_context;
......@@ -131,6 +140,8 @@ public:
FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single, bool skip_first_layer = false);
void setAttentionImpl(AttentionImpl impl);
public:
std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks;
std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks;
......
......@@ -181,7 +181,7 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward(Tensor x
return forward_quant(quantize(x, false), fuse, nextGEMM);
}
void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor norm_k, Tensor rotary_emb) {
void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor norm_k, Tensor rotary_emb, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens) {
QuantizedActivation qact = quantize(x, false);
#if !NO_LORA_FUSION
......@@ -196,7 +196,8 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
kernels::gemm_w4a4(
qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, qact.lora_act, this->lora_up, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, {}, {}, qact.is_unsigned, this->lora_scales, false,
use_fp4, *this->wtscale.data_ptr<float>(), wcscales.numel() > 0 ? wcscales: Tensor{}
use_fp4, *this->wtscale.data_ptr<float>(), wcscales.numel() > 0 ? wcscales: Tensor{},
out_q, out_k, out_v, numTokens
);
debug("gemm.out", out);
......@@ -277,7 +278,8 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
kernels::gemm_w4a4(
qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, qact.lora_act, this->lora_up, next_lora, qout.lora_act, {}, {}, {}, this->bias, next_smooth, {}, {}, qact.is_unsigned, this->lora_scales, fuse == FuseOptions::SILU,
use_fp4, *this->wtscale.data_ptr<float>(), wcscales.numel() > 0 ? wcscales: Tensor{}
use_fp4, *this->wtscale.data_ptr<float>(), wcscales.numel() > 0 ? wcscales: Tensor{},
{}, {}, {}, 0
);
if (fuse == FuseOptions::EMPTY || fuse == FuseOptions::SILU) {
......
......@@ -69,7 +69,11 @@ public:
Tensor forward(Tensor x);
Tensor forward_silu(Tensor x);
std::variant<Tensor, QuantizedActivation> forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr);
void forward(Tensor x, Tensor out, Tensor pool = {}, Tensor norm_q = {}, Tensor norm_k = {}, Tensor rotary_emb = {});
void forward(
Tensor x, Tensor out,
Tensor pool = {}, Tensor norm_q = {}, Tensor norm_k = {}, Tensor rotary_emb = {},
Tensor out_q = {}, Tensor out_k = {}, Tensor out_v = {}, int numTokens = 0
);
std::variant<Tensor, QuantizedActivation> forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr);
Tensor forward_quant(QuantizedActivation qact);
......
......@@ -174,7 +174,7 @@ protected:
}
void debug(std::string name, Tensor tensor) {
if (DebugContext::ctxs.empty()) {
if (DebugContext::ctxs.empty() || !tensor.valid()) {
return;
}
std::string prefix = getFullName();
......
......@@ -69,7 +69,8 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
qact.is_unsigned, qkv_proj.lora_scales, false,
qkv_proj.use_fp4,
*qkv_proj.wtscale.data_ptr<float>(),
qkv_proj.wcscales.numel() > 0 ? qkv_proj.wcscales : Tensor{}
qkv_proj.wcscales.numel() > 0 ? qkv_proj.wcscales : Tensor{},
{}, {}, {}, 0
);
debug("vk", vk);
......
#include "zgemm.h"
#include "attention.cuh"
#ifndef M_LOG2E
#define M_LOG2E 1.4426950408889634074
#endif
namespace nunchaku::kernels {
void attention_fp16(
Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
float scale
) {
int sizeBatch = q.shape[0];
int numHeads = q.shape[1];
int numTokensQ = q.shape[2];
int headDim = q.shape[3];
int numTokensKV = k.shape[2];
assert(o.ndims() == 3);
assert(o.shape[0] == sizeBatch);
assert(o.shape[1] == numTokensQ);
assert(o.shape[2] == numHeads * headDim);
spdlog::trace("attention_fp16: B={} H={} NQ={} NK={}", sizeBatch, numHeads, numTokensQ, numTokensKV);
spdlog::trace("q at {}", q.data_ptr());
spdlog::trace("k at {}", k.data_ptr());
spdlog::trace("v at {}", v.data_ptr());
spdlog::trace("o at {}", o.data_ptr());
spdlog::trace("scale={}", scale);
dispatchBool(o.scalar_type() == Tensor::BF16, [&]<bool bf16out>() {
#ifndef __INTELLISENSE__
using Attention = typename nunchaku::kernels::Attention<AttentionFP16Config<bf16out>>;
#else
using Attention = typename nunchaku::kernels::Attention<AttentionFP16Config<true>>;
#endif
using GEMM = typename Attention::GEMM;
assert(isTypeMatch<typename Attention::half_t>(q.scalar_type()));
assert(isTypeMatch<typename Attention::half_t>(k.scalar_type()));
assert(isTypeMatch<typename Attention::half_t>(v.scalar_type()));
assert(isTypeMatch<typename Attention::epilogue_half_t>(o.scalar_type()));
int shmem = 0;
// we use exp2 instead of exp in the kernel
scale *= M_LOG2E;
assert(numTokensQ % Attention::BLOCK_M == 0);
assert(numTokensKV % Attention::WARP_K == 0);
assert(headDim == Attention::HEAD_DIM);
auto launch = [&]<typename Epilogue>(Epilogue::Arguments args) {
dim3 grid(numTokensQ / Attention::BLOCK_M, numHeads, sizeBatch);
using packed_q_t = typename Attention::packed_q_t;
using packed_k_t = typename Attention::packed_k_t;
using packed_v_t = typename Attention::packed_v_t;
auto func = invoke_kernel<typename Attention::attention_fp16_kernel<Epilogue>,
const packed_q_t *,
const packed_k_t *,
const packed_v_t *,
float,
int, int,
typename Epilogue::Arguments,
bool>;
shmem = std::max(shmem, Attention::template attention_fp16_kernel<Epilogue>::SHMEM_SIZE);
if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
}
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
q.data_ptr<packed_q_t>(),
k.data_ptr<packed_k_t>(),
v.data_ptr<packed_v_t>(),
scale,
numTokensQ, numTokensKV,
args,
false
);
checkCUDA(cudaGetLastError());
};
launch.template operator()<typename GEMM::EpilogueDefault>(typename GEMM::EpilogueDefault::Arguments{
.out = o.data_ptr<typename GEMM::half_t>(),
.actualM = sizeBatch * numTokensQ,
.actualN = numHeads * headDim,
});
});
}
}; // namespace nunchaku::kernels
\ No newline at end of file
This diff is collapsed.
......@@ -188,6 +188,13 @@ static void ldmatrix(const void *ptr, uint4 &out) {
);
}
template<typename T>
__device__ __forceinline__
static T movmatrix(T x) {
asm volatile ("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" : "=r"(*reinterpret_cast<uint32_t *>(&x)) : "r"(*reinterpret_cast<uint32_t *>(&x)));
return x;
}
// x in low bit, y in high bit
template<int bitwidth, bool use_unsigned>
......@@ -277,6 +284,13 @@ static float cuda_cos(float x) {
return result;
}
__device__ __forceinline__
static float cuda_exp2(float x) {
float result;
asm ("ex2.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
return result;
}
// https://forums.developer.nvidia.com/t/hardware-accelerated-computation-of-the-sigmoid-logistic-function/266206
__forceinline__ __device__
static float cuda_sigmoidf (float a)
......@@ -364,4 +378,12 @@ static float int2float_fast(int val) {
return fval - 12582912.0f;
}
template<typename To, typename From>
__device__ __forceinline__
static To bit_cast(const From &input) {
static_assert(sizeof(To) == sizeof(From));
// not safe but anyway
return *reinterpret_cast<const To *>(&input);
}
}; // namespace nunchaku::kernels
\ No newline at end of file
......@@ -39,7 +39,11 @@ void gemm_w4a4(
bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales
Tensor wcscales,
Tensor out_q, // packed attention [B, H, M, D]
Tensor out_k, // packed attention [B, H, M, D]
Tensor out_v, // packed attention [B, H, M, D]
int attn_tokens
) {
Tensor::ScalarType dtype = Tensor::INVALID_SCALAR_TYPE;
if (!fp4) {
......@@ -53,60 +57,68 @@ void gemm_w4a4(
}
}
invoke_launch(dtype, [&]<typename Config>() {
GEMM_W4A4_Launch<Config>::gemm_w4a4(
act,
wgt,
out,
qout,
ascales,
wscales,
oscales,
poolout,
lora_act_in,
lora_up,
lora_down,
lora_act_out,
norm_q,
norm_k,
rotary_emb,
bias,
smooth_factor,
out_vk,
out_linearattn,
act_unsigned,
lora_scales,
fuse_silu,
fp4,
alpha,
wcscales
);
dispatchBool(fp4, [&]<bool USE_FP4>() {
GEMM_W4A4_Launch<Config, USE_FP4>::gemm_w4a4(
act,
wgt,
out,
qout,
ascales,
wscales,
oscales,
poolout,
lora_act_in,
lora_up,
lora_down,
lora_act_out,
norm_q,
norm_k,
rotary_emb,
bias,
smooth_factor,
out_vk,
out_linearattn,
act_unsigned,
lora_scales,
fuse_silu,
fp4,
alpha,
wcscales,
out_q,
out_k,
out_v,
attn_tokens
);
});
});
}
void linearattn_vk_mul_q(Tensor q, Tensor vk) {
invoke_launch(q.dtype(), [&]<typename Config>() {
GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(q, vk);
GEMM_W4A4_Launch<Config, false>::linearattn_vk_mul_q(q, vk);
});
}
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) {
invoke_launch(input.dtype(), [&]<typename Config>() {
GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(
input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu, fp4
);
dispatchBool(fp4, [&]<bool USE_FP4>() {
GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(
input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu, fp4
);
});
});
}
void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) {
invoke_launch(input.dtype(), [&]<typename Config>() {
GEMM_W4A4_Launch<Config>::quantize_w4a4_act(
GEMM_W4A4_Launch<Config, false>::quantize_w4a4_act(
input, output, oscales
);
});
}
void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) {
invoke_launch(input.dtype(), [&]<typename Config>() {
GEMM_W4A4_Launch<Config>::quantize_w4a4_wgt(
GEMM_W4A4_Launch<Config, false>::quantize_w4a4_wgt(
input, output, oscales
);
});
......
......@@ -1618,13 +1618,300 @@ public:
}
};
struct EpilogueLiteLA {
struct EpilogueRMSNormRope {
static constexpr int HEAD_DIM = 128;
static constexpr int NUM_HEADS_PER_WARP = WARP_N / HEAD_DIM;
static constexpr int WARP_N_TILES_PER_HEAD = WARP_N_TILES / NUM_HEADS_PER_WARP;
static constexpr int ROTARY_EMB_NUM_ELEMENTS = 2;
using packed_rotemb_t = float4;
static constexpr int WARP_N_ROTEMB_TILES = WARP_N_TILES / NUM_HEADS_PER_WARP * 2;
using rotemb_warp = std::array<packed_rotemb_t, WARP_M_TILES * WARP_N_ROTEMB_TILES>; // 128 regs
struct Arguments {
// **packed** [M, HEAD_DIM] float => [M // 16, HEAD_DIM // 8, WARP_SIZE] of packed_rotemb_t
// aka [M // BLOCK_M, NUM_WARPS, WARP_M_TILES, WARP_N_TILES // NUM_HEADS_PER_WARP * 2, WARP_SIZE]
const packed_rotemb_t *rotary_emb;
const half_t *rmsnorm_weight_q; // [HEAD_DIM]
const half_t *rmsnorm_weight_k; // [HEAD_DIM]
float epsilon;
};
__device__ __forceinline__
static rotemb_warp load_rotemb(const packed_rotemb_t *ptr_rotemb) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
rotemb_warp rotemb;
const packed_rotemb_t *ptrlane = &ptr_rotemb[warpId * WARP_M_TILES * WARP_N_ROTEMB_TILES * WARP_SIZE + laneId];
unrolled_loop<WARP_M_TILES>([&]<int i>() {
unrolled_loop<WARP_N_ROTEMB_TILES>([&]<int j>() {
constexpr int offset = (i * WARP_N_ROTEMB_TILES + j) * WARP_SIZE;
rotemb[i * WARP_N_ROTEMB_TILES + j] = load(&ptrlane[offset]);
});
});
return rotemb;
}
__device__ __forceinline__
static void load_rmsnorm(const half_t *ptr_rmsnorm_weight, half_t *shmem) {
const int laneId = threadIdx.x % WARP_SIZE;
static constexpr int PACK_SIZE = HEAD_DIM / WARP_SIZE;
using packed_t = std::array<half_t, PACK_SIZE>;
packed_t pack = load(reinterpret_cast<const packed_t *>(ptr_rmsnorm_weight + laneId * PACK_SIZE));
store<true>(reinterpret_cast<packed_t *>(shmem + laneId * PACK_SIZE), pack);
}
__device__ __forceinline__
static packed_fpsum_t load_rmsnorm_from_shmem(half_t *shmem, int n) {
const int laneId = threadIdx.x % WARP_SIZE;
const int col = n * INSN_N + laneId / 16 * 8; // lane 0-15: n*16+0, lane 16-31: n*16+8
uint4 tmp;
ldmatrix(shmem + col, tmp);
return bit_cast<packed_fpsum_t>(tmp);
}
__device__ __forceinline__
static void apply(fpsum_warp &fpsum, const packed_rotemb_t *ptr_rotemb, const half_t *ptr_rmsnorm_weight, float epsilon) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
__shared__ half_t shmem_rmsnorm[NUM_WARPS][HEAD_DIM];
load_rmsnorm(ptr_rmsnorm_weight, &shmem_rmsnorm[warpId][0]);
__syncwarp();
rotemb_warp rotemb = load_rotemb(ptr_rotemb);
float rmsnorm_coef[NUM_HEADS_PER_WARP][WARP_M_TILES][2];
auto sqr = [](half2_t val) ALWAYSINLINE {
float2 fval = half22float2(val);
return fval.x * fval.x + fval.y * fval.y;
};
#pragma unroll
for (int head = 0; head < NUM_HEADS_PER_WARP; head++) {
const int n_offset = head * WARP_N_TILES_PER_HEAD;
#pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) {
float sqrsum[2] = {0.0f, 0.0f};
#pragma unroll
for (int n = 0; n < WARP_N_TILES_PER_HEAD; n++) {
sqrsum[0] += sqr(fpsum[m * WARP_N_TILES + n + n_offset].data[0]);
sqrsum[1] += sqr(fpsum[m * WARP_N_TILES + n + n_offset].data[1]);
sqrsum[0] += sqr(fpsum[m * WARP_N_TILES + n + n_offset].data[2]);
sqrsum[1] += sqr(fpsum[m * WARP_N_TILES + n + n_offset].data[3]);
}
#pragma unroll
for (int mask = 1; mask <= 2; mask *= 2) {
sqrsum[0] += __shfl_xor_sync(~0, sqrsum[0], mask);
sqrsum[1] += __shfl_xor_sync(~0, sqrsum[1], mask);
}
rmsnorm_coef[head][m][0] = cuda_frsqrt(sqrsum[0] / HEAD_DIM + epsilon);
rmsnorm_coef[head][m][1] = cuda_frsqrt(sqrsum[1] / HEAD_DIM + epsilon);
}
}
#pragma unroll
for (int head = 0; head < NUM_HEADS_PER_WARP; head++) {
const int n_offset = head * WARP_N_TILES_PER_HEAD;
#pragma unroll
for (int n = 0; n < WARP_N_TILES_PER_HEAD; n++) {
packed_f32psum_t rms = packed_fp16_to_fp32(load_rmsnorm_from_shmem(&shmem_rmsnorm[warpId][0], n));
#pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) {
packed_f32psum_t pack = packed_fp16_to_fp32(fpsum[m * WARP_N_TILES + n + n_offset]);
pack.data[0] *= rmsnorm_coef[head][m][0] * rms.data[0];
pack.data[1] *= rmsnorm_coef[head][m][0] * rms.data[1];
pack.data[2] *= rmsnorm_coef[head][m][1] * rms.data[2];
pack.data[3] *= rmsnorm_coef[head][m][1] * rms.data[3];
pack.data[4] *= rmsnorm_coef[head][m][0] * rms.data[4];
pack.data[5] *= rmsnorm_coef[head][m][0] * rms.data[5];
pack.data[6] *= rmsnorm_coef[head][m][1] * rms.data[6];
pack.data[7] *= rmsnorm_coef[head][m][1] * rms.data[7];
auto rope = [](float &x, float &y, float sin, float cos) ALWAYSINLINE {
float ix = x, iy = y;
x = ix * cos - iy * sin;
y = ix * sin + iy * cos;
};
{
packed_rotemb_t sincos = rotemb[m * WARP_N_ROTEMB_TILES + n * 2];
rope(pack.data[0], pack.data[1], sincos.x, sincos.y);
rope(pack.data[2], pack.data[3], sincos.z, sincos.w);
}
{
packed_rotemb_t sincos = rotemb[m * WARP_N_ROTEMB_TILES + n * 2 + 1];
rope(pack.data[4], pack.data[5], sincos.x, sincos.y);
rope(pack.data[6], pack.data[7], sincos.z, sincos.w);
}
fpsum[m * WARP_N_TILES + n + n_offset] = packed_fp32_to_fp16(pack);
}
}
}
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
assert(binfo.numBlocksN % 3 == 0);
const bool is_q = bn < binfo.numBlocksN / 3;
const bool is_k = !is_q && bn < binfo.numBlocksN / 3 * 2;
if (is_q || is_k) {
apply(
fpsum,
args.rotary_emb + bm * NUM_WARPS * WARP_M_TILES * WARP_N_ROTEMB_TILES * WARP_SIZE,
is_q ? args.rmsnorm_weight_q : args.rmsnorm_weight_k,
args.epsilon
);
}
}
};
struct EpiloguePackQKV {
using attn_half_t = half;
using attn_half2_t = half2;
using packed_qkv_t = uint4;
static constexpr int HEAD_DIM = 128;
static constexpr int INSN_K_QK = 16;
static constexpr int INSN_K_PV = 16;
struct Arguments {
packed_qkv_t *out_q, *out_k, *out_v;
int actualM;
// !!! stride in number of packed_qkv_t !!!
int strideHead_q;
int strideHead_k;
int strideHead_v;
};
__device__ __forceinline__
static attn_half2_t convert_half2(half2_t input) {
if constexpr (std::is_same_v<half2_t, attn_half2_t>) {
return input;
} else {
float2 fval = half22float2(input);
return float22half2<attn_half2_t>(fval);
}
}
__device__ __forceinline__
static packed_qkv_t pack_q(packed_fpsum_t input) {
packed_qkv_t output;
output.x = bit_cast<int>(convert_half2(input.data[0]));
output.y = bit_cast<int>(convert_half2(input.data[1]));
output.z = bit_cast<int>(convert_half2(input.data[2]));
output.w = bit_cast<int>(convert_half2(input.data[3]));
return output;
}
__device__ __forceinline__
static packed_qkv_t pack_k(packed_fpsum_t input) {
packed_qkv_t output;
output.x = bit_cast<int>(convert_half2(input.data[0]));
output.y = bit_cast<int>(convert_half2(input.data[2]));
output.z = bit_cast<int>(convert_half2(input.data[1]));
output.w = bit_cast<int>(convert_half2(input.data[3]));
return output;
}
__device__ __forceinline__
static half2_t movmatrix(half2_t x) {
asm volatile ("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" : "=r"(*reinterpret_cast<uint32_t *>(&x)) : "r"(*reinterpret_cast<uint32_t *>(&x)));
return x;
static packed_qkv_t pack_v(packed_fpsum_t input) {
packed_qkv_t output;
output.x = bit_cast<int>(convert_half2(movmatrix(input.data[0])));
output.y = bit_cast<int>(convert_half2(movmatrix(input.data[1])));
output.z = bit_cast<int>(convert_half2(movmatrix(input.data[2])));
output.w = bit_cast<int>(convert_half2(movmatrix(input.data[3])));
return output;
}
__device__ __forceinline__
static void mask(packed_qkv_t &pack, uint32_t maskVal, int m, int maxRows) {
const int laneId = threadIdx.x % WARP_SIZE;
if (m * INSN_M + laneId / 4 >= maxRows) {
pack.x = maskVal;
pack.z = maskVal;
}
if (m * INSN_M + laneId / 4 + 8 >= maxRows) {
pack.y = maskVal;
pack.w = maskVal;
}
}
// qkv: [batch, head, bm, NUM_WARPS, WARP_M_TILES, WARP_N_TILES, WARP_SIZE] of packed_qkv_t
template<typename F>
__device__ __forceinline__
static void apply(fpsum_warp &fpsum, packed_qkv_t *ptr_output, int maxRows, F &&funcPack, attn_half2_t maskVal) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
static_assert(HEAD_DIM == WARP_N);
packed_qkv_t *ptrlane = &ptr_output[((warpId * WARP_M_TILES + 0) * WARP_N_TILES + 0) * WARP_SIZE + laneId];
unrolled_loop<WARP_M_TILES>([&]<int m>() ALWAYSINLINE {
unrolled_loop<WARP_N_TILES>([&]<int n>() ALWAYSINLINE {
packed_qkv_t pack = funcPack(fpsum[m * WARP_N_TILES + n]);
mask(pack, bit_cast<uint32_t>(maskVal), m, maxRows - warpId * WARP_M);
store(&ptrlane[(m * WARP_N_TILES + n) * WARP_SIZE], pack);
});
});
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
assert(binfo.numBlocksN % 3 == 0);
const int numBlocksQ = binfo.numBlocksN / 3;
const bool is_q = bn < numBlocksQ;
const bool is_k = !is_q && bn < numBlocksQ * 2;
// bn is head_id (assume HEAD_DIM == WARP_N)
int head_id, strideHead;
if (is_q) {
head_id = bn;
strideHead = args.strideHead_q;
} else if (is_k) {
head_id = bn - numBlocksQ;
strideHead = args.strideHead_k;
} else {
head_id = bn - numBlocksQ * 2;
strideHead = args.strideHead_v;
}
int block_offset = head_id * strideHead + bm * NUM_WARPS * WARP_M_TILES * WARP_N_TILES * WARP_SIZE;
int maxRows = args.actualM - bm * BLOCK_M;
// static constexpr float neginf = -std::numeric_limits<float>::infinity();
if (is_q) {
apply(fpsum, args.out_q + block_offset, maxRows, pack_q, attn_half2_t(0.0f, 0.0f));
} else if (is_k) {
apply(fpsum, args.out_k + block_offset, maxRows, pack_k, attn_half2_t(NAN, NAN));
} else {
apply(fpsum, args.out_v + block_offset, maxRows, pack_v, attn_half2_t(0.0f, 0.0f));
}
}
};
struct EpilogueLiteLA {
__device__ __forceinline__
static packed_f32psum_t mma_litela(packed_fpsum_t k, packed_fpsum_t v, packed_f32psum_t psum) {
......@@ -1874,6 +2161,62 @@ public:
);
}
};
template<typename Epilogue>
struct test_epilogue_kernel {
static constexpr size_t SHMEM_PER_WARP = ceilDiv<size_t>(load_act_to_fpsum<false>::SHMEM_SIZE, 128) * 128;
static constexpr size_t SHMEM_SIZE = SHMEM_PER_WARP * NUM_WARPS;
struct Arguments {
const half_t *input;
half_t *output;
// aligned to BLOCK_M and BLOCK_N
int M, N;
int actualM, actualN;
typename Epilogue::Arguments argsEpilogue;
};
__device__ __forceinline__
void operator()(Arguments args)
{
const BlockInfo binfo = {
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.numBlocksM = (int)gridDim.x,
.numBlocksN = (int)gridDim.y,
};
const int bm = binfo.bm;
const int bn = binfo.bn;
const int warpId = threadIdx.x / WARP_SIZE;
const int m_offset = bm * BLOCK_M + warpId * WARP_M;
const int n_offset = bn * BLOCK_N;
extern __shared__ uint8_t shmem[];
fpsum_warp fpsum;
load_act_to_fpsum<false>()(
args.input + m_offset * args.actualN + n_offset,
args.actualN,
args.actualM - m_offset,
args.actualN - n_offset,
fpsum,
shmem + warpId * SHMEM_PER_WARP
);
Epilogue()(binfo, fpsum, args.M, args.N, 0, args.argsEpilogue);
EpilogueDefault()(binfo, fpsum, args.M, args.N, 0, typename EpilogueDefault::Arguments{
.out = args.output,
.actualM = args.actualM,
.actualN = args.actualN,
});
}
};
};
}; // namespace nunchaku::kernels
\ No newline at end of file
......@@ -2,7 +2,7 @@
namespace nunchaku::kernels {
template<typename Config>
template<typename Config, bool USE_FP4>
class GEMM_W4A4_Launch {
using GEMM = GEMM_W4A4<Config>;
// using LoraRanks = std::integer_sequence<int, 0, 32>;
......@@ -48,7 +48,11 @@ public:
bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales // packed ws [N]
Tensor wcscales, // packed ws [N]
Tensor out_q, // packed attention [B, H, M, D]
Tensor out_k, // packed attention [B, H, M, D]
Tensor out_v, // packed attention [B, H, M, D]
int attn_tokens
);
static void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4);
static void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales);
......
#include "gemm_w4a4_launch_impl.cuh"
namespace nunchaku::kernels {
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_BF16>;
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_BF16, true>;
};
\ No newline at end of file
#include "gemm_w4a4_launch_impl.cuh"
namespace nunchaku::kernels {
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_BF16, false>;
};
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment