"vscode:/vscode.git/clone" did not exist on "d791ef440b45f57b4af7e4bb23dc539fe3e8ebdc"
Commit 92ac7b40 authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

Add our own FP16 Attention implementation

parent 182c323c
#include "gemm_w4a4_launch_impl.cuh" #include "gemm_w4a4_launch_impl.cuh"
namespace nunchaku::kernels { namespace nunchaku::kernels {
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>; template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, true>;
}; };
\ No newline at end of file
#include "gemm_w4a4_launch_impl.cuh"
namespace nunchaku::kernels {
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>;
};
\ No newline at end of file
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
namespace nunchaku::kernels { namespace nunchaku::kernels {
#ifndef __INTELLISENSE__ #ifndef __INTELLISENSE__
template<typename Config> template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config>::gemm_w4a4( void GEMM_W4A4_Launch<Config, USE_FP4>::gemm_w4a4(
#else #else
template<> template<>
void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
#endif #endif
Tensor act, // packed act [M, K / 2] Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2] Tensor wgt, // packed act [N, K / 2]
...@@ -33,8 +33,17 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( ...@@ -33,8 +33,17 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
bool fuse_silu, bool fuse_silu,
bool fp4, bool fp4,
float alpha, float alpha,
Tensor wcscales // packed ws [N] Tensor wcscales, // packed ws [N]
Tensor out_q, // packed attention [B, H, M, D]
Tensor out_k, // packed attention [B, H, M, D]
Tensor out_v, // packed attention [B, H, M, D]
int attn_tokens
) { ) {
#ifdef __INTELLISENSE__
static constexpr bool USE_FP4 = false;
#endif
assert(fp4 == USE_FP4);
int M = act.numel() / act.shape[-1]; int M = act.numel() / act.shape[-1];
int N = wgt.shape[0]; int N = wgt.shape[0];
int K = act.shape[-1] * 2; int K = act.shape[-1] * 2;
...@@ -71,90 +80,88 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( ...@@ -71,90 +80,88 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
std::swap(grid.x, grid.y); std::swap(grid.x, grid.y);
} }
dispatchBool(fp4, [&]<bool USE_FP4>() { // test_sizeof<typename Epilogue::Arguments>();
// test_sizeof<typename Epilogue::Arguments>(); // std::apply([](auto ...args) {
// std::apply([](auto ...args) { // (test_sizeof<decltype(args)>(), ...);
// (test_sizeof<decltype(args)>(), ...); // }, args);
// }, args);
// constexpr bool FP4_AVAILABLE = __CUDA_ARCH__ >= 1200;
// constexpr bool FP4_AVAILABLE = __CUDA_ARCH__ >= 1200;
if constexpr (!USE_FP4) {
if constexpr (!USE_FP4) { dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() {
dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() { auto func = invoke_kernel<typename GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>,
auto func = invoke_kernel<typename GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>, const packed_act_t *,
const packed_act_t *, const packed_wgt_t *,
const packed_wgt_t *, const packed_ascale_t *,
const packed_ascale_t *, const packed_wscale_t *,
const packed_wscale_t *, int, int, int,
int, int, int, typename Epilogue::Arguments,
typename Epilogue::Arguments, bool,
bool, bool>;
bool>;
if (shmem >= 24 * 1024) {
if (shmem >= 24 * 1024) { checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); }
}
assert(alpha == 1.0f);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_ascale_t>(),
wscales.data_ptr<packed_wscale_t>(),
M, N, K,
args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
});
return;
}
if constexpr (USE_FP4) { assert(alpha == 1.0f);
dispatchBool(alpha != 1.0f, [&]<bool USE_ALPHA>() {
assert(!act_unsigned);
auto func = invoke_kernel<typename GEMM::gemm_w4a4_fp4_kernel<Epilogue, USE_ALPHA>,
const packed_act_t *,
const packed_wgt_t *,
const packed_amscale_t *,
const packed_wmscale_t *,
float,
int, int, int,
typename Epilogue::Arguments,
bool,
bool>;
if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
}
assert(ascales.dtype() == Tensor::FP8_E4M3);
assert(wscales.dtype() == Tensor::FP8_E4M3);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_amscale_t>(),
wscales.data_ptr<packed_wmscale_t>(),
alpha,
M, N, K,
args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
});
return; func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
} act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_ascale_t>(),
wscales.data_ptr<packed_wscale_t>(),
M, N, K,
args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
});
return;
}
// if constexpr (USE_FP4 && !FP4_AVAILABLE) { if constexpr (USE_FP4) {
// throw std::runtime_error("FP4 kernel is not available"); dispatchBool(alpha != 1.0f, [&]<bool USE_ALPHA>() {
// } assert(!act_unsigned);
});
auto func = invoke_kernel<typename GEMM::gemm_w4a4_fp4_kernel<Epilogue, USE_ALPHA>,
const packed_act_t *,
const packed_wgt_t *,
const packed_amscale_t *,
const packed_wmscale_t *,
float,
int, int, int,
typename Epilogue::Arguments,
bool,
bool>;
if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
}
assert(ascales.dtype() == Tensor::FP8_E4M3);
assert(wscales.dtype() == Tensor::FP8_E4M3);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_amscale_t>(),
wscales.data_ptr<packed_wmscale_t>(),
alpha,
M, N, K,
args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
});
return;
}
// if constexpr (USE_FP4 && !FP4_AVAILABLE) {
// throw std::runtime_error("FP4 kernel is not available");
// }
}; };
auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) { auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) {
...@@ -262,30 +269,28 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( ...@@ -262,30 +269,28 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
static constexpr float SHIFT_GELU = 0.171875f; static constexpr float SHIFT_GELU = 0.171875f;
dispatchBool(fp4, [&]<bool USE_FP4>() { constexpr bool USE_UNSIGNED = !USE_FP4;
constexpr bool USE_UNSIGNED = !USE_FP4; using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED, USE_FP4>;
using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED, USE_FP4>; auto argsQuantize = typename EpilogueQuantize::Arguments{
auto argsQuantize = typename EpilogueQuantize::Arguments{ .qout = qout.data_ptr<packed_act_t>(),
.qout = qout.data_ptr<packed_act_t>(), .oscales = oscales.data_ptr<typename EpilogueQuantize::oscales_t>(),
.oscales = oscales.data_ptr<typename EpilogueQuantize::oscales_t>(), .shift_value = USE_FP4 ? 0.0f : SHIFT_GELU,
.shift_value = USE_FP4 ? 0.0f : SHIFT_GELU, .smooth_factor = smooth_factor.data_ptr<packed_wscale_t>()
.smooth_factor = smooth_factor.data_ptr<packed_wscale_t>() };
};
// TODO: check if gelu is needed
// TODO: check if gelu is needed if (out.valid()) {
if (out.valid()) { launch_lora.template operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>, typename GEMM::EpilogueGelu>({
launch_lora.template operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>, typename GEMM::EpilogueGelu>({ typename GEMM::EpilogueDefault::Arguments{
typename GEMM::EpilogueDefault::Arguments{ .out = out.data_ptr<half_t>(),
.out = out.data_ptr<half_t>(), .actualM = actualM,
.actualM = actualM, .actualN = actualN,
.actualN = actualN, },
}, argsQuantize
argsQuantize }, {});
}, {}); } else {
} else { launch_lora.template operator()<EpilogueQuantize, typename GEMM::EpilogueGelu>(argsQuantize, {});
launch_lora.template operator()<EpilogueQuantize, typename GEMM::EpilogueGelu>(argsQuantize, {}); }
}
});
} else if (out_linearattn.valid()) { } else if (out_linearattn.valid()) {
...@@ -327,17 +332,54 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( ...@@ -327,17 +332,54 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
assert(norm_k.valid()); assert(norm_k.valid());
// assert(isTypeMatch<half_t>(rotary_emb.scalar_type())); // assert(isTypeMatch<half_t>(rotary_emb.scalar_type()));
assert(rotary_emb.scalar_type() == Tensor::FP32); assert(rotary_emb.scalar_type() == Tensor::FP32);
assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS); assert(rotary_emb.ndims() == 3);
launch_lora.template operator()<typename GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{ assert(rotary_emb.shape[0] * rotary_emb.shape[1] == M);
.out = out.data_ptr<half_t>(), assert(rotary_emb.shape[2] == GEMM::EpilogueRMSNormRope::HEAD_DIM);
.actualM = actualM,
.actualN = actualN, // assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS);
.pool_out = poolout.valid() ? poolout.data_ptr<half_t>() : nullptr, // launch_lora.template operator()<typename GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
.rotary_emb = rotary_emb.data_ptr<float>(), // .out = out.data_ptr<half_t>(),
// .actualM = actualM,
// .actualN = actualN,
// .pool_out = poolout.valid() ? poolout.data_ptr<half_t>() : nullptr,
// .rotary_emb = rotary_emb.data_ptr<float>(),
// .rmsnorm_weight_q = norm_q.data_ptr<half_t>(),
// .rmsnorm_weight_k = norm_k.data_ptr<half_t>(),
// .epsilon = 1e-6,
// }, {});
using EpilogueRope = typename GEMM::EpilogueRMSNormRope;
auto argsRope = typename GEMM::EpilogueRMSNormRope::Arguments{
.rotary_emb = rotary_emb.data_ptr<typename EpilogueRope::packed_rotemb_t>(),
.rmsnorm_weight_q = norm_q.data_ptr<half_t>(), .rmsnorm_weight_q = norm_q.data_ptr<half_t>(),
.rmsnorm_weight_k = norm_k.data_ptr<half_t>(), .rmsnorm_weight_k = norm_k.data_ptr<half_t>(),
.epsilon = 1e-6, .epsilon = 1e-6,
}, {}); };
if (out_q.valid()) {
launch_lora.template operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename GEMM::EpiloguePackQKV>, typename GEMM::EpilogueNop>({
argsRope,
typename GEMM::EpiloguePackQKV::Arguments{
.out_q = out_q.data_ptr<typename GEMM::EpiloguePackQKV::packed_qkv_t>(),
.out_k = out_k.data_ptr<typename GEMM::EpiloguePackQKV::packed_qkv_t>(),
.out_v = out_v.data_ptr<typename GEMM::EpiloguePackQKV::packed_qkv_t>(),
.actualM = attn_tokens,
.strideHead_q = int(out_q.stride(1) * out_q.scalar_size() / sizeof(typename GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_k = int(out_k.stride(1) * out_k.scalar_size() / sizeof(typename GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_v = int(out_v.stride(1) * out_v.scalar_size() / sizeof(typename GEMM::EpiloguePackQKV::packed_qkv_t)),
}
}, {});
} else {
launch_lora.template operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename GEMM::EpilogueDefault>, typename GEMM::EpilogueNop>({
argsRope,
typename GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
.actualN = actualN,
}
}, {});
}
} else if (out.valid()) { } else if (out.valid()) {
using Epilogue = typename GEMM::EpilogueDefault; using Epilogue = typename GEMM::EpilogueDefault;
...@@ -357,8 +399,8 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( ...@@ -357,8 +399,8 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
} }
} }
template<typename Config> template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) { void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
using Epilogue = typename GEMM::EpilogueLiteLA; using Epilogue = typename GEMM::EpilogueLiteLA;
int batch_size = vk.shape[0]; int batch_size = vk.shape[0];
...@@ -384,8 +426,8 @@ void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) { ...@@ -384,8 +426,8 @@ void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
checkCUDA(cudaGetLastError()); checkCUDA(cudaGetLastError());
} }
template<typename Config> template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) { void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) {
const int actualM = input.numel() / input.shape[-1]; const int actualM = input.numel() / input.shape[-1];
const int actualN = input.shape[-1]; const int actualN = input.shape[-1];
...@@ -418,38 +460,41 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor ...@@ -418,38 +460,41 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
dispatchVal(rank, LoraRanks(), [&]<int RANK>() { dispatchVal(rank, LoraRanks(), [&]<int RANK>() {
dispatchBool(fuse_glu, [&]<bool FUSE_GLU>() { dispatchBool(fuse_glu, [&]<bool FUSE_GLU>() {
dispatchBool(fp4, [&]<bool USE_FP4>() { using Lora = typename GEMM::Lora<RANK>;
using Lora = typename GEMM::Lora<RANK>; using kernel = typename Lora::quantize_w4a4_fuse_lora_kernel<FUSE_GLU, USE_FP4>;
using kernel = typename Lora::quantize_w4a4_fuse_lora_kernel<FUSE_GLU, USE_FP4>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>( typename kernel::Arguments{
typename kernel::Arguments{ .input = input.data_ptr<half_t>(),
.input = input.data_ptr<half_t>(), .smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr,
.smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr, .output = output.data_ptr<packed_act_t>(),
.output = output.data_ptr<packed_act_t>(), .oscales = oscales.data_ptr<typename kernel::oscales_t>(),
.oscales = oscales.data_ptr<typename kernel::oscales_t>(), .lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(), .lora_act = lora_act_out.data_ptr<float>(),
.lora_act = lora_act_out.data_ptr<float>(), .M = M,
.M = M, .N = N,
.N = N, .actualM = actualM,
.actualM = actualM, .actualN = actualN,
.actualN = actualN, }
} );
); checkCUDA(cudaGetLastError());
checkCUDA(cudaGetLastError());
});
}); });
}); });
} }
template<typename Config> template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config>::quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) { void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) {
if constexpr (USE_FP4) {
assert(false); // not implemented
return;
}
int M = input.numel() / input.shape[-1]; int M = input.numel() / input.shape[-1];
int K = input.shape[-1]; int K = input.shape[-1];
...@@ -471,8 +516,13 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act(Tensor input, Tensor output, Te ...@@ -471,8 +516,13 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act(Tensor input, Tensor output, Te
checkCUDA(cudaGetLastError()); checkCUDA(cudaGetLastError());
} }
template<typename Config> template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config>::quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) { void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) {
if constexpr (USE_FP4) {
assert(false);
return;
}
int N = input.numel() / input.shape[-1]; int N = input.numel() / input.shape[-1];
int K = input.shape[-1]; int K = input.shape[-1];
......
#include "zgemm.h"
#include "gemm_w4a4.cuh"
namespace nunchaku::kernels {
void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k, Tensor rotary_emb) {
assert(input.ndims() == 2);
const int M = input.shape[0];
const int N = input.shape[1];
assert(input.shape.dataExtent == output.shape.dataExtent);
assert(input.scalar_type() == Tensor::FP16);
using GEMM = GEMM_W4A4<GEMMConfig_W4A4_FP16>;
using Epilogue = GEMM::EpilogueRMSNormRope;
assert(M % GEMM::BLOCK_M == 0);
assert(N % GEMM::BLOCK_N == 0);
using kernel = typename GEMM::test_epilogue_kernel<Epilogue>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{
.input = input.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(),
.M = M,
.N = N,
.actualM = M,
.actualN = N,
.argsEpilogue = typename Epilogue::Arguments{
.rotary_emb = rotary_emb.data_ptr<typename Epilogue::packed_rotemb_t>(),
.rmsnorm_weight_q = norm_q.data_ptr<GEMM::half_t>(),
.rmsnorm_weight_k = norm_k.data_ptr<GEMM::half_t>(),
.epsilon = 1e-6,
}
}
);
checkCUDA(cudaGetLastError());
}
void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens) {
assert(input.ndims() == 2);
const int M = input.shape[0];
const int N = input.shape[1];
assert(input.scalar_type() == Tensor::FP16);
Tensor output = Tensor::empty_like(input);
using GEMM = GEMM_W4A4<GEMMConfig_W4A4_FP16>;
using Epilogue = GEMM::EpiloguePackQKV;
assert(M % GEMM::BLOCK_M == 0);
assert(N % GEMM::BLOCK_N == 0);
using kernel = typename GEMM::test_epilogue_kernel<Epilogue>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{
.input = input.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(),
.M = M,
.N = N,
.actualM = M,
.actualN = N,
.argsEpilogue = typename Epilogue::Arguments{
.out_q = out_q.data_ptr<typename Epilogue::packed_qkv_t>(),
.out_k = out_k.data_ptr<typename Epilogue::packed_qkv_t>(),
.out_v = out_v.data_ptr<typename Epilogue::packed_qkv_t>(),
.actualM = numTokens,
.strideHead_q = int(out_q.stride(1) * out_q.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_k = int(out_k.stride(1) * out_k.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_v = int(out_v.stride(1) * out_v.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
}
}
);
checkCUDA(cudaGetLastError());
}
}; // namespace nunchaku::kernels
\ No newline at end of file
...@@ -30,7 +30,11 @@ void gemm_w4a4( ...@@ -30,7 +30,11 @@ void gemm_w4a4(
bool fuse_silu, bool fuse_silu,
bool fp4, bool fp4,
float alpha, float alpha,
Tensor wcscales Tensor wcscales,
Tensor out_q, // packed attention [B, H, M, D]
Tensor out_k, // packed attention [B, H, M, D]
Tensor out_v, // packed attention [B, H, M, D]
int attn_tokens
); );
void linearattn_vk_mul_q(Tensor q, Tensor vk); void linearattn_vk_mul_q(Tensor q, Tensor vk);
...@@ -57,4 +61,17 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl ...@@ -57,4 +61,17 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
// Tensor wscales // [1, N] // Tensor wscales // [1, N]
// ); // );
void attention_fp16(
Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
float scale
);
// FOR TEST ONLY
void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k, Tensor rotary_emb);
void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens);
}; // namespace nunchaku::kernels }; // namespace nunchaku::kernels
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment