Commit 2033d805 authored by zhanghj2's avatar zhanghj2
Browse files

支持纯bf16

parent 58b43d4a
...@@ -75,7 +75,7 @@ dense_attn_decode_interface( ...@@ -75,7 +75,7 @@ dense_attn_decode_interface(
const int num_heads = num_heads_k; const int num_heads = num_heads_k;
q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3) q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3)
.reshape({batch_size, q_seq_per_hk, num_heads, head_size_k}); .reshape({batch_size, q_seq_per_hk, num_heads, head_size_k});
int num_sm_parts = std::max(arch.num_sms / num_heads_k / cutlass::ceil_div(seqlen_q_ori*num_heads_q/num_heads_k, 64), 1); int num_sm_parts = std::max(arch.num_sms / num_heads_k / cutlass::ceil_div(seqlen_q_ori*num_heads_q/num_heads_k, 16), 1);
KU_CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k); KU_CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k);
KU_CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); KU_CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
namespace Config { namespace Config {
static constexpr int BLOCK_SIZE_M = 64; static constexpr int BLOCK_SIZE_M = 16;
static constexpr int PAGE_BLOCK_SIZE = 64; static constexpr int PAGE_BLOCK_SIZE = 64;
static constexpr int HEAD_DIM_K = 576; static constexpr int HEAD_DIM_K = 576;
......
...@@ -5,22 +5,598 @@ ...@@ -5,22 +5,598 @@
#include "params.h" #include "params.h"
#include "config.h" #include "config.h"
#include "traits.h" #include "traits.h"
#include "softmax.h"
using namespace cute; using namespace cute;
namespace sm90 { namespace sm90 {
// Here we use MAX_INIT_VAL_SM to initialize sM, and MAX_INIT_VAL for masking template<typename T>
// The reason is that, we need to calculate new_max = max(sM(row_idx), cur_max*scale_softmax_log2) __device__ void
// so we must guarantee that MAX_INIT_VAL*scale_softmax_log2 < MAX_INIT_VAL_SM compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams params,
static constexpr float MAX_INIT_VAL_SM = -1e30f; const int bidb, const int bidh, const int m_block,
static constexpr float MAX_INIT_VAL = -1e33f; const int n_split_idx, const int seqlen_k,
const int n_block_min, const int n_block_max, const bool NoSplit)
{
extern __shared__ char shared_memory[];
using SharedMemoryPlan = typename T::SharedMemoryPlan;
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(shared_memory);
const int tidx = threadIdx.x;
constexpr int kBlockM = T::BLOCK_SIZE_M;
constexpr int kBlockN = T::PAGE_BLOCK_SIZE;
constexpr int kHeadDim = T::HEAD_DIM_K;
constexpr int kHeadDimV = T::HEAD_DIM_V;
using Element = T::InputT;
using index_t = int64_t;
const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
const index_t row_offset_k = (bidh) * params.k_head_stride;
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.k_row_stride, _1{}));
Tensor sQ = make_tensor(make_smem_ptr(plan.smem_q.data()), typename T::SmemLayoutQ{});
Tensor sV = make_tensor(make_smem_ptr(plan.smem_v.data()), typename T::SmemLayoutV{});
Tensor sK = make_tensor(make_smem_ptr(plan.smem_v.data()), typename T::SmemLayoutK{});
Tensor sP = make_tensor(make_smem_ptr(plan.smem_p.data()), typename T::SmemLayoutP{});
Tensor sVt = make_tensor(sV.data(), typename T::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename T::SmemLayoutVtransposedNoSwizzle{});
Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(plan.smem_row_max.data()), typename T::SmemLayoutRow{});
Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(plan.smem_row_sum.data()), typename T::SmemLayoutRow{});
typename T::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
typename T::TiledMma_O tiled_mma_o;
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
#if 1
typename T::GmemTiledCopyQ gmem_tiled_copy_Q;
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ)));
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ);
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQgQ)));
if (tidx < 128)
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, false>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
params.q_seq_per_hk - m_block * kBlockM);
__syncthreads();
auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ);
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
__syncthreads();
#else
auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tSrQ = thr_mma.partition_fragment_A(gQ);
Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ)));
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ);
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tSgQ)));
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tSgQ, tSrQ, tQcQ, tQpQ,
params.q_seq_per_hk - m_block * kBlockM);
__syncthreads();
#endif
// if (block0() && tidx < 64)
// {
// printf(" %.3f %.3f \n", float(tSrQ(0)), float(tSrQ(1)));
// }
auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSgK = smem_thr_copy_K.partition_S(gK);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
Tensor tSrK = thr_mma.partition_fragment_B(sK);
auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_DS_M32x16_B16, Element>{}, tiled_mma_o);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
Tensor tOrVt = thr_mma_o.partition_fragment_B(sVtNoSwizzle);
constexpr int n_masking_steps = !T::Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
int n_block = n_block_max - 1;
Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{});
clear(acc_o);
flash::Softmax<size<1>(acc_o)> softmax;
Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt);
Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
// Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt);
// Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
// Tensor tKcK_smem = smem_thr_copy_K.partition_S(cK);
Tensor tKpK_smem = make_tensor<bool>(make_shape(size<2>(tSgK)));
Tensor tSrK_smem = thr_mma.partition_fragment_B(gK);
constexpr static int k0_lds_loops = 15;
constexpr static int k0_loops = size<2>(tSrK_smem);
constexpr static int k1_loops = size<2>(tOrVt);
constexpr static int STAGE = 15;
for (int masking_step = 0; masking_step < n_masking_steps && n_block >= n_block_min; ++masking_step, --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s);
// asm volatile("s_barrier\n\t");
// 这个也做过循环2类似的修改,但是性能不如现在的好,所以保持不变
int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block;
// cur_block_table = block_table[n_block - 1];
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK.data() + (offset_k);
#pragma unroll
for (int i = 0; i < STAGE; i++) {
flash::lds_direct_copy<false, true>(gK, sK, i, params.k_row_stride, seqlen_k - n_block * kBlockN);
}
constexpr static int BUFFER_SIZE = 3;
uint128_t buffer[BUFFER_SIZE];
flash::buffer_load_copy<false, true, true, true>(gK, buffer[0], 15, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN);
flash::buffer_load_copy<false, true, true, true>(gK, buffer[1], 16, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN);
flash::buffer_load_copy<false, true, true, true>(gK, buffer[2], 17, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN);
// if constexpr (STAGE == 15)
{
int k_idx = 0;
// k_idx++;
asm volatile("s_waitcnt vmcnt(14 + 3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(13 + 3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(12 + 3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(11 + 3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(10 + 3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(9+ 3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(8+ 3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(7+ 3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t");
asm volatile("s_waitcnt vmcnt(6+ 3) \n\t s_barrier\n\t");
k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t");
asm volatile("s_waitcnt vmcnt(5 + 3) \n\t s_barrier\n\t");
k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t");
asm volatile("s_waitcnt vmcnt(4 + 3) \n\t s_barrier\n\t");
k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(3 + 3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
asm volatile("s_waitcnt vmcnt(2 + 3) \n\t s_barrier\n\t");
k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t");
asm volatile("s_waitcnt vmcnt(1 + 3) \n\t s_barrier\n\t");
k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t");
asm volatile("s_waitcnt vmcnt(0 + 3) \n\t s_barrier\n\t");
k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t");
}
asm volatile("s_waitcnt vmcnt(2) \n\t \n\t");
flash::buffer_to_tensor(buffer[0], tSrK_smem, 15);
cute::gemm(tiled_mma, tSrQ(_, _, 15), tSrK_smem(_, _, 15), acc_s);
asm volatile("s_waitcnt vmcnt(1) \n\t \n\t");
flash::buffer_to_tensor(buffer[1], tSrK_smem, 16);
cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s);
asm volatile("s_waitcnt vmcnt(0) \n\t \n\t");
flash::buffer_to_tensor(buffer[2], tSrK_smem, 17);
cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s);
// asm volatile("s_barrier\n\t");
// if (block0() && tidx < 64)
// {
// printf(" %.3f %.3f \n", acc_s(0), acc_s(1));
// }
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tScS = thr_mma.partition_C(cS);
for (int i = 0; i < size(acc_s); ++i) {
if constexpr (!T::Is_causal) {
if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY;
} else {
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int row = int(get<0>(tScS(i)));
int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.q_seq_per_hk - 1 - (m_block * kBlockM + row)) / params.q_head_per_hk;
if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY;
}
}
// We have key_padding_mask so we'll need to Check_inf
if constexpr (n_masking_steps == 1) {
softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/T::Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2);
}
else {
const bool is_first_masking_step = masking_step == 0;
is_first_masking_step
? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/T::Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2)
: softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/T::Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2);
}
Tensor rP = flash::convert_type<Element>(acc_s);
// Tensor tOrP = convert_layout_acc_Aregs(tiled_mma_o, rP, sP);
Tensor tOrP = flash::convert_layout_acc_Aregs_dense(tiled_mma, tiled_mma_o, rP, sP);
__syncthreads();
flash::lds_direct_copy<false, true>(gK, sK, 15, params.k_row_stride, seqlen_k - n_block * kBlockN);
// asm_ds_write(buffer[0], tVsV, 15);
// asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
gK.data() = gK.data() + (-offset_k);
#pragma unroll
for (int i = 0; i < k1_loops; i++) {
cute::copy(smem_tiled_copy_V, tOsVt(_, _, i), tOrVt_copy_view(_, _, i));
cute::gemm(tiled_mma_o, tOrP(_, _, i), tOrVt(_, _, i), acc_o);
}
// asm volatile("s_barrier\n\t");
}
for (; n_block >= n_block_min; --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s);
// asm volatile("s_barrier\n\t");
int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block;
// cur_block_table = block_table[n_block - 1];
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK.data() + (offset_k);
#pragma unroll
for (int i = 0; i < 16; i++) {
flash::lds_direct_copy<true, true>(gK, sK, i, params.k_row_stride, seqlen_k - n_block * kBlockN);
}
constexpr static int BUFFER_SIZE = 2;
uint128_t buffer[BUFFER_SIZE];
// buffer_load_copy<true, true, true, true>(gK, buffer[0], 15, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN);
flash::buffer_load_copy<true, true, true, true>(gK, buffer[0], 16, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN);
flash::buffer_load_copy<true, true, true, true>(gK, buffer[1], 17, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN);
// if constexpr (STAGE == 15)
{
int k_idx = 0;
// k_idx++;
asm volatile("s_waitcnt vmcnt(14 + 3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(13 + 3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(12 + 3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(11 + 3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(10 + 3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(9+ 3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(8+ 3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(7+ 3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t");
asm volatile("s_waitcnt vmcnt(6+ 3) \n\t s_barrier\n\t");
k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t");
asm volatile("s_waitcnt vmcnt(5 + 3) \n\t s_barrier\n\t");
k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t");
asm volatile("s_waitcnt vmcnt(4 + 3) \n\t s_barrier\n\t");
k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(3 + 3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
asm volatile("s_waitcnt vmcnt(2 + 3) \n\t s_barrier\n\t");
k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t");
asm volatile("s_waitcnt vmcnt(1 + 3) \n\t s_barrier\n\t");
k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t");
asm volatile("s_waitcnt vmcnt(0 + 3) \n\t s_barrier\n\t");
k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
asm volatile("s_waitcnt vmcnt(0 + 2) \n\t s_barrier\n\t");
k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
__builtin_amdgcn_sched_barrier(0);
flash::__ds_read_m32x16_row_col<3, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col<3, 1>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col<3, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col<3, 3>(tOsVt, tOrVt_copy_view);
__builtin_amdgcn_sched_barrier(0);
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t");
}
asm volatile("s_waitcnt vmcnt(1) \n\t \n\t");
flash::buffer_to_tensor(buffer[0], tSrK_smem, 16);
cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s);
asm volatile("s_waitcnt vmcnt(0) \n\t \n\t");
flash::buffer_to_tensor(buffer[1], tSrK_smem, 17);
cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s);
// asm volatile("s_barrier\n\t");
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
gK.data() = gK.data() + (-offset_k);
// We have key_padding_mask so we'll need to Check_inf
softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(acc_s);
// Tensor tOrP = convert_layout_acc_Aregs(tiled_mma_o, rP, sP);
Tensor tOrP = flash::convert_layout_acc_Aregs_dense(tiled_mma, tiled_mma_o, rP, sP);
flash::__ds_read_m32x16_row_col<0, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col<1, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col<2, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col<0, 1>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col<1, 1>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col<2, 1>(tOsVt, tOrVt_copy_view);
cute::gemm(tiled_mma_o, tOrP(_, _, 0), tOrVt(_, _, 0), acc_o);
cute::gemm(tiled_mma_o, tOrP(_, _, 1), tOrVt(_, _, 1), acc_o);
flash::__ds_read_m32x16_row_col<0, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col<1, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col<2, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col<0, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col<1, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col<2, 3>(tOsVt, tOrVt_copy_view);
cute::gemm(tiled_mma_o, tOrP(_, _, 2), tOrVt(_, _, 2), acc_o);
cute::gemm(tiled_mma_o, tOrP(_, _, 3), tOrVt(_, _, 3), acc_o);
// asm volatile("s_barrier\n\t");
}
using ElementAccum = float;
if (NoSplit)
{
using ElementO = Element;
const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_lse = (bidb * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM;
constexpr bool Split = false;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + ( row_offset_o)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(acc_o, sRow_sum_reduce_buffer, params.scale_softmax);
// if (block0() && tidx < 64)
// {
// printf(" %.3f %.3f \n", float(acc_o(0)), float(acc_o(1)));
// }
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lse)),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma_o.partition_C(caccO);
Tensor rO = flash::convert_type<ElementO>(acc_o);
if (get<1>(taccOcO(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO(0, mi, 0));
if (row < params.q_seq_per_hk - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
}
}
{
// using result_type = cutlass::Array<bfloat16_t, 2>;
// int tidx = threadIdx.x;
int col = 0;
int warpid = tidx / 64;
for (int m = 0; m < 1; m++) {
const int row = tidx % 16;
if (row < params.q_seq_per_hk - m_block * kBlockM) {
for (int n = 0; n < size<2>(acc_o); n++) {
col = (tidx % 64 / 16) + warpid * 32 + n * 128;
for (int ei = 0; ei < 8; ei ++) {
gOaccum(row, col) = rO(ei, m, n);
col += 4;
}
}
}
}
}
}
else
{
using ElementO = float;
int split_idx = params.num_splits_ptr[bidb] + n_split_idx;
constexpr bool Split = true;
const index_t row_offset_oaccum = ((split_idx*params.h_k + bidh)*params.q_seq_per_hk + m_block * kBlockM)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1)
const index_t row_offset_lseaccum = (split_idx*params.h_k + bidh)*params.q_seq_per_hk + m_block * kBlockM;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (row_offset_oaccum)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lseaccum)),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(acc_o, sRow_sum_reduce_buffer, params.scale_softmax);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma_o.partition_C(caccO);
if (get<1>(taccOcO(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO(0, mi, 0));
if (row < params.q_seq_per_hk - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
}
}
{
// using result_type = cutlass::Array<bfloat16_t, 2>;
// int tidx = threadIdx.x;
int col = 0;
int warpid = tidx / 64;
for (int m = 0; m < 1; m++) {
const int row = tidx % 16;
if (row < params.q_seq_per_hk - m_block * kBlockM) {
for (int n = 0; n < size<2>(acc_o); n++) {
col = (tidx % 64 / 16) + warpid * 32 + n * 128;
for (int ei = 0; ei < 8; ei ++) {
gOaccum(row, col) = acc_o(ei, m, n);
col += 4;
}
}
}
}
}
}
}
template<typename T> template<typename T>
__global__ void __launch_bounds__(T::NUM_THREADS, 1) __global__ void __launch_bounds__(T::NUM_THREADS, 1)
flash_fwd_splitkv_mla_kernel(const DenseAttnDecodeParams params) { flash_fwd_splitkv_mla_kernel(const DenseAttnDecodeParams params) {
const int m_block = blockIdx.x;
const int bidh = blockIdx.y;
const int partition_idx = blockIdx.z;
DecodingSchedMeta sched_meta = params.tile_scheduler_metadata_ptr[partition_idx];
if (sched_meta.begin_req_idx >= params.b) return;
for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) {
constexpr int kBlockN = T::PAGE_BLOCK_SIZE;
const int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0;
int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx);
const int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0;
int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : cute::ceil_div(seqlen_k, kBlockN);
const bool is_no_split = batch_idx == sched_meta.begin_req_idx ? !sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? !sched_meta.is_last_req_splitted : true);
if (batch_idx > sched_meta.begin_req_idx) {
__syncthreads();
}
compute_attn_1rowblock_splitkv_mla_gfx936<T>(params, batch_idx, bidh, m_block, n_split_idx,
seqlen_k, start_block_idx, end_block_idx, is_no_split
);
}
} }
...@@ -29,15 +605,19 @@ void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams &params) { ...@@ -29,15 +605,19 @@ void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams &params) {
FLASH_ASSERT(params.d == Config::HEAD_DIM_K); FLASH_ASSERT(params.d == Config::HEAD_DIM_K);
FLASH_ASSERT(params.d_v == Config::HEAD_DIM_V); FLASH_ASSERT(params.d_v == Config::HEAD_DIM_V);
using T = Traits<InputT>;
auto shape_Q = make_shape(params.q_seq_per_hk, params.d, params.h_k, params.b);
auto mla_kernel = &flash_fwd_splitkv_mla_kernel<T>;
constexpr size_t smem_size = sizeof(typename T::SharedMemoryPlan); constexpr size_t smem_size = 65536;
// Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch) // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
using T = Traits<InputT, Is_causal>;
const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M); const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M);
auto mla_kernel = &flash_fwd_splitkv_mla_kernel<T>;
mla_kernel<<<dim3(num_m_block, params.h_k, params.num_sm_parts), T::NUM_THREADS, smem_size, params.stream>>>(params);
});
// cudaLaunchConfig_t mla_kernel_config = { // cudaLaunchConfig_t mla_kernel_config = {
// dim3(num_m_block, params.h_k, params.num_sm_parts), // dim3(num_m_block, params.h_k, params.num_sm_parts),
// dim3(T::NUM_THREADS, 1, 1), // dim3(T::NUM_THREADS, 1, 1),
......
...@@ -7,13 +7,12 @@ ...@@ -7,13 +7,12 @@
#include "config.h" #include "config.h"
using TMABarrier = cutlass::arch::ClusterTransactionBarrier;
using namespace cute; using namespace cute;
template<typename InputT_> template<typename InputT_, bool Is_causal_>
struct Traits { struct Traits {
using InputT = InputT_; using InputT = InputT_;
static constexpr bool Is_causal = Is_causal_;
static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M; static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M;
static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE; static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE;
static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K; static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K;
...@@ -23,63 +22,105 @@ struct Traits { ...@@ -23,63 +22,105 @@ struct Traits {
static_assert(std::is_same_v<InputT, cutlass::bfloat16_t> || std::is_same_v<InputT, cutlass::half_t>); static_assert(std::is_same_v<InputT, cutlass::bfloat16_t> || std::is_same_v<InputT, cutlass::half_t>);
using TiledMMA_QK_sQ = decltype(make_tiled_mma( static constexpr int kBlockM = BLOCK_SIZE_M;
GMMA::ss_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>, GMMA::Major::K, GMMA::Major::K>(), static constexpr int kBlockN = PAGE_BLOCK_SIZE;
Layout<Shape<_1, _1, _1>>{} static constexpr int kHeadDim = HEAD_DIM_K;
)); static constexpr int kHeadDimV = HEAD_DIM_V;
static constexpr int kNWarps = 4;
using TiledMMA_QK_rQ = decltype(make_tiled_mma(
GMMA::rs_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>, GMMA::Major::K, GMMA::Major::K>(),
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
GMMA::rs_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_V/2>, Int<PAGE_BLOCK_SIZE>>, GMMA::Major::K, GMMA::Major::MN>(),
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_RemoteP = decltype(make_tiled_mma( using Element = InputT;
GMMA::ss_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_V/2>, Int<PAGE_BLOCK_SIZE>>, GMMA::Major::K, GMMA::Major::MN>(), using elem_type = Element;
Layout<Shape<_1, _1, _1>>{} using ElementAccum = float;
));
using SmemLayoutRow = Layout<Shape<_128>, Stride<_1>>;
using SmemLayoutAtomK = decltype(composition(
Swizzle<3, 3, 3>{},
Layout<Shape<Int<8>, Int<32>>, Stride<Int<32>, _1>>{}));
using SmemLayoutK = decltype(tile_to_shape(
SmemLayoutAtomK{},
Shape<Int<kBlockN>, Int<16 * 32>>{}));
using SmemLayoutK_place_holder = decltype(tile_to_shape(
SmemLayoutAtomK{},
Shape<Int<kBlockN>, Int<15 * 32>>{}));
using SmemLayoutAtomV = SmemLayoutAtomK;
using SmemLayoutV = decltype(tile_to_shape(
SmemLayoutAtomV{},
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
using SmemLayoutAtomP = Layout<Shape<Int<4*16*16>>, Stride<Int<1>>>;
using SmemLayoutP = decltype(tile_to_shape(
SmemLayoutAtomP{},
Shape<Int<4*16*16>>{}));
using SmemLayoutVtransposed = decltype(
composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
using SmemLayoutAtomQ = decltype(composition(
Swizzle<3, 3, 3>{},
Layout<Shape<Int<8>, Int<64>>, Stride<Int<64>, _1>>{}));
using SmemLayoutQ = decltype(tile_to_shape( using SmemLayoutQ = decltype(tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{}, SmemLayoutAtomQ{},
Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_K>>{} Shape<Int<kBlockM>, Int<kHeadDim>>{}));
)); using ValLayoutMNK = Layout<Shape<_1, _1, _1>>;
// #if defined(__gfx936__) || defined(__gfx938__)
using MMA_Atom_Arch = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<GFX928_16x16x32_F32F16F16F32_NT>,
MMA_Atom<GFX928_16x16x32_F32BF16BF16F32_NT>
>;
using TiledMma = TiledMMA<
MMA_Atom_Arch,
Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
// #elif defined(__gfx928__)
// using MMA_Atom_Arch = std::conditional_t<
// std::is_same_v<elem_type, cutlass::half_t>,
// MMA_Atom<GFX928_16x16x32_F32F16F16F32_NT>,
// MMA_Atom<GFX928_16x16x32_F32BF16BF16F32_NT>
// >;
// using TiledMma = TiledMMA<
// MMA_Atom_Arch,
// Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
// ValLayoutMNK>;
// #endif
using MMA_Atom_Arch_16x32 = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<GFX928_16x32x16_F32F16F16F32_NT>,
MMA_Atom<GFX928_16x32x16_F32BF16BF16F32_NT>
>;
using TiledMma_O = TiledMMA<
MMA_Atom_Arch_16x32,
Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
using GmemLayoutAtomQ = Layout<Shape <_32, _8>,
Stride< _8, _1>>;
using GmemTiledCopyQ = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomQ{},
Layout<Shape<_1, _8>>{}));
using SmemLayoutK = decltype(tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{},
Shape<Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>{}
));
using SmemLayoutV = decltype(composition(
SmemLayoutK{},
make_layout(Shape<Int<HEAD_DIM_V>, Int<PAGE_BLOCK_SIZE>>{}, GenRowMajor{})
)); // A transposed version of SmemLayoutK
using SmemLayoutP0 = decltype(tile_to_shape( struct SharedMemoryPlan {
GMMA::Layout_K_SW128_Atom<InputT>{}, union {
Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>>{} struct {
)); cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v; // Double buffer
using rP0Layout = decltype(layout(partition_fragment_C( };
TiledMMA_QK_sQ{}, struct {
Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>>{} cute::array_aligned<Element, cute::cosize_v<SmemLayoutK_place_holder>> smem_temp; // Double buffer
))); cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_sum;
cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_max;
};
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
};
struct SharedMemoryPlan { };
cute::array_aligned<InputT, cosize_v<SmemLayoutQ>> smem_sQ;
cute::array_aligned<InputT, cosize_v<SmemLayoutK>> smem_sK0;
cute::array_aligned<InputT, cosize_v<SmemLayoutK>> smem_sK1;
cute::array_aligned<InputT, cosize_v<SmemLayoutP0>> smem_sP0;
cute::array_aligned<float, BLOCK_SIZE_M> smem_sM;
cute::array_aligned<float, 2*BLOCK_SIZE_M> sL_reduction_wksp;
cute::array_aligned<float, BLOCK_SIZE_M> smem_sScale0;
cute::array_aligned<float, BLOCK_SIZE_M> smem_sScale1;
TMABarrier barriers_K0[HEAD_DIM_K/64];
TMABarrier barriers_K1[HEAD_DIM_K/64];
TMABarrier barrier_Q;
}; };
}; };
......
...@@ -88,6 +88,18 @@ struct RingBufferState { ...@@ -88,6 +88,18 @@ struct RingBufferState {
} }
}; };
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
namespace flash { namespace flash {
using namespace cute; using namespace cute;
...@@ -559,5 +571,170 @@ lds_direct_copy_for_prefill_sparse_mla( ...@@ -559,5 +571,170 @@ lds_direct_copy_for_prefill_sparse_mla(
} }
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool mma_layout = false,
bool use_asm = false,
class SrcEngine, class SrcLayout
>
CUTE_HOST_DEVICE
void
buffer_load_copy(
Tensor<SrcEngine, SrcLayout> const& src,
uint128_t & dst,
int k_idx_, const int row_stride,
int offset_k,
const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 2;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
constexpr int elements_per_thread = 8;
if constexpr (mma_layout)
{
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
int mma_k = 32*64;
int row = tidx % 16;
int col = lane / 16;
int row_offset = row + (warp_id * 16) ;
int col_offset = col * elements_per_thread + k_idx * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if constexpr(use_asm) {
asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr)
);
}
else {
auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
dst = *reinterpret_cast<uint128_t*>(&res);
}
}
else
{
uint32x4_t global_addr = {0};
*(uint64_t*)&global_addr = reinterpret_cast<uint64_t>(src.data().get());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
int mma_k = 32*64;
int row = tidx / 4;
int col = lane % 4;
int row_offset = row;
int col_offset = col * elements_per_thread + k_idx * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if constexpr(use_asm) {
asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr)
);
}
else {
auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
dst = *reinterpret_cast<uint128_t*>(&res);
}
}
}
template<
class SrcEngine, class SrcLayout>
CUTE_HOST_DEVICE
void
buffer_to_tensor(const uint128_t & src, Tensor<SrcEngine, SrcLayout> & dst, int k_idx)
{
uint128_t* d = reinterpret_cast<uint128_t*>(&dst(0, 0, k_idx));
d[0] = src;
}
template <class TiledMma, class TiledMma_O,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1
>
__forceinline__ __device__ auto convert_layout_acc_Aregs_dense(const TiledMma& tiled_mma, const TiledMma_O& tiled_mma_o, Tensor<Engine0, Layout0> const& tOrP,
Tensor<Engine1, Layout1> const& sAcc)
{
using Value_type = typename Engine0::value_type;
int tid = threadIdx.x % 64;
int warp_id = threadIdx.x / 64;
// __fp16 *smem_ptr =
// sAcc((tid % 16 ) * 4 + (tid / 16) + warp_id * 16 * 16) = tOrP(0, 0, 0);
// sAcc((tid % 16 ) * 4 + (tid / 16) + 16 * 4 + warp_id * 16 * 16) = tOrP(1, 0, 0);
// sAcc((tid % 16 ) * 4 + (tid / 16) + 2 * 16 * 4 + warp_id * 16 * 16) = tOrP(2, 0, 0);
// sAcc((tid % 16 ) * 4 + (tid / 16) + 3 * 16 * 4 + warp_id * 16 * 16) = tOrP(3, 0, 0);
sAcc((tid % 16 ) * 8 + (tid / 16) + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(0, 0, 0);
sAcc((tid % 16 ) * 8 + (tid / 16) + 1 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(1, 0, 0);
sAcc((tid % 16 ) * 8 + (tid / 16) + 2 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(2, 0, 0);
sAcc((tid % 16 ) * 8 + (tid / 16) + 3 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(3, 0, 0);
__syncthreads();
using SmemLayoutAtomP = Layout<Shape<Int<16>, Int<64>>, Stride<Int<64>, _1>>;
using SmemLayoutP = decltype(tile_to_shape(
SmemLayoutAtomP{},
Shape<Int<16>, Int<64>>{}));
Tensor sP_tmp = make_tensor(sAcc.data(),SmemLayoutP{});
auto thr_mma = tiled_mma_o.get_thread_slice(tid);
Tensor tSrACC = thr_mma.partition_fragment_A(sP_tmp);
tSrACC(0, 0, 0) = sAcc(tid * 8 + 0);
tSrACC(1, 0, 0) = sAcc(tid * 8 + 1);
tSrACC(2, 0, 0) = sAcc(tid * 8 + 2);
tSrACC(3, 0, 0) = sAcc(tid * 8 + 3);
tSrACC(0, 0, 1) = sAcc(tid * 8 + 0 + 4);
tSrACC(1, 0, 1) = sAcc(tid * 8 + 1 + 4);
tSrACC(2, 0, 1) = sAcc(tid * 8 + 2 + 4);
tSrACC(3, 0, 1) = sAcc(tid * 8 + 3 + 4);
tSrACC(0, 0, 2) = sAcc(tid * 8 + 0 + 16*32);
tSrACC(1, 0, 2) = sAcc(tid * 8 + 1 + 16*32);
tSrACC(2, 0, 2) = sAcc(tid * 8 + 2 + 16*32);
tSrACC(3, 0, 2) = sAcc(tid * 8 + 3 + 16*32);
tSrACC(0, 0, 3) = sAcc(tid * 8 + 0 + 4 + 16*32);
tSrACC(1, 0, 3) = sAcc(tid * 8 + 1 + 4 + 16*32);
tSrACC(2, 0, 3) = sAcc(tid * 8 + 2 + 4 + 16*32);
tSrACC(3, 0, 3) = sAcc(tid * 8 + 3 + 4 + 16*32);
return tSrACC;
}
} }
\ No newline at end of file
...@@ -223,9 +223,10 @@ def main(torch_dtype): ...@@ -223,9 +223,10 @@ def main(torch_dtype):
] ]
performance_cases = [ performance_cases = [
TestParam(128, s_q, s_k, is_varlen=True, is_causal=is_causal, test_performance=True) TestParam(128, s_q, s_k, is_varlen=True, is_causal=is_causal, h_q = h_q, test_performance=True)
for is_causal in [False, True] for is_causal in [False, True]
for s_q in [1, 2] for s_q in [1, 2]
for h_q in [16, 128]
for s_k in [4096, 8192, 16384, 32768] for s_k in [4096, 8192, 16384, 32768]
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment