Commit dd5d4bb3 authored by zhanghj2's avatar zhanghj2
Browse files

区分dim576和512

parent c3cf875a
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "../../helpers.h" #include "../../helpers.h"
namespace sm90::fwd { namespace sm90::fwd {
#define CUDART_L2E_F 1.442695041F
using namespace cute; using namespace cute;
...@@ -54,8 +55,11 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn ...@@ -54,8 +55,11 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
flash::lds_direct_copy<false, true, true>(gQ, sQ, 0, params.stride_q_h_q, params.h_q - bidh * kBlockM); flash::lds_direct_copy<false, true, true>(gQ, sQ, 0, params.stride_q_h_q, params.h_q - bidh * kBlockM);
flash::lds_direct_copy<false, true, true>(gQ, sQ, 1, params.stride_q_h_q, params.h_q - bidh * kBlockM); flash::lds_direct_copy<false, true, true>(gQ, sQ, 1, params.stride_q_h_q, params.h_q - bidh * kBlockM);
flash::lds_direct_copy<false, true, true>(gQ, sQ, 2, params.stride_q_h_q, params.h_q - bidh * kBlockM); flash::lds_direct_copy<false, true, true>(gQ, sQ, 2, params.stride_q_h_q, params.h_q - bidh * kBlockM);
flash::lds_direct_copy<false, true, true>(gQ, sQ, 3, params.stride_q_h_q, params.h_q - bidh * kBlockM); flash::lds_direct_copy<false, true, true>(gQ, sQ, 3, params.stride_q_h_q, params.h_q - bidh * kBlockM);
flash::lds_direct_copy<false, false, true>(gQ, sQ, 4, params.stride_q_h_q, params.h_q - bidh * kBlockM); if constexpr (D_QK == 576)
{
flash::lds_direct_copy<false, false, true>(gQ, sQ, 4, params.stride_q_h_q, params.h_q - bidh * kBlockM);
}
auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma); auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
...@@ -64,29 +68,56 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn ...@@ -64,29 +68,56 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
// asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); // asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t"); if constexpr (D_QK == 576)
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0)); {
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1)); asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1));
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5)); asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5));
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 9), tSrQ_copy_view(_, _, 9)); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 10), tSrQ_copy_view(_, _, 10)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 11), tSrQ_copy_view(_, _, 11)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 9), tSrQ_copy_view(_, _, 9));
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 10), tSrQ_copy_view(_, _, 10));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 12), tSrQ_copy_view(_, _, 12)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 11), tSrQ_copy_view(_, _, 11));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 13), tSrQ_copy_view(_, _, 13)); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 14), tSrQ_copy_view(_, _, 14)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 12), tSrQ_copy_view(_, _, 12));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 15), tSrQ_copy_view(_, _, 15)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 13), tSrQ_copy_view(_, _, 13));
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 14), tSrQ_copy_view(_, _, 14));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 16), tSrQ_copy_view(_, _, 16)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 15), tSrQ_copy_view(_, _, 15));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 17), tSrQ_copy_view(_, _, 17)); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 16), tSrQ_copy_view(_, _, 16));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 17), tSrQ_copy_view(_, _, 17));
}
else
{
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3));
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7));
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 9), tSrQ_copy_view(_, _, 9));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 10), tSrQ_copy_view(_, _, 10));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 11), tSrQ_copy_view(_, _, 11));
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 12), tSrQ_copy_view(_, _, 12));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 13), tSrQ_copy_view(_, _, 13));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 14), tSrQ_copy_view(_, _, 14));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 15), tSrQ_copy_view(_, _, 15));
}
__syncthreads(); __syncthreads();
...@@ -383,7 +414,7 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn ...@@ -383,7 +414,7 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
float* gMax_logits = reinterpret_cast<float *>(params.max_logits) + row_offset_lse; float* gMax_logits = reinterpret_cast<float *>(params.max_logits) + row_offset_lse;
if (params.attn_sink != nullptr) { if (params.attn_sink != nullptr) {
float rAttn_sink = __ldg((float*)params.attn_sink + start_head_idx + lane_idx % 16); float rAttn_sink = __ldg((float*)params.attn_sink + bidh * kBlockM + lane_idx % 16);
if (flash::is_positive_infinity(rAttn_sink)) if (flash::is_positive_infinity(rAttn_sink))
{ {
for (int i = 0; i < size(acc_o); i++) for (int i = 0; i < size(acc_o); i++)
...@@ -455,7 +486,7 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams &para ...@@ -455,7 +486,7 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams &para
KU_ASSERT(params.h_q % B_H == 0); KU_ASSERT(params.h_q % B_H == 0);
auto kernel = &sparse_attn_fwd_kernel<KernelTemplate<D_QK, HAVE_TOPK_LENGTH>>; auto kernel = &sparse_attn_fwd_kernel<KernelTemplate<D_QK, HAVE_TOPK_LENGTH>>;
constexpr size_t smem_size = 65536; constexpr size_t smem_size = 65536;
dim3 grid((params.s_q, params.h_q/B_H), 1); dim3 grid(params.s_q, params.h_q/B_H, 1);
kernel<<<grid, NUM_THREADS, smem_size, params.stream>>>(params); kernel<<<grid, NUM_THREADS, smem_size, params.stream>>>(params);
KU_CHECK_KERNEL_LAUNCH(); KU_CHECK_KERNEL_LAUNCH();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment