Commit 92ac7b40 authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

Add our own FP16 Attention implementation

parent 182c323c
#include "gemm_w4a4_launch_impl.cuh"
namespace nunchaku::kernels {
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>;
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, true>;
};
\ No newline at end of file
#include "gemm_w4a4_launch_impl.cuh"
namespace nunchaku::kernels {
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>;
};
\ No newline at end of file
......@@ -3,11 +3,11 @@
namespace nunchaku::kernels {
#ifndef __INTELLISENSE__
template<typename Config>
void GEMM_W4A4_Launch<Config>::gemm_w4a4(
template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config, USE_FP4>::gemm_w4a4(
#else
template<>
void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
#endif
Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2]
......@@ -33,8 +33,17 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales // packed ws [N]
Tensor wcscales, // packed ws [N]
Tensor out_q, // packed attention [B, H, M, D]
Tensor out_k, // packed attention [B, H, M, D]
Tensor out_v, // packed attention [B, H, M, D]
int attn_tokens
) {
#ifdef __INTELLISENSE__
static constexpr bool USE_FP4 = false;
#endif
assert(fp4 == USE_FP4);
int M = act.numel() / act.shape[-1];
int N = wgt.shape[0];
int K = act.shape[-1] * 2;
......@@ -71,90 +80,88 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
std::swap(grid.x, grid.y);
}
dispatchBool(fp4, [&]<bool USE_FP4>() {
// test_sizeof<typename Epilogue::Arguments>();
// std::apply([](auto ...args) {
// (test_sizeof<decltype(args)>(), ...);
// }, args);
// constexpr bool FP4_AVAILABLE = __CUDA_ARCH__ >= 1200;
if constexpr (!USE_FP4) {
dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() {
auto func = invoke_kernel<typename GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>,
const packed_act_t *,
const packed_wgt_t *,
const packed_ascale_t *,
const packed_wscale_t *,
int, int, int,
typename Epilogue::Arguments,
bool,
bool>;
if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
}
assert(alpha == 1.0f);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_ascale_t>(),
wscales.data_ptr<packed_wscale_t>(),
M, N, K,
args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
});
return;
}
// test_sizeof<typename Epilogue::Arguments>();
// std::apply([](auto ...args) {
// (test_sizeof<decltype(args)>(), ...);
// }, args);
// constexpr bool FP4_AVAILABLE = __CUDA_ARCH__ >= 1200;
if constexpr (!USE_FP4) {
dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() {
auto func = invoke_kernel<typename GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>,
const packed_act_t *,
const packed_wgt_t *,
const packed_ascale_t *,
const packed_wscale_t *,
int, int, int,
typename Epilogue::Arguments,
bool,
bool>;
if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
}
if constexpr (USE_FP4) {
dispatchBool(alpha != 1.0f, [&]<bool USE_ALPHA>() {
assert(!act_unsigned);
auto func = invoke_kernel<typename GEMM::gemm_w4a4_fp4_kernel<Epilogue, USE_ALPHA>,
const packed_act_t *,
const packed_wgt_t *,
const packed_amscale_t *,
const packed_wmscale_t *,
float,
int, int, int,
typename Epilogue::Arguments,
bool,
bool>;
if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
}
assert(ascales.dtype() == Tensor::FP8_E4M3);
assert(wscales.dtype() == Tensor::FP8_E4M3);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_amscale_t>(),
wscales.data_ptr<packed_wmscale_t>(),
alpha,
M, N, K,
args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
});
assert(alpha == 1.0f);
return;
}
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_ascale_t>(),
wscales.data_ptr<packed_wscale_t>(),
M, N, K,
args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
});
return;
}
// if constexpr (USE_FP4 && !FP4_AVAILABLE) {
// throw std::runtime_error("FP4 kernel is not available");
// }
});
if constexpr (USE_FP4) {
dispatchBool(alpha != 1.0f, [&]<bool USE_ALPHA>() {
assert(!act_unsigned);
auto func = invoke_kernel<typename GEMM::gemm_w4a4_fp4_kernel<Epilogue, USE_ALPHA>,
const packed_act_t *,
const packed_wgt_t *,
const packed_amscale_t *,
const packed_wmscale_t *,
float,
int, int, int,
typename Epilogue::Arguments,
bool,
bool>;
if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
}
assert(ascales.dtype() == Tensor::FP8_E4M3);
assert(wscales.dtype() == Tensor::FP8_E4M3);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_amscale_t>(),
wscales.data_ptr<packed_wmscale_t>(),
alpha,
M, N, K,
args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
});
return;
}
// if constexpr (USE_FP4 && !FP4_AVAILABLE) {
// throw std::runtime_error("FP4 kernel is not available");
// }
};
auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) {
......@@ -262,30 +269,28 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
static constexpr float SHIFT_GELU = 0.171875f;
dispatchBool(fp4, [&]<bool USE_FP4>() {
constexpr bool USE_UNSIGNED = !USE_FP4;
using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED, USE_FP4>;
auto argsQuantize = typename EpilogueQuantize::Arguments{
.qout = qout.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<typename EpilogueQuantize::oscales_t>(),
.shift_value = USE_FP4 ? 0.0f : SHIFT_GELU,
.smooth_factor = smooth_factor.data_ptr<packed_wscale_t>()
};
// TODO: check if gelu is needed
if (out.valid()) {
launch_lora.template operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>, typename GEMM::EpilogueGelu>({
typename GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
.actualN = actualN,
},
argsQuantize
}, {});
} else {
launch_lora.template operator()<EpilogueQuantize, typename GEMM::EpilogueGelu>(argsQuantize, {});
}
});
constexpr bool USE_UNSIGNED = !USE_FP4;
using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED, USE_FP4>;
auto argsQuantize = typename EpilogueQuantize::Arguments{
.qout = qout.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<typename EpilogueQuantize::oscales_t>(),
.shift_value = USE_FP4 ? 0.0f : SHIFT_GELU,
.smooth_factor = smooth_factor.data_ptr<packed_wscale_t>()
};
// TODO: check if gelu is needed
if (out.valid()) {
launch_lora.template operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>, typename GEMM::EpilogueGelu>({
typename GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
.actualN = actualN,
},
argsQuantize
}, {});
} else {
launch_lora.template operator()<EpilogueQuantize, typename GEMM::EpilogueGelu>(argsQuantize, {});
}
} else if (out_linearattn.valid()) {
......@@ -327,17 +332,54 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
assert(norm_k.valid());
// assert(isTypeMatch<half_t>(rotary_emb.scalar_type()));
assert(rotary_emb.scalar_type() == Tensor::FP32);
assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS);
launch_lora.template operator()<typename GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
.actualN = actualN,
.pool_out = poolout.valid() ? poolout.data_ptr<half_t>() : nullptr,
.rotary_emb = rotary_emb.data_ptr<float>(),
assert(rotary_emb.ndims() == 3);
assert(rotary_emb.shape[0] * rotary_emb.shape[1] == M);
assert(rotary_emb.shape[2] == GEMM::EpilogueRMSNormRope::HEAD_DIM);
// assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS);
// launch_lora.template operator()<typename GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
// .out = out.data_ptr<half_t>(),
// .actualM = actualM,
// .actualN = actualN,
// .pool_out = poolout.valid() ? poolout.data_ptr<half_t>() : nullptr,
// .rotary_emb = rotary_emb.data_ptr<float>(),
// .rmsnorm_weight_q = norm_q.data_ptr<half_t>(),
// .rmsnorm_weight_k = norm_k.data_ptr<half_t>(),
// .epsilon = 1e-6,
// }, {});
using EpilogueRope = typename GEMM::EpilogueRMSNormRope;
auto argsRope = typename GEMM::EpilogueRMSNormRope::Arguments{
.rotary_emb = rotary_emb.data_ptr<typename EpilogueRope::packed_rotemb_t>(),
.rmsnorm_weight_q = norm_q.data_ptr<half_t>(),
.rmsnorm_weight_k = norm_k.data_ptr<half_t>(),
.epsilon = 1e-6,
}, {});
};
if (out_q.valid()) {
launch_lora.template operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename GEMM::EpiloguePackQKV>, typename GEMM::EpilogueNop>({
argsRope,
typename GEMM::EpiloguePackQKV::Arguments{
.out_q = out_q.data_ptr<typename GEMM::EpiloguePackQKV::packed_qkv_t>(),
.out_k = out_k.data_ptr<typename GEMM::EpiloguePackQKV::packed_qkv_t>(),
.out_v = out_v.data_ptr<typename GEMM::EpiloguePackQKV::packed_qkv_t>(),
.actualM = attn_tokens,
.strideHead_q = int(out_q.stride(1) * out_q.scalar_size() / sizeof(typename GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_k = int(out_k.stride(1) * out_k.scalar_size() / sizeof(typename GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_v = int(out_v.stride(1) * out_v.scalar_size() / sizeof(typename GEMM::EpiloguePackQKV::packed_qkv_t)),
}
}, {});
} else {
launch_lora.template operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename GEMM::EpilogueDefault>, typename GEMM::EpilogueNop>({
argsRope,
typename GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
.actualN = actualN,
}
}, {});
}
} else if (out.valid()) {
using Epilogue = typename GEMM::EpilogueDefault;
......@@ -357,8 +399,8 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
}
}
template<typename Config>
void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
using Epilogue = typename GEMM::EpilogueLiteLA;
int batch_size = vk.shape[0];
......@@ -384,8 +426,8 @@ void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
checkCUDA(cudaGetLastError());
}
template<typename Config>
void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) {
template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) {
const int actualM = input.numel() / input.shape[-1];
const int actualN = input.shape[-1];
......@@ -418,38 +460,41 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
dispatchVal(rank, LoraRanks(), [&]<int RANK>() {
dispatchBool(fuse_glu, [&]<bool FUSE_GLU>() {
dispatchBool(fp4, [&]<bool USE_FP4>() {
using Lora = typename GEMM::Lora<RANK>;
using kernel = typename Lora::quantize_w4a4_fuse_lora_kernel<FUSE_GLU, USE_FP4>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{
.input = input.data_ptr<half_t>(),
.smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr,
.output = output.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<typename kernel::oscales_t>(),
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
.M = M,
.N = N,
.actualM = actualM,
.actualN = actualN,
}
);
checkCUDA(cudaGetLastError());
});
using Lora = typename GEMM::Lora<RANK>;
using kernel = typename Lora::quantize_w4a4_fuse_lora_kernel<FUSE_GLU, USE_FP4>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{
.input = input.data_ptr<half_t>(),
.smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr,
.output = output.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<typename kernel::oscales_t>(),
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
.M = M,
.N = N,
.actualM = actualM,
.actualN = actualN,
}
);
checkCUDA(cudaGetLastError());
});
});
}
template<typename Config>
void GEMM_W4A4_Launch<Config>::quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) {
template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) {
if constexpr (USE_FP4) {
assert(false); // not implemented
return;
}
int M = input.numel() / input.shape[-1];
int K = input.shape[-1];
......@@ -471,8 +516,13 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act(Tensor input, Tensor output, Te
checkCUDA(cudaGetLastError());
}
template<typename Config>
void GEMM_W4A4_Launch<Config>::quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) {
template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) {
if constexpr (USE_FP4) {
assert(false);
return;
}
int N = input.numel() / input.shape[-1];
int K = input.shape[-1];
......
#include "zgemm.h"
#include "gemm_w4a4.cuh"
namespace nunchaku::kernels {
void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k, Tensor rotary_emb) {
assert(input.ndims() == 2);
const int M = input.shape[0];
const int N = input.shape[1];
assert(input.shape.dataExtent == output.shape.dataExtent);
assert(input.scalar_type() == Tensor::FP16);
using GEMM = GEMM_W4A4<GEMMConfig_W4A4_FP16>;
using Epilogue = GEMM::EpilogueRMSNormRope;
assert(M % GEMM::BLOCK_M == 0);
assert(N % GEMM::BLOCK_N == 0);
using kernel = typename GEMM::test_epilogue_kernel<Epilogue>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{
.input = input.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(),
.M = M,
.N = N,
.actualM = M,
.actualN = N,
.argsEpilogue = typename Epilogue::Arguments{
.rotary_emb = rotary_emb.data_ptr<typename Epilogue::packed_rotemb_t>(),
.rmsnorm_weight_q = norm_q.data_ptr<GEMM::half_t>(),
.rmsnorm_weight_k = norm_k.data_ptr<GEMM::half_t>(),
.epsilon = 1e-6,
}
}
);
checkCUDA(cudaGetLastError());
}
void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens) {
assert(input.ndims() == 2);
const int M = input.shape[0];
const int N = input.shape[1];
assert(input.scalar_type() == Tensor::FP16);
Tensor output = Tensor::empty_like(input);
using GEMM = GEMM_W4A4<GEMMConfig_W4A4_FP16>;
using Epilogue = GEMM::EpiloguePackQKV;
assert(M % GEMM::BLOCK_M == 0);
assert(N % GEMM::BLOCK_N == 0);
using kernel = typename GEMM::test_epilogue_kernel<Epilogue>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{
.input = input.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(),
.M = M,
.N = N,
.actualM = M,
.actualN = N,
.argsEpilogue = typename Epilogue::Arguments{
.out_q = out_q.data_ptr<typename Epilogue::packed_qkv_t>(),
.out_k = out_k.data_ptr<typename Epilogue::packed_qkv_t>(),
.out_v = out_v.data_ptr<typename Epilogue::packed_qkv_t>(),
.actualM = numTokens,
.strideHead_q = int(out_q.stride(1) * out_q.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_k = int(out_k.stride(1) * out_k.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_v = int(out_v.stride(1) * out_v.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
}
}
);
checkCUDA(cudaGetLastError());
}
}; // namespace nunchaku::kernels
\ No newline at end of file
......@@ -30,7 +30,11 @@ void gemm_w4a4(
bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales
Tensor wcscales,
Tensor out_q, // packed attention [B, H, M, D]
Tensor out_k, // packed attention [B, H, M, D]
Tensor out_v, // packed attention [B, H, M, D]
int attn_tokens
);
void linearattn_vk_mul_q(Tensor q, Tensor vk);
......@@ -57,4 +61,17 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
// Tensor wscales // [1, N]
// );
void attention_fp16(
Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
float scale
);
// FOR TEST ONLY
void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k, Tensor rotary_emb);
void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens);
}; // namespace nunchaku::kernels
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment