Commit 92ac7b40 authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

Add our own FP16 Attention implementation

parent 182c323c
#include "gemm_w4a4_launch_impl.cuh" #include "gemm_w4a4_launch_impl.cuh"
namespace nunchaku::kernels { namespace nunchaku::kernels {
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>; template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, true>;
}; };
\ No newline at end of file
#include "gemm_w4a4_launch_impl.cuh"
namespace nunchaku::kernels {
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>;
};
\ No newline at end of file
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
namespace nunchaku::kernels { namespace nunchaku::kernels {
#ifndef __INTELLISENSE__ #ifndef __INTELLISENSE__
template<typename Config> template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config>::gemm_w4a4( void GEMM_W4A4_Launch<Config, USE_FP4>::gemm_w4a4(
#else #else
template<> template<>
void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
#endif #endif
Tensor act, // packed act [M, K / 2] Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2] Tensor wgt, // packed act [N, K / 2]
...@@ -33,8 +33,17 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( ...@@ -33,8 +33,17 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
bool fuse_silu, bool fuse_silu,
bool fp4, bool fp4,
float alpha, float alpha,
Tensor wcscales // packed ws [N] Tensor wcscales, // packed ws [N]
Tensor out_q, // packed attention [B, H, M, D]
Tensor out_k, // packed attention [B, H, M, D]
Tensor out_v, // packed attention [B, H, M, D]
int attn_tokens
) { ) {
#ifdef __INTELLISENSE__
static constexpr bool USE_FP4 = false;
#endif
assert(fp4 == USE_FP4);
int M = act.numel() / act.shape[-1]; int M = act.numel() / act.shape[-1];
int N = wgt.shape[0]; int N = wgt.shape[0];
int K = act.shape[-1] * 2; int K = act.shape[-1] * 2;
...@@ -71,7 +80,6 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( ...@@ -71,7 +80,6 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
std::swap(grid.x, grid.y); std::swap(grid.x, grid.y);
} }
dispatchBool(fp4, [&]<bool USE_FP4>() {
// test_sizeof<typename Epilogue::Arguments>(); // test_sizeof<typename Epilogue::Arguments>();
// std::apply([](auto ...args) { // std::apply([](auto ...args) {
// (test_sizeof<decltype(args)>(), ...); // (test_sizeof<decltype(args)>(), ...);
...@@ -154,7 +162,6 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( ...@@ -154,7 +162,6 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
// if constexpr (USE_FP4 && !FP4_AVAILABLE) { // if constexpr (USE_FP4 && !FP4_AVAILABLE) {
// throw std::runtime_error("FP4 kernel is not available"); // throw std::runtime_error("FP4 kernel is not available");
// } // }
});
}; };
auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) { auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) {
...@@ -262,7 +269,6 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( ...@@ -262,7 +269,6 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
static constexpr float SHIFT_GELU = 0.171875f; static constexpr float SHIFT_GELU = 0.171875f;
dispatchBool(fp4, [&]<bool USE_FP4>() {
constexpr bool USE_UNSIGNED = !USE_FP4; constexpr bool USE_UNSIGNED = !USE_FP4;
using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED, USE_FP4>; using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED, USE_FP4>;
auto argsQuantize = typename EpilogueQuantize::Arguments{ auto argsQuantize = typename EpilogueQuantize::Arguments{
...@@ -285,7 +291,6 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( ...@@ -285,7 +291,6 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
} else { } else {
launch_lora.template operator()<EpilogueQuantize, typename GEMM::EpilogueGelu>(argsQuantize, {}); launch_lora.template operator()<EpilogueQuantize, typename GEMM::EpilogueGelu>(argsQuantize, {});
} }
});
} else if (out_linearattn.valid()) { } else if (out_linearattn.valid()) {
...@@ -327,17 +332,54 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( ...@@ -327,17 +332,54 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
assert(norm_k.valid()); assert(norm_k.valid());
// assert(isTypeMatch<half_t>(rotary_emb.scalar_type())); // assert(isTypeMatch<half_t>(rotary_emb.scalar_type()));
assert(rotary_emb.scalar_type() == Tensor::FP32); assert(rotary_emb.scalar_type() == Tensor::FP32);
assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS); assert(rotary_emb.ndims() == 3);
launch_lora.template operator()<typename GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{ assert(rotary_emb.shape[0] * rotary_emb.shape[1] == M);
.out = out.data_ptr<half_t>(), assert(rotary_emb.shape[2] == GEMM::EpilogueRMSNormRope::HEAD_DIM);
.actualM = actualM,
.actualN = actualN, // assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS);
.pool_out = poolout.valid() ? poolout.data_ptr<half_t>() : nullptr, // launch_lora.template operator()<typename GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
.rotary_emb = rotary_emb.data_ptr<float>(), // .out = out.data_ptr<half_t>(),
// .actualM = actualM,
// .actualN = actualN,
// .pool_out = poolout.valid() ? poolout.data_ptr<half_t>() : nullptr,
// .rotary_emb = rotary_emb.data_ptr<float>(),
// .rmsnorm_weight_q = norm_q.data_ptr<half_t>(),
// .rmsnorm_weight_k = norm_k.data_ptr<half_t>(),
// .epsilon = 1e-6,
// }, {});
using EpilogueRope = typename GEMM::EpilogueRMSNormRope;
auto argsRope = typename GEMM::EpilogueRMSNormRope::Arguments{
.rotary_emb = rotary_emb.data_ptr<typename EpilogueRope::packed_rotemb_t>(),
.rmsnorm_weight_q = norm_q.data_ptr<half_t>(), .rmsnorm_weight_q = norm_q.data_ptr<half_t>(),
.rmsnorm_weight_k = norm_k.data_ptr<half_t>(), .rmsnorm_weight_k = norm_k.data_ptr<half_t>(),
.epsilon = 1e-6, .epsilon = 1e-6,
};
if (out_q.valid()) {
launch_lora.template operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename GEMM::EpiloguePackQKV>, typename GEMM::EpilogueNop>({
argsRope,
typename GEMM::EpiloguePackQKV::Arguments{
.out_q = out_q.data_ptr<typename GEMM::EpiloguePackQKV::packed_qkv_t>(),
.out_k = out_k.data_ptr<typename GEMM::EpiloguePackQKV::packed_qkv_t>(),
.out_v = out_v.data_ptr<typename GEMM::EpiloguePackQKV::packed_qkv_t>(),
.actualM = attn_tokens,
.strideHead_q = int(out_q.stride(1) * out_q.scalar_size() / sizeof(typename GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_k = int(out_k.stride(1) * out_k.scalar_size() / sizeof(typename GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_v = int(out_v.stride(1) * out_v.scalar_size() / sizeof(typename GEMM::EpiloguePackQKV::packed_qkv_t)),
}
}, {}); }, {});
} else {
launch_lora.template operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename GEMM::EpilogueDefault>, typename GEMM::EpilogueNop>({
argsRope,
typename GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
.actualN = actualN,
}
}, {});
}
} else if (out.valid()) { } else if (out.valid()) {
using Epilogue = typename GEMM::EpilogueDefault; using Epilogue = typename GEMM::EpilogueDefault;
...@@ -357,8 +399,8 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( ...@@ -357,8 +399,8 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
} }
} }
template<typename Config> template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) { void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
using Epilogue = typename GEMM::EpilogueLiteLA; using Epilogue = typename GEMM::EpilogueLiteLA;
int batch_size = vk.shape[0]; int batch_size = vk.shape[0];
...@@ -384,8 +426,8 @@ void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) { ...@@ -384,8 +426,8 @@ void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
checkCUDA(cudaGetLastError()); checkCUDA(cudaGetLastError());
} }
template<typename Config> template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) { void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) {
const int actualM = input.numel() / input.shape[-1]; const int actualM = input.numel() / input.shape[-1];
const int actualN = input.shape[-1]; const int actualN = input.shape[-1];
...@@ -418,7 +460,6 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor ...@@ -418,7 +460,6 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
dispatchVal(rank, LoraRanks(), [&]<int RANK>() { dispatchVal(rank, LoraRanks(), [&]<int RANK>() {
dispatchBool(fuse_glu, [&]<bool FUSE_GLU>() { dispatchBool(fuse_glu, [&]<bool FUSE_GLU>() {
dispatchBool(fp4, [&]<bool USE_FP4>() {
using Lora = typename GEMM::Lora<RANK>; using Lora = typename GEMM::Lora<RANK>;
using kernel = typename Lora::quantize_w4a4_fuse_lora_kernel<FUSE_GLU, USE_FP4>; using kernel = typename Lora::quantize_w4a4_fuse_lora_kernel<FUSE_GLU, USE_FP4>;
...@@ -445,11 +486,15 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor ...@@ -445,11 +486,15 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
checkCUDA(cudaGetLastError()); checkCUDA(cudaGetLastError());
}); });
}); });
});
} }
template<typename Config> template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config>::quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) { void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) {
if constexpr (USE_FP4) {
assert(false); // not implemented
return;
}
int M = input.numel() / input.shape[-1]; int M = input.numel() / input.shape[-1];
int K = input.shape[-1]; int K = input.shape[-1];
...@@ -471,8 +516,13 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act(Tensor input, Tensor output, Te ...@@ -471,8 +516,13 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act(Tensor input, Tensor output, Te
checkCUDA(cudaGetLastError()); checkCUDA(cudaGetLastError());
} }
template<typename Config> template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config>::quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) { void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) {
if constexpr (USE_FP4) {
assert(false);
return;
}
int N = input.numel() / input.shape[-1]; int N = input.numel() / input.shape[-1];
int K = input.shape[-1]; int K = input.shape[-1];
......
#include "zgemm.h"
#include "gemm_w4a4.cuh"
namespace nunchaku::kernels {
void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k, Tensor rotary_emb) {
assert(input.ndims() == 2);
const int M = input.shape[0];
const int N = input.shape[1];
assert(input.shape.dataExtent == output.shape.dataExtent);
assert(input.scalar_type() == Tensor::FP16);
using GEMM = GEMM_W4A4<GEMMConfig_W4A4_FP16>;
using Epilogue = GEMM::EpilogueRMSNormRope;
assert(M % GEMM::BLOCK_M == 0);
assert(N % GEMM::BLOCK_N == 0);
using kernel = typename GEMM::test_epilogue_kernel<Epilogue>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{
.input = input.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(),
.M = M,
.N = N,
.actualM = M,
.actualN = N,
.argsEpilogue = typename Epilogue::Arguments{
.rotary_emb = rotary_emb.data_ptr<typename Epilogue::packed_rotemb_t>(),
.rmsnorm_weight_q = norm_q.data_ptr<GEMM::half_t>(),
.rmsnorm_weight_k = norm_k.data_ptr<GEMM::half_t>(),
.epsilon = 1e-6,
}
}
);
checkCUDA(cudaGetLastError());
}
void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens) {
assert(input.ndims() == 2);
const int M = input.shape[0];
const int N = input.shape[1];
assert(input.scalar_type() == Tensor::FP16);
Tensor output = Tensor::empty_like(input);
using GEMM = GEMM_W4A4<GEMMConfig_W4A4_FP16>;
using Epilogue = GEMM::EpiloguePackQKV;
assert(M % GEMM::BLOCK_M == 0);
assert(N % GEMM::BLOCK_N == 0);
using kernel = typename GEMM::test_epilogue_kernel<Epilogue>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{
.input = input.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(),
.M = M,
.N = N,
.actualM = M,
.actualN = N,
.argsEpilogue = typename Epilogue::Arguments{
.out_q = out_q.data_ptr<typename Epilogue::packed_qkv_t>(),
.out_k = out_k.data_ptr<typename Epilogue::packed_qkv_t>(),
.out_v = out_v.data_ptr<typename Epilogue::packed_qkv_t>(),
.actualM = numTokens,
.strideHead_q = int(out_q.stride(1) * out_q.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_k = int(out_k.stride(1) * out_k.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_v = int(out_v.stride(1) * out_v.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
}
}
);
checkCUDA(cudaGetLastError());
}
}; // namespace nunchaku::kernels
\ No newline at end of file
...@@ -30,7 +30,11 @@ void gemm_w4a4( ...@@ -30,7 +30,11 @@ void gemm_w4a4(
bool fuse_silu, bool fuse_silu,
bool fp4, bool fp4,
float alpha, float alpha,
Tensor wcscales Tensor wcscales,
Tensor out_q, // packed attention [B, H, M, D]
Tensor out_k, // packed attention [B, H, M, D]
Tensor out_v, // packed attention [B, H, M, D]
int attn_tokens
); );
void linearattn_vk_mul_q(Tensor q, Tensor vk); void linearattn_vk_mul_q(Tensor q, Tensor vk);
...@@ -57,4 +61,17 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl ...@@ -57,4 +61,17 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
// Tensor wscales // [1, N] // Tensor wscales // [1, N]
// ); // );
void attention_fp16(
Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
float scale
);
// FOR TEST ONLY
void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k, Tensor rotary_emb);
void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens);
}; // namespace nunchaku::kernels }; // namespace nunchaku::kernels
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment