Commit aec17474 authored by zhanghj2's avatar zhanghj2
Browse files

Feature/kimi nhead64 dense

parent a45f646b
......@@ -92,7 +92,7 @@ dense_attn_decode_interface(
KU_CHECK_CONTIGUOUS(out);
KU_CHECK_CONTIGUOUS(lse);
if (!tile_scheduler_metadata.has_value()) {
if (!tile_scheduler_metadata.has_value() && ((num_heads_q < 64 && num_heads_k == 1) || num_heads_k > 1)) {
tile_scheduler_metadata = torch::empty({num_sm_parts, sizeof(DecodingSchedMeta)/4}, opts.dtype(torch::kInt32));
num_splits = torch::empty({batch_size+1}, opts.dtype(torch::kInt32));
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
......@@ -125,20 +125,6 @@ dense_attn_decode_interface(
if (const char* val = std::getenv("FLASH_MLA_PRINT_PARAM")) {
print_param = (std::string(val) == "1");
}
if (print_param) {
fprintf(stderr, "[FlashMLA] [dense_attn_decode_interface] [%s] batch_size = %d seqlen_q_ori = %d "
"num_heads_q = %d head_size_k = %d max_num_blocks_per_seq = %d num_blocks %d page_block_size = %d num_heads_k = %d \n",
arch.archName.c_str(),
batch_size,
seqlen_q_ori,
num_heads_q,
head_size_k,
max_num_blocks_per_seq,
num_blocks,
page_block_size,
num_heads_k
);
}
// Set the sizes
DenseAttnDecodeParams params;
params.b = batch_size;
......@@ -174,10 +160,10 @@ dense_attn_decode_interface(
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;
if ((num_heads_q < 64 && num_heads_k == 1) || num_heads_k > 1) {
params.tile_scheduler_metadata_ptr = (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr();
params.num_sm_parts = num_sm_parts;
params.num_splits_ptr = num_splits->data_ptr<int>();
const int total_num_splits = batch_size + params.num_sm_parts;
at::Tensor lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat));
......@@ -186,7 +172,65 @@ dense_attn_decode_interface(
params.total_num_splits = total_num_splits;
params.softmax_lseaccum_ptr = lse_accum.data_ptr<float>();
params.oaccum_ptr = out_accum.data_ptr<float>();
params.use_split_kv = false;
} else {
bool use_split_kv = true;
int num_m_blocks = (params.q_seq_per_hk + 64 - 1) / 64;
int num_sms = arch.num_sms;
int num_splits = num_sms * 3 / (num_m_blocks * params.b);
if (max_num_blocks_per_seq >= 32768 / 64) {
num_splits = 32;
} else if (max_num_blocks_per_seq >= 16384 / 64) {
num_splits = 32;
} else if (max_num_blocks_per_seq >= 8192 / 64) {
num_splits = 16;
} else if (max_num_blocks_per_seq >= 4096 / 64) {
num_splits = 8;
} else if (max_num_blocks_per_seq >= 2048 / 64) {
num_splits = 4;
} else {
num_splits = 1;
}
if (params.b >= 128) {
num_splits = 1;
}
if (num_splits <= 1) {
use_split_kv = false;
}
else {
num_splits = std::min(num_splits, 240);
params.partition_block_nums = max_num_blocks_per_seq / num_splits;
}
if (params.partition_block_nums <= 4) {
use_split_kv = false;
}
params.use_split_kv = use_split_kv;
params.total_num_splits = params.b * num_splits;
at::Tensor lse_accum = torch::empty({params.total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({params.total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat));
KU_CHECK_CONTIGUOUS(lse_accum);
KU_CHECK_CONTIGUOUS(out_accum);
params.softmax_lseaccum_ptr = lse_accum.data_ptr<float>();
params.oaccum_ptr = out_accum.data_ptr<float>();
}
if (print_param) {
fprintf(stderr, "[FlashMLA] [dense_attn_decode_interface] [%s] batch_size = %d seqlen_q_ori = %d "
"num_heads_q = %d head_size_k = %d max_num_blocks_per_seq = %d num_blocks %d page_block_size = %d num_heads_k = %d use_split_kv = %d num_splits %d params.partition_block_nums = %d \n",
arch.archName.c_str(),
batch_size,
seqlen_q_ori,
num_heads_q,
head_size_k,
max_num_blocks_per_seq,
num_blocks,
page_block_size,
num_heads_k,
params.use_split_kv,
params.total_num_splits / params.b,
params.partition_block_nums
);
}
params.stream = at::cuda::getCurrentCUDAStream().stream();
if (q_dtype == torch::kBFloat16) {
......@@ -220,18 +264,24 @@ dense_attn_decode_interface(
params.num_sm_parts,
nullptr,
at::cuda::getCurrentCUDAStream().stream()
at::cuda::getCurrentCUDAStream().stream(),
params.use_split_kv,
params.total_num_splits / params.b,
params.seqlens_k_ptr,
params.partition_block_nums
};
if ((num_heads_q < 64 && num_heads_k == 1) || num_heads_k > 1 || params.use_split_kv) {
if (q_dtype == torch::kBFloat16) {
gfx9::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params);
} else if (q_dtype == torch::kHalf) {
#ifndef FLASH_MLA_DISABLE_FP16
#ifndef FLASH_MLA_DISABLE_FP16
gfx9::decode::run_flash_mla_combine_kernel<cutlass::half_t>(combine_params);
#endif
#endif
} else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}
}
out = out.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk, head_size_v}).transpose(1, 2)
.reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v});
......
......@@ -16,7 +16,7 @@ using namespace cute;
namespace gfx9::decode {
template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS>
template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS, bool USE_SPLIT_KV=false>
__global__ void __launch_bounds__(NUM_THREADS, 1)
flash_fwd_mla_combine_kernel(const CombineParams params) {
// grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M]
......@@ -33,12 +33,36 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
return;
}
const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx);
const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1);
const int my_num_splits = end_split_idx - start_split_idx;
int start_split_idx;
int end_split_idx;
int my_num_splits;
if constexpr (USE_SPLIT_KV)
{
start_split_idx = batch_idx * params.num_splits;
end_split_idx = (batch_idx + 1) * params.num_splits;
int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx);
end_split_idx = std::min(cute::ceil_div(cute::ceil_div(seqlen_k, 64), params.partition_block_nums), params.num_splits) + start_split_idx;
// if (lane_idx == 0 && batch_idx == 61)
// {
// printf(" batch_idx = %d start_split_idx = %d end_split_idx = %d seqlen_k = %d \n",batch_idx, start_split_idx, end_split_idx, seqlen_k);
// }
my_num_splits = end_split_idx - start_split_idx;
if (my_num_splits == 1) {
return;
}
}
else
{
start_split_idx = __ldg(params.num_splits_ptr + batch_idx);
end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1);
my_num_splits = end_split_idx - start_split_idx;
if (my_num_splits == 1) {
return;
}
}
// FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
......@@ -245,6 +269,9 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
} else if (NUM_SPLITS <= 160) { \
constexpr static int NAME = 160; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 240) { \
constexpr static int NAME = 240; \
return __VA_ARGS__(); \
} else { \
FLASH_ASSERT(false); \
} \
......@@ -255,29 +282,33 @@ template<typename ElementT>
void run_flash_mla_combine_kernel(CombineParams &params) {
static constexpr int HEAD_DIM_V = 512; // Since only this head dimension is supported by Flash MLA
FLASH_ASSERT(params.d_v == HEAD_DIM_V);
if (params.use_split_kv)
{
MLA_NUM_SPLITS_SWITCH(params.num_splits, NUM_SPLITS, [&] {
constexpr int BLOCK_SIZE_M = 4;
constexpr int NUM_THREADS = BLOCK_SIZE_M*64;
constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float);
auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS, true>;
combine_kernel<<<dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
NUM_THREADS,
smem_size,
params.stream>>>(params);
});
}
else
{
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] {
constexpr int BLOCK_SIZE_M = 4;
constexpr int NUM_THREADS = BLOCK_SIZE_M*64;
constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float);
auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>;
// CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
// cudaLaunchAttribute attribute[1];
// attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
// attribute[0].val.programmaticStreamSerializationAllowed = 1;
// cudaLaunchConfig_t combine_kernel_config = {
// dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
// dim3(NUM_THREADS, 1, 1),
// 0,
// params.stream,
// attribute,
// 1
// };
combine_kernel<<<dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
NUM_THREADS,
smem_size,
params.stream>>>(params);
});
}
CHECK_CUDA_KERNEL_LAUNCH();
}
......
......@@ -10,6 +10,613 @@ using namespace cute;
namespace gfx93 {
// template<typename T, bool use_split_kv=false>
// __device__ void
// compute_attn_1rowblock_splitkv_mla_block_m_64_gfx936(const DenseAttnDecodeParams& params,
// const int bidb, const int bidh, const int m_block,
// const int n_split_idx, const int seqlen_k,
// const int n_block_min, const int n_block_max, const bool NoSplit)
// {
// extern __shared__ char shared_memory[];
// const int tidx = threadIdx.x;
// constexpr int kBlockM = T::BLOCK_SIZE_M;
// constexpr int kBlockN = T::PAGE_BLOCK_SIZE;
// constexpr int kHeadDim = T::HEAD_DIM_K;
// constexpr int kHeadDimV = T::HEAD_DIM_V;
// using Element = T::InputT;
// using index_t = int64_t;
// const int warp_idx = __builtin_amdgcn_readfirstlane(tidx / 64);
// const int lane_idx = tidx % 64;
// Element* q_lds = (Element*)&(shared_memory);
// Element* k_lds = q_lds;
// Element* v_lds = q_lds;
// const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
// Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
// Shape<Int<kBlockM>, Int<kHeadDim>>{},
// make_stride(params.q_row_stride, _1{}));
// // const index_t row_offset_k = (0) * params.k_head_stride;
// // Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
// // Shape<Int<kBlockN>, Int<kHeadDim>>{},
// // make_stride(params.k_row_stride, _1{}));
// typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8)));
// typedef __bf16 __fp16x4_t __attribute__((ext_vector_type(4)));
// typedef __bf16 __fp16x2_t __attribute__((ext_vector_type(2)));
// union Bf16_storage {
// __fp16x8_t data_128;
// __fp16x4_t data_64[2];
// __fp16x2_t data_32[4];
// uint16_t data_array[8];
// };
// union Bf16_storage_x4 {
// __fp16x4_t data_64;
// __fp16x2_t data_32[2];
// uint16_t data[4];
// };
// struct PtrWrapper {
// uint32_t former;
// uint32_t latter;
// };
// PtrWrapper glob_ptr_q;
// *(uint64_t*)&glob_ptr_q = reinterpret_cast<uint64_t>(gQ.data().get());
// glob_ptr_q.latter |= ((params.q_row_stride * 2) << 16);
// glob_ptr_q.latter |= 0x40000000;
// uint32x4_t global_addr_q = {0};
// global_addr_q[0] = (glob_ptr_q.former);
// global_addr_q[1] = (glob_ptr_q.latter);
// global_addr_q[2] = params.q_seq_per_hk - m_block * kBlockM;
// global_addr_q[3] = 0x00020000;
// int virtual_row_ = lane_idx / 8;//0
// int virtual_col_ = lane_idx % 8;//0
// int swizzle_col_ = virtual_row_ ^ virtual_col_;
// int row_ = lane_idx / 4;//0
// // 8->9 9->8
// // row_ = (row_ >= 8 ) ^ row_;
// int col_ = swizzle_col_ % 4;
// auto calc_row_and_col_k = [&]() -> std::tuple<int, int> {
// constexpr int elements_per_thread = 8;
// #if defined(__gfx938__)
// int row_offset = row_ + warp_idx * 16;
// int col_offset = col_ * 8;
// #else
// int row_offset = row_ * 4 + warp_idx;
// int col_offset = col_ * 8;
// #endif
// return {row_offset, col_offset};
// };
// auto buffer_load_lds_k = [&](int row_offset, int col, int k_idx, int block_idx, index_t offset_k) {
// constexpr int element_size = 2;
// PtrWrapper glob_ptr_k;
// *(uint64_t*)&glob_ptr_k = reinterpret_cast<uint64_t>(params.k_ptr) + offset_k * 2;
// glob_ptr_k.latter |= ((params.k_row_stride * 2) << 16);
// glob_ptr_k.latter |= 0x40000000;
// uint32x4_t global_addr_k = {0};
// global_addr_k[0] = __builtin_amdgcn_readfirstlane(glob_ptr_k.former);
// global_addr_k[1] = __builtin_amdgcn_readfirstlane(glob_ptr_k.latter);
// global_addr_k[2] = seqlen_k - block_idx * kBlockN;
// global_addr_k[3] = 0x00020000;
// constexpr int elements_per_thread = 8;
// int col_offset = col;
// int offset_v = col_offset * 2;
// int ldsAddrPerWave = reinterpret_cast<size_t>(k_lds) + warp_idx * 16 * 32 * 2 + (k_idx % 4) * 64 * 32 * 2;
// typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
// uint32x2_t index_offset = {0};
// index_offset[0] = row_offset;
// index_offset[1] = offset_v;
// const int offset_s = k_idx * 32 * 2;
// __builtin_amdgcn_sched_barrier(0);
// asm volatile(
// "s_mov_b32 m0, %1 \n\t"
// "s_nop 0 \n\t"
// "buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset),
// "s"(ldsAddrPerWave), "s"(global_addr_k), "s"(offset_s)
// :);
// __builtin_amdgcn_sched_barrier(0);
// };
// auto k_lds_read_offset = [&] () -> int {
// int row = lane_idx % 16;
// int col = lane_idx / 16;
// col = (row / 2) ^ col;
// col = col % 4;
// // row = (row >= 8) ^ row;
// const auto lds_offset = row * 32 + col * 8;
// // #endif
// return lds_offset;
// };
// auto calc_row_and_col_v = [&](int i) -> int {
// int row = lane_idx / 4;
// // int col = lane_idx % 4;
// int row_offset = row + i * 16;
// // int col_offset = col * 8 + warp_idx * 32;
// return row_offset;
// };
// const int v_lds_read_ptr = reinterpret_cast<size_t>(v_lds + lane_idx * 8);
// Element* k_lds_read_ptr = (k_lds + k_lds_read_offset());
// int col_offset_v = (lane_idx % 4) * 8 + warp_idx * 32;
// auto buffer_load_lds_v = [&](int row_offset, int col, int k_idx, int n_idx, int block_idx, index_t offset_k) {
// constexpr int element_size = 2;
// PtrWrapper glob_ptr_k;
// *(uint64_t*)&glob_ptr_k = reinterpret_cast<uint64_t>(params.k_ptr) + offset_k * 2;
// glob_ptr_k.latter |= ((params.k_row_stride * 2) << 16);
// glob_ptr_k.latter |= 0x40000000;
// uint32x4_t global_addr_k = {0};
// global_addr_k[0] = __builtin_amdgcn_readfirstlane(glob_ptr_k.former);
// global_addr_k[1] = __builtin_amdgcn_readfirstlane(glob_ptr_k.latter);
// global_addr_k[2] = seqlen_k - block_idx * kBlockN;
// global_addr_k[3] = 0x00020000;
// constexpr int elements_per_thread = 8;
// int col_offset = col;
// // int v_idx = row_offset;
// int offset_v = col_offset * 2;
// int ldsAddrPerWave = reinterpret_cast<size_t>(v_lds) + warp_idx * 16 * 32 * 2 + (k_idx) * 128 * 16 * 2;
// typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
// uint32x2_t index_offset = {0};
// index_offset[0] = row_offset;
// index_offset[1] = offset_v;
// const int offset_s = n_idx * 128 * 2;
// __builtin_amdgcn_sched_barrier(0);
// asm volatile(
// "s_mov_b32 m0, %1 \n\t"
// "s_nop 0 \n\t"
// "buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset),
// "s"(ldsAddrPerWave), "s"(global_addr_k), "s"(offset_s)
// :);
// __builtin_amdgcn_sched_barrier(0);
// };
// Bf16_storage q_reg[18];
// for (int i = 0; i < 18; i++)
// {
// constexpr int elements_per_thread = 8;
// int row = lane_idx % 16;
// int col = lane_idx / 16;
// int row_offset = row + warp_idx * 16;
// int col_offset = col * 8;
// int offset_v = col_offset * 2 + i * 32 * 2;
// q_reg[i].data_128 = __builtin_amdgcn_buffer_load_dwordx4(global_addr_q, row_offset, offset_v, false, false);
// }
// __syncthreads();
// v4f acco_f32[32];
// for (int i = 0; i < 32; i++)
// {
// acco_f32[i].x = 0.0f;
// acco_f32[i].y = 0.0f;
// acco_f32[i].z = 0.0f;
// acco_f32[i].w = 0.0f;
// }
// const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
// auto float2bf16 = [] (float s) -> uint16_t {
// uint32_t x32 = reinterpret_cast<uint32_t const &>(s);
// #ifndef FLASH_MLA_BF16_TYPE
// #define FLASH_MLA_BF16_TYPE 0
// #endif
// #if FLASH_MLA_BF16_TYPE == 1
// x32 += 0x8000u;
// #endif
// return uint16_t(x32 >> 16);
// };
// // int block_idx = 0;
// // int cur_block_table = block_table[block_idx];
// // index_t offset_k = (index_t)(cur_block_table) * params.k_batch_stride;
// // auto [row_offset, col] = calc_row_and_col_k(block_idx);
// // buffer_load_lds_k(row_offset, col, 0, offset_k);
// // __syncthreads();
// {
// // if (thread0())
// // {
// // int k = 0;
// // for (int i = 0; i < 64; i++)
// // {
// // for (int j = 0; j < 32; j++)
// // {
// // printf(" %.3f ", float(k_lds[k]));
// // k++;
// // }
// // printf("\n");
// // }
// // }
// // if (block0() && threadIdx.x < 64)
// // {
// // cutlass::bfloat16_t q[8];
// // for (int i = 0; i < 8; i++)
// // {
// // q[i].storage = v_reg[0].data_array[i];
// // // q[i].storage = q_reg[0].data_array[i];
// // }
// // printf("tidx %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %d\n ", threadIdx.x,
// // float(q[0]),
// // float(q[1]),
// // float(q[2]),
// // float(q[3]),
// // float(q[4]),
// // float(q[5]),
// // float(q[6]),
// // float(q[7]),
// // v_lds_read_ptr
// // );
// // }
// }
// struct IsMaskBlock {};
// struct IsFirstMaskBlock {};
// struct IsNoMaskBlock {};
// flash::Softmax<1> softmax;
// auto process_one_block = [&] (int block_idx, auto is_mask_block_t) {
// static constexpr bool IS_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsNoMaskBlock>;
// static constexpr bool IS_FIRST_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsFirstMaskBlock>;
// static constexpr bool IS_NO_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsNoMaskBlock>;
// int cur_block_table = block_table[block_idx];
// v4f accs_f32[4];
// for (int i = 0; i < 4; i++)
// {
// accs_f32[i].x = 0.0f;
// accs_f32[i].y = 0.0f;
// accs_f32[i].z = 0.0f;
// accs_f32[i].w = 0.0f;
// }
// index_t offset_k = (index_t)(cur_block_table) * params.k_batch_stride;
// auto [row_offset, col] = calc_row_and_col_k();
// #define LOAD_K_AND_QK_GEMM(k) \
// { \
// constexpr int k_val = (k); \
// buffer_load_lds_k(row_offset, col, k_val - 3, block_idx, offset_k); \
// flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \
// __builtin_amdgcn_sched_barrier(0); \
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
// __builtin_amdgcn_sched_barrier(0); \
// }
// {
// constexpr int k_val = (17);
// buffer_load_lds_k(row_offset, col, k_val, block_idx, offset_k);
// buffer_load_lds_k(row_offset, col, k_val - 1, block_idx, offset_k);
// buffer_load_lds_k(row_offset, col, k_val - 2, block_idx, offset_k);
// buffer_load_lds_k(row_offset, col, k_val - 3, block_idx, offset_k);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// LOAD_K_AND_QK_GEMM(16);
// LOAD_K_AND_QK_GEMM(15);
// LOAD_K_AND_QK_GEMM(14);
// LOAD_K_AND_QK_GEMM(13);
// LOAD_K_AND_QK_GEMM(12);
// LOAD_K_AND_QK_GEMM(11);
// LOAD_K_AND_QK_GEMM(10);
// LOAD_K_AND_QK_GEMM(9);
// LOAD_K_AND_QK_GEMM(8);
// LOAD_K_AND_QK_GEMM(7);
// LOAD_K_AND_QK_GEMM(6);
// LOAD_K_AND_QK_GEMM(5);
// LOAD_K_AND_QK_GEMM(4);
// LOAD_K_AND_QK_GEMM(3);
// flash::qk_gemm<Element, k_val - 15>(q_reg[k_val - 15].data_128, k_lds_read_ptr, accs_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// flash::qk_gemm<Element, k_val - 16>(q_reg[k_val - 16].data_128, k_lds_read_ptr, accs_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// flash::qk_gemm<Element, k_val - 17>(q_reg[k_val - 17].data_128, k_lds_read_ptr, accs_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// }
// // if (block0() && tidx < 64)
// // {
// // printf(" %.3f %.3f \n", accs_f32[0][0], accs_f32[0][1]);
// // }
// if constexpr (!IS_NO_MASK_BLOCK) {
// for (int i = 0; i < 16; ++i) {
// int idx = i;
// if constexpr (!T::Is_causal) {
// if ((lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16 >= int(seqlen_k - block_idx * kBlockN))
// {
// #if defined(__gfx938__)
// accs_f32[i/4][i%4] = -INFINITY;
// #else
// accs_f32[i%4][i/4] = -INFINITY;
// #endif
// }
// } else {
// // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// int row = (lane_idx % 16) + warp_idx * 16;;
// int col_limit_right = seqlen_k - 1 - block_idx * kBlockN - (params.q_seq_per_hk - 1 - (m_block * kBlockM + row)) / params.q_head_per_hk;
// if ((lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16 > col_limit_right) {
// #if defined(__gfx938__)
// accs_f32[i/4][i%4] = -INFINITY;
// #else
// accs_f32[i%4][i/4] = -INFINITY;
// #endif
// }
// }
// }
// }
// Tensor scores = make_tensor<float>(Shape<_1, _16>{});
// for (int i = 0; i < 16; i++) {
// #if defined(__gfx938__)
// scores(0, i) = accs_f32[i/4][i%4];
// #else
// scores(0, i) = accs_f32[i%4][i/4];
// #endif
// }
// softmax.template softmax_rescale_o_prefill_4x1</*Is_first=*/IS_FIRST_MASK_BLOCK, /*Check_inf=*//*Is_local=*/T::Is_causal>(scores, acco_f32, params.scale_softmax_log2);
// Bf16_storage_x4 p[4];
// for (int i = 0; i < 4; i++)
// {
// #if defined(__gfx938__)
// p[i].data_32[0] = __builtin_hcu_cvt_pk_bf16_f32(0, scores(0, i * 4), 0, scores(0, i * 4 + 1), 0);
// p[i].data_32[1] = __builtin_hcu_cvt_pk_bf16_f32(0, scores(0, i * 4 + 2), 0, scores(0, i * 4 + 3), 0);
// #else
// p[i].data[0] = float2bf16(scores(0, i * 4));
// p[i].data[1] = float2bf16(scores(0, i * 4 + 1));
// p[i].data[2] = float2bf16(scores(0, i * 4 + 2));
// p[i].data[3] = float2bf16(scores(0, i * 4 + 3));
// #endif
// }
// int row_offset_v[4];
// for (int i = 0; i < 4; i++)
// {
// row_offset_v[i] = calc_row_and_col_v(i);
// }
// // __syncthreads();
// // #if 1
// {
// constexpr int k_val = (0);
// buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 0, block_idx, offset_k);
// buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 0, block_idx, offset_k);
// buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, 0, block_idx, offset_k);
// buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, 0, block_idx, offset_k);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// flash::pv_gemm<k_val, 0>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val, 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val, 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val, 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 1, block_idx, offset_k);
// flash::pv_gemm<k_val + 1, 0>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val + 1, 1>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val + 1, 2>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val + 1, 3>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 1, block_idx, offset_k);
// flash::pv_gemm<k_val + 2, 0>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val + 2, 1>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val + 2, 2>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val + 2, 3>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, 1, block_idx, offset_k);
// flash::pv_gemm<k_val + 3, 0>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val + 3, 1>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val + 3, 2>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val + 3, 3>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, 1, block_idx, offset_k);
// }
// #define LOAD_V_AND_PV_GEMM(n) \
// { \
// constexpr int k_val = (0); \
// constexpr int n_val = (n); \
// flash::pv_gemm<k_val, n_val * 4>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val, n_val * 4 + 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val, n_val * 4 + 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val, n_val * 4 + 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
// __builtin_amdgcn_sched_barrier(0); \
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
// __builtin_amdgcn_sched_barrier(0); \
// buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, n_val + 1, block_idx, offset_k); \
// flash::pv_gemm<k_val + 1, n_val * 4>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val + 1, n_val * 4 + 1>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val + 1, n_val * 4 + 2>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val + 1, n_val * 4 + 3>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
// __builtin_amdgcn_sched_barrier(0); \
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
// __builtin_amdgcn_sched_barrier(0); \
// buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, n_val + 1, block_idx, offset_k); \
// flash::pv_gemm<k_val + 2, n_val * 4>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val + 2, n_val * 4 + 1>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val + 2, n_val * 4 + 2>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val + 2, n_val * 4 + 3>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
// __builtin_amdgcn_sched_barrier(0); \
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
// __builtin_amdgcn_sched_barrier(0); \
// buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, n_val + 1, block_idx, offset_k); \
// flash::pv_gemm<k_val + 3, n_val * 4>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val + 3, n_val * 4 + 1>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val + 3, n_val * 4 + 2>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val + 3, n_val * 4 + 3>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
// __builtin_amdgcn_sched_barrier(0); \
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
// __builtin_amdgcn_sched_barrier(0); \
// buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, n_val + 1, block_idx, offset_k); \
// }
// LOAD_V_AND_PV_GEMM(1);
// LOAD_V_AND_PV_GEMM(2);
// {
// constexpr int n_val = (3);
// flash::pv_gemm<0, 12>(p[0].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<0, 13>(p[0].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<0, 14>(p[0].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<0, 15>(p[0].data_64, v_lds_read_ptr, acco_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// flash::pv_gemm<1, 12>(p[1].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<1, 13>(p[1].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<1, 14>(p[1].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<1, 15>(p[1].data_64, v_lds_read_ptr, acco_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// flash::pv_gemm<2, 12>(p[2].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<2, 13>(p[2].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<2, 14>(p[2].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<2, 15>(p[2].data_64, v_lds_read_ptr, acco_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// flash::pv_gemm<3, 12>(p[3].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<3, 13>(p[3].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<3, 14>(p[3].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<3, 15>(p[3].data_64, v_lds_read_ptr, acco_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// }
// };
// constexpr int n_masking_steps = !T::Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
// int n_block = n_block_max - 1;
// if constexpr (n_masking_steps == 1) {
// if (n_block >= n_block_min) {
// process_one_block(n_block, IsFirstMaskBlock{});
// }
// n_block--;
// } else {
// int masking_step = 1;
// if (n_block >= n_block_min) {
// process_one_block(n_block, IsFirstMaskBlock{});
// }
// n_block--;
// for (; n_block >= n_block_min && masking_step < n_masking_steps; ++masking_step, --n_block) {
// process_one_block(n_block, IsMaskBlock{});
// }
// }
// for(; n_block >= n_block_min; --n_block) {
// process_one_block(n_block, IsNoMaskBlock{});
// }
// using ElementAccum = float;
// if constexpr (true)
// {
// using ElementO = Element;
// const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
// const index_t row_offset_lse = (bidb * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM;
// constexpr bool Split = false;
// Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + ( row_offset_o)),
// Shape<Int<kBlockM>, Int<kHeadDimV>>{},
// make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
// Tensor lse = softmax.template normalize_softmax_lse_prefill_4x1</*Is_dropout=*/false, Split>(acco_f32, params.scale_softmax);
// // if (block0() && tidx < 64)
// // {
// // printf(" %.3f %.3f \n", float(acc_o(0)), float(acc_o(1)));
// // }
// Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lse)),
// Shape<Int<kBlockM>>{}, Stride<_1>{});
// {
// // using result_type = cutlass::Array<bfloat16_t, 2>;
// // int tidx = threadIdx.x;
// int row, col;
// // int warpid = tidx / 64;
// for (int mi = 0; mi < 1; mi++) {
// row = mi * kBlockM + lane_idx % 16 + warp_idx * 16;
// if (row < params.q_seq_per_hk - m_block * kBlockM) {
// for (int ni = 0; ni < 16; ++ni) {
// #if defined(__gfx938__)
// Bf16_storage res;
// col = (lane_idx / 16) * 8 + ni * 32 ;
// res.data_32[0] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][0], 0, acco_f32[ni * 2 + 1][0], 0);
// res.data_32[1] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][1], 0, acco_f32[ni * 2 + 1][1], 0);
// res.data_32[2] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][2], 0, acco_f32[ni * 2 + 1][2], 0);
// res.data_32[3] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][3], 0, acco_f32[ni * 2 + 1][3], 0);
// *(__fp16x8_t*)(&gOaccum(row, col)) = res.data_128;
// #else
// col = (lane_idx / 16) * 2 + ni * 32 ;
// using result_type = cutlass::Array<Element, 2>;
// for (int ei = 0; ei < 4; ei++)
// {
// result_type res;
// Element e0, e1;
// e0.storage = float2bf16(acco_f32[ni * 2][ei]);
// e1.storage = float2bf16(acco_f32[ni * 2 + 1][ei]);
// res[0] = e0;
// res[1] = e1;
// // gO(row, col) = res[0];
// // gO(row, col + 1) = res[1];
// *(result_type*)(&gOaccum(row, col)) = res;
// col += 8;
// }
// #endif
// }
// // for (int n = 0; n < 1; n++) {
// // col = (tidx % 64 / 16) + warpid * 32 + n * 128;
// // for (int ei = 0; ei < 8; ei ++) {
// // gOaccum(row, col) = rO(ei, m, n);
// // col += 4;
// // }
// // }
// gLSEaccum(row) = lse(mi);
// }
// }
// }
// }
// }
template<typename T>
__device__ void
compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
......@@ -599,34 +1206,742 @@ flash_fwd_splitkv_mla_kernel(const DenseAttnDecodeParams params) {
}
}
template<typename T, bool use_split_kv=false>
__global__ void __launch_bounds__(T::NUM_THREADS, 1)
flash_fwd_splitkv_mla_block_m_64_kernel(const DenseAttnDecodeParams params) {
constexpr int kBlockN = T::PAGE_BLOCK_SIZE;
const int m_block = blockIdx.x;
const int bidh = blockIdx.y;
int bidb;
int seqlen_k;
int n_block_min;
int n_block_max;
const int tidx = threadIdx.x;
const int lane_idx = tidx % 64;
bool is_split = use_split_kv;
if constexpr (use_split_kv)
{
int num_splits = params.total_num_splits / params.b;
bidb = blockIdx.z % params.b;
// bidb = blockIdx.z / num_splits;
seqlen_k = __ldg(params.seqlens_k_ptr + bidb);
int split_id = blockIdx.z / params.b;
n_block_min = split_id * params.partition_block_nums;
n_block_max = split_id == (num_splits - 1) ? cute::ceil_div(seqlen_k, kBlockN) :
std::min((split_id + 1) * params.partition_block_nums, cute::ceil_div(seqlen_k, kBlockN));
if (split_id == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN))
{
is_split = false;
}
// if (tidx == 0 && bidb == 61)
// {
// printf("bidb = %d split_id = %d n_block_min = %d n_block_max = %d num_splits = %d params.partition_block_nums %d is_split = %d \n", bidb, split_id, n_block_min, n_block_max, num_splits, params.partition_block_nums, is_split);
// }
if (n_block_max <= n_block_min) return;
template<typename InputT>
void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams &params) {
FLASH_ASSERT(params.d == Config::HEAD_DIM_K);
FLASH_ASSERT(params.d_v == Config::HEAD_DIM_V);
}
else
{
bidb = blockIdx.z;
seqlen_k = __ldg(params.seqlens_k_ptr + bidb);
n_block_min = 0;
n_block_max = cute::ceil_div(seqlen_k, kBlockN);
}
constexpr size_t smem_size = 65536;
extern __shared__ char shared_memory[];
constexpr int kBlockM = T::BLOCK_SIZE_M;
// constexpr int kBlockN = T::PAGE_BLOCK_SIZE;
constexpr int kHeadDim = T::HEAD_DIM_K;
constexpr int kHeadDimV = T::HEAD_DIM_V;
using Element = T::InputT;
using index_t = int64_t;
const int warp_idx = __builtin_amdgcn_readfirstlane(tidx / 64);
Element* q_lds = (Element*)&(shared_memory);
Element* k_lds = q_lds;
Element* v_lds = q_lds;
// Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
// const index_t row_offset_k = (0) * params.k_head_stride;
// Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
// Shape<Int<kBlockN>, Int<kHeadDim>>{},
// make_stride(params.k_row_stride, _1{}));
typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8)));
typedef __bf16 __fp16x4_t __attribute__((ext_vector_type(4)));
typedef __bf16 __fp16x2_t __attribute__((ext_vector_type(2)));
union Bf16_storage {
__fp16x8_t data_128;
__fp16x4_t data_64[2];
__fp16x2_t data_32[4];
uint16_t data_array[8];
};
union Bf16_storage_x4 {
__fp16x4_t data_64;
__fp16x2_t data_32[2];
uint16_t data[4];
};
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr_q;
*(uint64_t*)&glob_ptr_q = reinterpret_cast<uint64_t>(gQ.data().get());
glob_ptr_q.latter |= ((params.q_row_stride * 2) << 16);
glob_ptr_q.latter |= 0x40000000;
uint32x4_t global_addr_q = {0};
global_addr_q[0] = (glob_ptr_q.former);
global_addr_q[1] = (glob_ptr_q.latter);
global_addr_q[2] = params.q_seq_per_hk - m_block * kBlockM;
global_addr_q[3] = 0x00020000;
int virtual_row_ = lane_idx / 8;//0
int virtual_col_ = lane_idx % 8;//0
int swizzle_col_ = virtual_row_ ^ virtual_col_;
int row_ = lane_idx / 4;//0
// 8->9 9->8
// row_ = (row_ >= 8 ) ^ row_;
int col_ = swizzle_col_ % 4;
auto calc_row_and_col_k = [&]() -> std::tuple<int, int> {
constexpr int elements_per_thread = 8;
#if defined(__gfx938__)
int row_offset = row_ + warp_idx * 16;
int col_offset = col_ * 8;
#else
int row_offset = row_ * 4 + warp_idx;
int col_offset = col_ * 8;
#endif
return {row_offset, col_offset};
};
auto buffer_load_lds_k = [&](int row_offset, int col, int k_idx, int block_idx, index_t offset_k) {
constexpr int element_size = 2;
PtrWrapper glob_ptr_k;
*(uint64_t*)&glob_ptr_k = reinterpret_cast<uint64_t>(params.k_ptr) + offset_k * 2;
glob_ptr_k.latter |= ((params.k_row_stride * 2) << 16);
glob_ptr_k.latter |= 0x40000000;
uint32x4_t global_addr_k = {0};
global_addr_k[0] = __builtin_amdgcn_readfirstlane(glob_ptr_k.former);
global_addr_k[1] = __builtin_amdgcn_readfirstlane(glob_ptr_k.latter);
global_addr_k[2] = seqlen_k - block_idx * kBlockN;
global_addr_k[3] = 0x00020000;
constexpr int elements_per_thread = 8;
int col_offset = col;
int offset_v = col_offset * 2;
int ldsAddrPerWave = reinterpret_cast<size_t>(k_lds) + warp_idx * 16 * 32 * 2 + (k_idx % 5) * 64 * 32 * 2;
typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
uint32x2_t index_offset = {0};
index_offset[0] = row_offset;
index_offset[1] = offset_v;
const int offset_s = k_idx * 32 * 2;
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset),
"s"(ldsAddrPerWave), "s"(global_addr_k), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0);
};
auto k_lds_read_offset = [&] () -> int {
int row = lane_idx % 16;
int col = lane_idx / 16;
col = (row / 2) ^ col;
col = col % 4;
// row = (row >= 8) ^ row;
const auto lds_offset = row * 32 + col * 8;
// #endif
return lds_offset;
};
auto calc_row_and_col_v = [&](int i) -> int {
int row = lane_idx / 4;
// int col = lane_idx % 4;
int row_offset = row + i * 16;
// int col_offset = col * 8 + warp_idx * 32;
return row_offset;
};
const int v_lds_read_ptr = reinterpret_cast<size_t>(v_lds + lane_idx * 8);
Element* k_lds_read_ptr = (k_lds + k_lds_read_offset());
int col_offset_v = (lane_idx % 4) * 8 + warp_idx * 32;
auto buffer_load_lds_v = [&](int row_offset, int col, int k_idx, int n_idx, int block_idx, index_t offset_k) {
constexpr int element_size = 2;
PtrWrapper glob_ptr_k;
*(uint64_t*)&glob_ptr_k = reinterpret_cast<uint64_t>(params.k_ptr) + offset_k * 2;
glob_ptr_k.latter |= ((params.k_row_stride * 2) << 16);
glob_ptr_k.latter |= 0x40000000;
uint32x4_t global_addr_k = {0};
global_addr_k[0] = __builtin_amdgcn_readfirstlane(glob_ptr_k.former);
global_addr_k[1] = __builtin_amdgcn_readfirstlane(glob_ptr_k.latter);
global_addr_k[2] = seqlen_k - block_idx * kBlockN;
global_addr_k[3] = 0x00020000;
constexpr int elements_per_thread = 8;
int col_offset = col;
// int v_idx = row_offset;
int offset_v = col_offset * 2;
int ldsAddrPerWave = reinterpret_cast<size_t>(v_lds) + warp_idx * 16 * 32 * 2 + (k_idx) * 128 * 16 * 2;
typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
uint32x2_t index_offset = {0};
index_offset[0] = row_offset;
index_offset[1] = offset_v;
const int offset_s = n_idx * 128 * 2;
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset),
"s"(ldsAddrPerWave), "s"(global_addr_k), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0);
};
Bf16_storage q_reg[18];
for (int i = 0; i < 18; i++)
{
constexpr int elements_per_thread = 8;
int row = lane_idx % 16;
int col = lane_idx / 16;
int row_offset = row + warp_idx * 16;
int col_offset = col * 8;
int offset_v = col_offset * 2 + i * 32 * 2;
q_reg[i].data_128 = __builtin_amdgcn_buffer_load_dwordx4(global_addr_q, row_offset, offset_v, false, false);
}
__syncthreads();
v4f acco_f32[32];
for (int i = 0; i < 32; i++)
{
acco_f32[i].x = 0.0f;
acco_f32[i].y = 0.0f;
acco_f32[i].z = 0.0f;
acco_f32[i].w = 0.0f;
}
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
auto float2bf16 = [] (float s) -> uint16_t {
uint32_t x32 = reinterpret_cast<uint32_t const &>(s);
#ifndef FLASH_MLA_BF16_TYPE
#define FLASH_MLA_BF16_TYPE 0
#endif
#if FLASH_MLA_BF16_TYPE == 1
x32 += 0x8000u;
#endif
return uint16_t(x32 >> 16);
};
// int block_idx = 0;
// int cur_block_table = block_table[block_idx];
// index_t offset_k = (index_t)(cur_block_table) * params.k_batch_stride;
// auto [row_offset, col] = calc_row_and_col_k(block_idx);
// buffer_load_lds_k(row_offset, col, 0, offset_k);
// __syncthreads();
{
// if (thread0())
// {
// int k = 0;
// for (int i = 0; i < 64; i++)
// {
// for (int j = 0; j < 32; j++)
// {
// printf(" %.3f ", float(k_lds[k]));
// k++;
// }
// printf("\n");
// }
// }
// if (block0() && threadIdx.x < 64)
// {
// cutlass::bfloat16_t q[8];
// for (int i = 0; i < 8; i++)
// {
// q[i].storage = v_reg[0].data_array[i];
// // q[i].storage = q_reg[0].data_array[i];
// }
// printf("tidx %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %d\n ", threadIdx.x,
// float(q[0]),
// float(q[1]),
// float(q[2]),
// float(q[3]),
// float(q[4]),
// float(q[5]),
// float(q[6]),
// float(q[7]),
// v_lds_read_ptr
// );
// }
}
struct IsMaskBlock {};
struct IsFirstMaskBlock {};
struct IsNoMaskBlock {};
flash::Softmax<1> softmax;
auto process_one_block = [&] (int block_idx, auto is_mask_block_t) {
static constexpr bool IS_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsNoMaskBlock>;
static constexpr bool IS_FIRST_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsFirstMaskBlock>;
static constexpr bool IS_NO_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsNoMaskBlock>;
int cur_block_table = block_table[block_idx];
v4f accs_f32[4];
for (int i = 0; i < 4; i++)
{
accs_f32[i].x = 0.0f;
accs_f32[i].y = 0.0f;
accs_f32[i].z = 0.0f;
accs_f32[i].w = 0.0f;
}
index_t offset_k = (index_t)(cur_block_table) * params.k_batch_stride;
auto [row_offset, col] = calc_row_and_col_k();
#define LOAD_K_AND_QK_GEMM(k) \
{ \
constexpr int k_val = (k); \
buffer_load_lds_k(row_offset, col, k_val - 4, block_idx, offset_k); \
flash::qk_gemm<Element, k_val, 5>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
}
{
constexpr int k_val = (17);
buffer_load_lds_k(row_offset, col, k_val, block_idx, offset_k);
buffer_load_lds_k(row_offset, col, k_val - 1, block_idx, offset_k);
buffer_load_lds_k(row_offset, col, k_val - 2, block_idx, offset_k);
buffer_load_lds_k(row_offset, col, k_val - 3, block_idx, offset_k);
buffer_load_lds_k(row_offset, col, k_val - 4, block_idx, offset_k);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, k_val, 5>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
LOAD_K_AND_QK_GEMM(16);
LOAD_K_AND_QK_GEMM(15);
LOAD_K_AND_QK_GEMM(14);
LOAD_K_AND_QK_GEMM(13);
LOAD_K_AND_QK_GEMM(12);
LOAD_K_AND_QK_GEMM(11);
LOAD_K_AND_QK_GEMM(10);
LOAD_K_AND_QK_GEMM(9);
LOAD_K_AND_QK_GEMM(8);
LOAD_K_AND_QK_GEMM(7);
LOAD_K_AND_QK_GEMM(6);
LOAD_K_AND_QK_GEMM(5);
LOAD_K_AND_QK_GEMM(4);
// LOAD_K_AND_QK_GEMM(3);
flash::qk_gemm<Element, k_val - 14, 5>(q_reg[k_val - 14].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, k_val - 15, 5>(q_reg[k_val - 15].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, k_val - 16, 5>(q_reg[k_val - 16].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, k_val - 17, 5>(q_reg[k_val - 17].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
}
// if (block0() && tidx < 64)
// {
// printf(" %.3f %.3f \n", accs_f32[0][0], accs_f32[0][1]);
// }
if constexpr (!IS_NO_MASK_BLOCK) {
for (int i = 0; i < 16; ++i) {
int idx = i;
if constexpr (!T::Is_causal) {
if ((lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16 >= int(seqlen_k - block_idx * kBlockN))
{
#if defined(__gfx938__)
accs_f32[i/4][i%4] = -INFINITY;
#else
accs_f32[i%4][i/4] = -INFINITY;
#endif
}
} else {
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int row = (lane_idx % 16) + warp_idx * 16;;
int col_limit_right = seqlen_k - 1 - block_idx * kBlockN - (params.q_seq_per_hk - 1 - (m_block * kBlockM + row)) / params.q_head_per_hk;
if ((lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16 > col_limit_right) {
#if defined(__gfx938__)
accs_f32[i/4][i%4] = -INFINITY;
#else
accs_f32[i%4][i/4] = -INFINITY;
#endif
}
}
}
}
Tensor scores = make_tensor<float>(Shape<_1, _16>{});
for (int i = 0; i < 16; i++) {
#if defined(__gfx938__)
scores(0, i) = accs_f32[i/4][i%4];
#else
scores(0, i) = accs_f32[i%4][i/4];
#endif
}
softmax.template softmax_rescale_o_prefill_4x1</*Is_first=*/IS_FIRST_MASK_BLOCK, /*Check_inf=*//*Is_local=*/T::Is_causal>(scores, acco_f32, params.scale_softmax_log2);
Bf16_storage_x4 p[4];
for (int i = 0; i < 4; i++)
{
#if defined(__gfx938__)
p[i].data_32[0] = __builtin_hcu_cvt_pk_bf16_f32(0, scores(0, i * 4), 0, scores(0, i * 4 + 1), 0);
p[i].data_32[1] = __builtin_hcu_cvt_pk_bf16_f32(0, scores(0, i * 4 + 2), 0, scores(0, i * 4 + 3), 0);
#else
p[i].data[0] = float2bf16(scores(0, i * 4));
p[i].data[1] = float2bf16(scores(0, i * 4 + 1));
p[i].data[2] = float2bf16(scores(0, i * 4 + 2));
p[i].data[3] = float2bf16(scores(0, i * 4 + 3));
#endif
}
int row_offset_v[4];
for (int i = 0; i < 4; i++)
{
row_offset_v[i] = calc_row_and_col_v(i);
}
// __syncthreads();
// #if 1
{
constexpr int k_val = (0);
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 0, block_idx, offset_k);
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 0, block_idx, offset_k);
buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, 0, block_idx, offset_k);
buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, 0, block_idx, offset_k);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::pv_gemm<k_val, 0>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 1, block_idx, offset_k);
flash::pv_gemm<k_val + 1, 0>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 1, 1>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 1, 2>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 1, 3>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 1, block_idx, offset_k);
flash::pv_gemm<k_val + 2, 0>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 2, 1>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 2, 2>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 2, 3>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, 1, block_idx, offset_k);
flash::pv_gemm<k_val + 3, 0>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 3, 1>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 3, 2>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 3, 3>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, 1, block_idx, offset_k);
}
#define LOAD_V_AND_PV_GEMM(n) \
{ \
constexpr int k_val = (0); \
constexpr int n_val = (n); \
flash::pv_gemm<k_val, n_val * 4>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, n_val * 4 + 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, n_val * 4 + 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, n_val * 4 + 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, n_val + 1, block_idx, offset_k); \
flash::pv_gemm<k_val + 1, n_val * 4>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 1, n_val * 4 + 1>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 1, n_val * 4 + 2>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 1, n_val * 4 + 3>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, n_val + 1, block_idx, offset_k); \
flash::pv_gemm<k_val + 2, n_val * 4>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 2, n_val * 4 + 1>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 2, n_val * 4 + 2>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 2, n_val * 4 + 3>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, n_val + 1, block_idx, offset_k); \
flash::pv_gemm<k_val + 3, n_val * 4>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 3, n_val * 4 + 1>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 3, n_val * 4 + 2>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 3, n_val * 4 + 3>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, n_val + 1, block_idx, offset_k); \
}
LOAD_V_AND_PV_GEMM(1);
LOAD_V_AND_PV_GEMM(2);
{
constexpr int n_val = (3);
flash::pv_gemm<0, 12>(p[0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<0, 13>(p[0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<0, 14>(p[0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<0, 15>(p[0].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::pv_gemm<1, 12>(p[1].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<1, 13>(p[1].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<1, 14>(p[1].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<1, 15>(p[1].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::pv_gemm<2, 12>(p[2].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<2, 13>(p[2].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<2, 14>(p[2].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<2, 15>(p[2].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::pv_gemm<3, 12>(p[3].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<3, 13>(p[3].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<3, 14>(p[3].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<3, 15>(p[3].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
}
};
constexpr int n_masking_steps = !T::Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
int n_block = n_block_max - 1;
if constexpr (n_masking_steps == 1) {
if (n_block >= n_block_min) {
process_one_block(n_block, IsFirstMaskBlock{});
}
n_block--;
} else {
int masking_step = 1;
if (n_block >= n_block_min) {
process_one_block(n_block, IsFirstMaskBlock{});
}
n_block--;
for (; n_block >= n_block_min && masking_step < n_masking_steps; ++masking_step, --n_block) {
process_one_block(n_block, IsMaskBlock{});
}
}
for(; n_block >= n_block_min; --n_block) {
process_one_block(n_block, IsNoMaskBlock{});
}
using ElementAccum = float;
// if constexpr (!use_split_kv)
if (!is_split)
{
using ElementO = Element;
const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_lse = (bidb * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM;
constexpr bool Split = false;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + ( row_offset_o)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor lse = softmax.template normalize_softmax_lse_prefill_4x1</*Is_dropout=*/false, Split>(acco_f32, params.scale_softmax);
// if (block0() && tidx < 64)
// {
// printf(" %.3f %.3f \n", float(acc_o(0)), float(acc_o(1)));
// }
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lse)),
Shape<Int<kBlockM>>{}, Stride<_1>{});
{
// using result_type = cutlass::Array<bfloat16_t, 2>;
// int tidx = threadIdx.x;
int row, col;
// int warpid = tidx / 64;
for (int mi = 0; mi < 1; mi++) {
row = mi * kBlockM + lane_idx % 16 + warp_idx * 16;
if (row < params.q_seq_per_hk - m_block * kBlockM) {
for (int ni = 0; ni < 16; ++ni) {
#if defined(__gfx938__)
Bf16_storage res;
col = (lane_idx / 16) * 8 + ni * 32 ;
res.data_32[0] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][0], 0, acco_f32[ni * 2 + 1][0], 0);
res.data_32[1] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][1], 0, acco_f32[ni * 2 + 1][1], 0);
res.data_32[2] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][2], 0, acco_f32[ni * 2 + 1][2], 0);
res.data_32[3] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][3], 0, acco_f32[ni * 2 + 1][3], 0);
*(__fp16x8_t*)(&gOaccum(row, col)) = res.data_128;
#else
col = (lane_idx / 16) * 2 + ni * 32 ;
using result_type = cutlass::Array<Element, 2>;
for (int ei = 0; ei < 4; ei++)
{
result_type res;
Element e0, e1;
e0.storage = float2bf16(acco_f32[ni * 2][ei]);
e1.storage = float2bf16(acco_f32[ni * 2 + 1][ei]);
res[0] = e0;
res[1] = e1;
// gO(row, col) = res[0];
// gO(row, col + 1) = res[1];
*(result_type*)(&gOaccum(row, col)) = res;
col += 8;
}
#endif
}
// for (int n = 0; n < 1; n++) {
// col = (tidx % 64 / 16) + warpid * 32 + n * 128;
// for (int ei = 0; ei < 8; ei ++) {
// gOaccum(row, col) = rO(ei, m, n);
// col += 4;
// }
// }
gLSEaccum(row) = lse(mi);
}
}
}
}
else
{
using ElementO = float;
int num_splits = params.total_num_splits / params.b;
int split_idx = (blockIdx.z / params.b) + bidb * num_splits;
constexpr bool Split = true;
const index_t row_offset_oaccum = ((split_idx*params.h_k + bidh)*params.q_seq_per_hk + m_block * kBlockM)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1)
const index_t row_offset_lseaccum = (split_idx*params.h_k + bidh)*params.q_seq_per_hk + m_block * kBlockM;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (row_offset_oaccum)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lseaccum)),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor lse = softmax.template normalize_softmax_lse_prefill_4x1</*Is_dropout=*/false, Split>(acco_f32, params.scale_softmax);
{
// using result_type = cutlass::Array<bfloat16_t, 2>;
// int tidx = threadIdx.x;
int row, col;
// int warpid = tidx / 64;
for (int mi = 0; mi < 1; mi++) {
row = mi * kBlockM + lane_idx % 16 + warp_idx * 16;
if (row < params.q_seq_per_hk - m_block * kBlockM) {
for (int ni = 0; ni < 16; ++ni) {
#if defined(__gfx938__)
col = (lane_idx / 16) * 8 + ni * 32 ;
for (int ei = 0; ei < 4; ei++)
{
gOaccum(row, col) = acco_f32[ni * 2][ei];
gOaccum(row, col + 1) = acco_f32[ni * 2 + 1][ei];
col += 2;
}
#else
col = (lane_idx / 16) * 2 + ni * 32 ;
for (int ei = 0; ei < 4; ei++)
{
gOaccum(row, col) = acco_f32[ni * 2][ei];
gOaccum(row, col + 1) = acco_f32[ni * 2 + 1][ei];
col += 8;
}
#endif
}
// for (int n = 0; n < 1; n++) {
// col = (tidx % 64 / 16) + warpid * 32 + n * 128;
// for (int ei = 0; ei < 8; ei ++) {
// gOaccum(row, col) = rO(ei, m, n);
// col += 4;
// }
// }
gLSEaccum(row) = lse(mi);
}
}
}
}
}
template<typename InputT>
void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams &params) {
FLASH_ASSERT(params.d == Config::HEAD_DIM_K);
FLASH_ASSERT(params.d_v == Config::HEAD_DIM_V);
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (params.h_q >= 64 && params.h_k == 1) {
using T = Traits_Block_M_64<InputT, Is_causal>;
constexpr size_t smem_size = 16384 + 4096;
if (params.use_split_kv)
{
auto mla_kernel = &flash_fwd_splitkv_mla_block_m_64_kernel<T, true>;
const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M);
mla_kernel<<<dim3(num_m_block, params.h_k, params.total_num_splits), T::NUM_THREADS, smem_size, params.stream>>>(params);
}
else
{
auto mla_kernel = &flash_fwd_splitkv_mla_block_m_64_kernel<T>;
const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M);
mla_kernel<<<dim3(num_m_block, params.h_k, params.b), T::NUM_THREADS, smem_size, params.stream>>>(params);
}
} else {
constexpr size_t smem_size = 65536;
using T = Traits<InputT, Is_causal>;
const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M);
auto mla_kernel = &flash_fwd_splitkv_mla_kernel<T>;
mla_kernel<<<dim3(num_m_block, params.h_k, params.num_sm_parts), T::NUM_THREADS, smem_size, params.stream>>>(params);
}
});
// cudaLaunchConfig_t mla_kernel_config = {
// dim3(num_m_block, params.h_k, params.num_sm_parts),
// dim3(T::NUM_THREADS, 1, 1),
// smem_size,
// params.stream,
// mla_kernel_attributes,
// 1
// };
// cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params);
CHECK_CUDA_KERNEL_LAUNCH();
}
......
......@@ -127,3 +127,26 @@ struct Traits {
template<typename InputT_, bool Is_causal_>
struct Traits_Block_M_64 {
using InputT = InputT_;
static constexpr bool Is_causal = Is_causal_;
static constexpr int BLOCK_SIZE_M = 64;
static constexpr int PAGE_BLOCK_SIZE = 64;
static constexpr int HEAD_DIM_K = 576;
static constexpr int HEAD_DIM_V = 512;
static constexpr int NUM_THREADS = 256;
static_assert(std::is_same_v<InputT, cutlass::bfloat16_t> || std::is_same_v<InputT, cutlass::half_t>);
static constexpr int kBlockM = BLOCK_SIZE_M;
static constexpr int kBlockN = PAGE_BLOCK_SIZE;
static constexpr int kHeadDim = HEAD_DIM_K;
static constexpr int kHeadDimV = HEAD_DIM_V;
static constexpr int kNWarps = 4;
using Element = InputT;
using elem_type = Element;
using ElementAccum = float;
};
\ No newline at end of file
......@@ -236,7 +236,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
// int v_idx = row_offset;
int offset_v = col_offset * 2;
int ldsAddrPerWave = reinterpret_cast<size_t>(v_lds) + warp_idx * 16 * 32 * 2 + (k_idx % 1) * 512 * 16 * 2 + n_idx * 128 * 16 * 2;
int ldsAddrPerWave = reinterpret_cast<size_t>(v_lds) + warp_idx * 16 * 32 * 2 + (k_idx) * 128 * 16 * 2;
typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
uint32x2_t index_offset = {0};
index_offset[0] = row_offset;
......@@ -474,7 +474,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
#endif
}
softmax.template softmax_rescale_o_prefill_4x1</*Is_first=*/IS_FIRST_BLOCK, /*Check_inf=*//*Is_local=*/false>(scores, acco_f32, params.sm_scale_div_log2);
softmax.template softmax_rescale_o_prefill_4x1</*Is_first=*/IS_FIRST_BLOCK, /*Check_inf=*//*Is_local=*/true>(scores, acco_f32, params.sm_scale_div_log2);
Bf16_storage_x4 p[4];
for (int i = 0; i < 4; i++)
......@@ -500,9 +500,9 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
{
constexpr int k_val = (0);
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 0);
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 1);
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 2);
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 3);
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 0);
buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, 0);
buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, 0);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
......@@ -513,109 +513,111 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 0);
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 1);
flash::pv_gemm<k_val, 4>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 5>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 6>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 7>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 1, 0>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 1, 1>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 1, 2>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 1, 3>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 1);
flash::pv_gemm<k_val, 8>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 9>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 10>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 11>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 2, 0>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 2, 1>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 2, 2>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 2, 3>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 2);
buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, 1);
flash::pv_gemm<k_val, 12>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 13>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 14>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 15>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 3, 0>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 3, 1>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 3, 2>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 3, 3>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 3);
}
buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, 1);
#define LOAD_V_AND_PV_GEMM(k) \
}
#define LOAD_V_AND_PV_GEMM(n) \
{ \
constexpr int k_val = (k); \
flash::pv_gemm<k_val, 0>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
constexpr int k_val = (0); \
constexpr int n_val = (n); \
flash::pv_gemm<k_val, n_val * 4>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, n_val * 4 + 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, n_val * 4 + 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, n_val * 4 + 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 0); \
flash::pv_gemm<k_val, 4>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 5>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 6>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 7>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, n_val + 1); \
flash::pv_gemm<k_val + 1, n_val * 4>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 1, n_val * 4 + 1>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 1, n_val * 4 + 2>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 1, n_val * 4 + 3>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 1); \
flash::pv_gemm<k_val, 8>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 9>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 10>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 11>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, n_val + 1); \
flash::pv_gemm<k_val + 2, n_val * 4>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 2, n_val * 4 + 1>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 2, n_val * 4 + 2>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 2, n_val * 4 + 3>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 2); \
flash::pv_gemm<k_val, 12>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 13>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 14>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 15>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, n_val + 1); \
flash::pv_gemm<k_val + 3, n_val * 4>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 3, n_val * 4 + 1>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 3, n_val * 4 + 2>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 3, n_val * 4 + 3>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 3); \
buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, n_val + 1); \
}
LOAD_V_AND_PV_GEMM(1);
LOAD_V_AND_PV_GEMM(2);
{
constexpr int k_val = (3);
flash::pv_gemm<k_val, 0>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
constexpr int n_val = (3);
flash::pv_gemm<0, 12>(p[0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<0, 13>(p[0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<0, 14>(p[0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<0, 15>(p[0].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::pv_gemm<k_val, 4>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 5>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 6>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 7>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<1, 12>(p[1].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<1, 13>(p[1].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<1, 14>(p[1].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<1, 15>(p[1].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::pv_gemm<k_val, 8>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 9>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 10>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 11>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<2, 12>(p[2].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<2, 13>(p[2].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<2, 14>(p[2].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<2, 15>(p[2].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::pv_gemm<k_val, 12>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 13>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 14>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 15>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<3, 12>(p[3].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<3, 13>(p[3].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<3, 14>(p[3].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<3, 15>(p[3].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
}
#else
#define LOAD_V_AND_PV_GEMM(k) \
{ \
......
......@@ -58,6 +58,10 @@ struct DenseAttnDecodeParams { // TODO Change name to DenseAttnDecodeParams
float *__restrict__ oaccum_ptr;
cudaStream_t stream;
bool use_split_kv;
int partition_block_nums;
};
struct DenseAttnDecodeParams_fp8 : public DenseAttnDecodeParams {
......@@ -127,6 +131,12 @@ struct CombineParams {
float* attn_sink; // [h_q], may be nullptr
cudaStream_t stream;
bool use_split_kv;
int num_splits;
int *__restrict__ seqlens_k_ptr;
int partition_block_nums;
};
struct GetDecodeSchedMetaParams {
......
......@@ -621,7 +621,7 @@ struct Softmax {
// static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = !true
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
......
......@@ -1553,7 +1553,7 @@ __forceinline__ __device__ void gemm1_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor
#endif
typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8)));
template<typename Element, int k_idx>
template<typename Element, int k_idx, int k_mod = 4>
__forceinline__ __device__ void qk_gemm(const __fp16x8_t& q_data, Element* k_lds_read_ptr, v4f* accs_f32)
{
typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8)));
......@@ -1563,7 +1563,7 @@ __forceinline__ __device__ void qk_gemm(const __fp16x8_t& q_data, Element* k_lds
__fp16x4_t data_64[2];
uint16_t data_array[8];
};
constexpr int k_idx_even = k_idx % 4;
constexpr int k_idx_even = k_idx % k_mod;
constexpr int n_offset = 16 * 32;
constexpr int k_offset = k_idx_even * 64 * 32;
Bf16_storage q_reg;
......@@ -1616,7 +1616,7 @@ typedef __bf16 __fp16x4_t __attribute__((ext_vector_type(4)));
template<int k_idx, int n_idx_val>
__forceinline__ __device__ void pv_gemm(const __fp16x4_t& p, int v_lds_read_ptr, v4f* acco_f32)
{
constexpr int k_idx_even = k_idx % 1;
constexpr int k_idx_even = k_idx;
constexpr int n_offset = 16 * 32 * 2;
typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8)));
union Bf16_storage {
......@@ -1624,11 +1624,11 @@ __forceinline__ __device__ void pv_gemm(const __fp16x4_t& p, int v_lds_read_ptr,
__fp16x4_t data_64[2];
uint16_t data_array[8];
};
constexpr int k_offset = k_idx_even * 16 * 512 * 2;
constexpr int k_offset = k_idx_even * 16 * 128 * 2;
// #if 1
Bf16_storage v_reg;
v_reg.data_128 = __builtin_amdgcn_ds_read_m32x16f16_alt((__attribute__((address_space(3))) __fp16*)(v_lds_read_ptr), k_offset + n_idx_val * n_offset);
v_reg.data_128 = __builtin_amdgcn_ds_read_m32x16f16_alt((__attribute__((address_space(3))) __fp16*)(v_lds_read_ptr), k_offset + (n_idx_val % 4) * n_offset);
#if defined(__gfx938__)
acco_f32[n_idx_val * 2] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(p, v_reg.data_64[0], acco_f32[n_idx_val * 2], true, false);
acco_f32[n_idx_val * 2 + 1] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(p, v_reg.data_64[1], acco_f32[n_idx_val * 2 + 1], true, false);
......
......@@ -172,7 +172,7 @@ def test_flash_mla(t: TestParam):
assert is_correct
if t.test_performance:
time_usage = kk.bench_kineto(run_flash_mla, 10).get_kernel_time("flash_fwd_splitkv_mla_kernel")
time_usage = kk.bench_kineto(run_flash_mla, 10).get_kernel_time("flash_fwd_splitkv_mla")
mean_attended_seqlens = cache_seqlens.float().mean().item()
compute_volume_flop = t.b * t.h_q * t.s_q * sum([
......@@ -226,7 +226,7 @@ def main(torch_dtype):
TestParam(128, s_q, s_k, is_varlen=True, is_causal=is_causal, h_q = h_q, test_performance=True)
for is_causal in [False, True]
for s_q in [1, 2]
for h_q in [16, 128]
for h_q in [16, 64, 128]
for s_k in [4096, 8192, 16384, 32768]
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment