Commit c85c787e authored by zhanghj2's avatar zhanghj2
Browse files

Merge branch 'feature/kimi-nhead64-dense' into 'master'

Feature/kimi nhead64 dense

See merge request dcutoolkit/deeplearing/flashmla!10
parents a45f646b aec17474
......@@ -92,7 +92,7 @@ dense_attn_decode_interface(
KU_CHECK_CONTIGUOUS(out);
KU_CHECK_CONTIGUOUS(lse);
if (!tile_scheduler_metadata.has_value()) {
if (!tile_scheduler_metadata.has_value() && ((num_heads_q < 64 && num_heads_k == 1) || num_heads_k > 1)) {
tile_scheduler_metadata = torch::empty({num_sm_parts, sizeof(DecodingSchedMeta)/4}, opts.dtype(torch::kInt32));
num_splits = torch::empty({batch_size+1}, opts.dtype(torch::kInt32));
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
......@@ -125,20 +125,6 @@ dense_attn_decode_interface(
if (const char* val = std::getenv("FLASH_MLA_PRINT_PARAM")) {
print_param = (std::string(val) == "1");
}
if (print_param) {
fprintf(stderr, "[FlashMLA] [dense_attn_decode_interface] [%s] batch_size = %d seqlen_q_ori = %d "
"num_heads_q = %d head_size_k = %d max_num_blocks_per_seq = %d num_blocks %d page_block_size = %d num_heads_k = %d \n",
arch.archName.c_str(),
batch_size,
seqlen_q_ori,
num_heads_q,
head_size_k,
max_num_blocks_per_seq,
num_blocks,
page_block_size,
num_heads_k
);
}
// Set the sizes
DenseAttnDecodeParams params;
params.b = batch_size;
......@@ -174,10 +160,10 @@ dense_attn_decode_interface(
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;
if ((num_heads_q < 64 && num_heads_k == 1) || num_heads_k > 1) {
params.tile_scheduler_metadata_ptr = (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr();
params.num_sm_parts = num_sm_parts;
params.num_splits_ptr = num_splits->data_ptr<int>();
const int total_num_splits = batch_size + params.num_sm_parts;
at::Tensor lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat));
......@@ -186,7 +172,65 @@ dense_attn_decode_interface(
params.total_num_splits = total_num_splits;
params.softmax_lseaccum_ptr = lse_accum.data_ptr<float>();
params.oaccum_ptr = out_accum.data_ptr<float>();
params.use_split_kv = false;
} else {
bool use_split_kv = true;
int num_m_blocks = (params.q_seq_per_hk + 64 - 1) / 64;
int num_sms = arch.num_sms;
int num_splits = num_sms * 3 / (num_m_blocks * params.b);
if (max_num_blocks_per_seq >= 32768 / 64) {
num_splits = 32;
} else if (max_num_blocks_per_seq >= 16384 / 64) {
num_splits = 32;
} else if (max_num_blocks_per_seq >= 8192 / 64) {
num_splits = 16;
} else if (max_num_blocks_per_seq >= 4096 / 64) {
num_splits = 8;
} else if (max_num_blocks_per_seq >= 2048 / 64) {
num_splits = 4;
} else {
num_splits = 1;
}
if (params.b >= 128) {
num_splits = 1;
}
if (num_splits <= 1) {
use_split_kv = false;
}
else {
num_splits = std::min(num_splits, 240);
params.partition_block_nums = max_num_blocks_per_seq / num_splits;
}
if (params.partition_block_nums <= 4) {
use_split_kv = false;
}
params.use_split_kv = use_split_kv;
params.total_num_splits = params.b * num_splits;
at::Tensor lse_accum = torch::empty({params.total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({params.total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat));
KU_CHECK_CONTIGUOUS(lse_accum);
KU_CHECK_CONTIGUOUS(out_accum);
params.softmax_lseaccum_ptr = lse_accum.data_ptr<float>();
params.oaccum_ptr = out_accum.data_ptr<float>();
}
if (print_param) {
fprintf(stderr, "[FlashMLA] [dense_attn_decode_interface] [%s] batch_size = %d seqlen_q_ori = %d "
"num_heads_q = %d head_size_k = %d max_num_blocks_per_seq = %d num_blocks %d page_block_size = %d num_heads_k = %d use_split_kv = %d num_splits %d params.partition_block_nums = %d \n",
arch.archName.c_str(),
batch_size,
seqlen_q_ori,
num_heads_q,
head_size_k,
max_num_blocks_per_seq,
num_blocks,
page_block_size,
num_heads_k,
params.use_split_kv,
params.total_num_splits / params.b,
params.partition_block_nums
);
}
params.stream = at::cuda::getCurrentCUDAStream().stream();
if (q_dtype == torch::kBFloat16) {
......@@ -220,18 +264,24 @@ dense_attn_decode_interface(
params.num_sm_parts,
nullptr,
at::cuda::getCurrentCUDAStream().stream()
at::cuda::getCurrentCUDAStream().stream(),
params.use_split_kv,
params.total_num_splits / params.b,
params.seqlens_k_ptr,
params.partition_block_nums
};
if ((num_heads_q < 64 && num_heads_k == 1) || num_heads_k > 1 || params.use_split_kv) {
if (q_dtype == torch::kBFloat16) {
gfx9::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params);
} else if (q_dtype == torch::kHalf) {
#ifndef FLASH_MLA_DISABLE_FP16
#ifndef FLASH_MLA_DISABLE_FP16
gfx9::decode::run_flash_mla_combine_kernel<cutlass::half_t>(combine_params);
#endif
#endif
} else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}
}
out = out.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk, head_size_v}).transpose(1, 2)
.reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v});
......
......@@ -16,7 +16,7 @@ using namespace cute;
namespace gfx9::decode {
template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS>
template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS, bool USE_SPLIT_KV=false>
__global__ void __launch_bounds__(NUM_THREADS, 1)
flash_fwd_mla_combine_kernel(const CombineParams params) {
// grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M]
......@@ -33,12 +33,36 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
return;
}
const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx);
const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1);
const int my_num_splits = end_split_idx - start_split_idx;
int start_split_idx;
int end_split_idx;
int my_num_splits;
if constexpr (USE_SPLIT_KV)
{
start_split_idx = batch_idx * params.num_splits;
end_split_idx = (batch_idx + 1) * params.num_splits;
int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx);
end_split_idx = std::min(cute::ceil_div(cute::ceil_div(seqlen_k, 64), params.partition_block_nums), params.num_splits) + start_split_idx;
// if (lane_idx == 0 && batch_idx == 61)
// {
// printf(" batch_idx = %d start_split_idx = %d end_split_idx = %d seqlen_k = %d \n",batch_idx, start_split_idx, end_split_idx, seqlen_k);
// }
my_num_splits = end_split_idx - start_split_idx;
if (my_num_splits == 1) {
return;
}
}
else
{
start_split_idx = __ldg(params.num_splits_ptr + batch_idx);
end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1);
my_num_splits = end_split_idx - start_split_idx;
if (my_num_splits == 1) {
return;
}
}
// FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
......@@ -245,6 +269,9 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
} else if (NUM_SPLITS <= 160) { \
constexpr static int NAME = 160; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 240) { \
constexpr static int NAME = 240; \
return __VA_ARGS__(); \
} else { \
FLASH_ASSERT(false); \
} \
......@@ -255,29 +282,33 @@ template<typename ElementT>
void run_flash_mla_combine_kernel(CombineParams &params) {
static constexpr int HEAD_DIM_V = 512; // Since only this head dimension is supported by Flash MLA
FLASH_ASSERT(params.d_v == HEAD_DIM_V);
if (params.use_split_kv)
{
MLA_NUM_SPLITS_SWITCH(params.num_splits, NUM_SPLITS, [&] {
constexpr int BLOCK_SIZE_M = 4;
constexpr int NUM_THREADS = BLOCK_SIZE_M*64;
constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float);
auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS, true>;
combine_kernel<<<dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
NUM_THREADS,
smem_size,
params.stream>>>(params);
});
}
else
{
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] {
constexpr int BLOCK_SIZE_M = 4;
constexpr int NUM_THREADS = BLOCK_SIZE_M*64;
constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float);
auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>;
// CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
// cudaLaunchAttribute attribute[1];
// attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
// attribute[0].val.programmaticStreamSerializationAllowed = 1;
// cudaLaunchConfig_t combine_kernel_config = {
// dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
// dim3(NUM_THREADS, 1, 1),
// 0,
// params.stream,
// attribute,
// 1
// };
combine_kernel<<<dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
NUM_THREADS,
smem_size,
params.stream>>>(params);
});
}
CHECK_CUDA_KERNEL_LAUNCH();
}
......
This diff is collapsed.
......@@ -127,3 +127,26 @@ struct Traits {
template<typename InputT_, bool Is_causal_>
struct Traits_Block_M_64 {
using InputT = InputT_;
static constexpr bool Is_causal = Is_causal_;
static constexpr int BLOCK_SIZE_M = 64;
static constexpr int PAGE_BLOCK_SIZE = 64;
static constexpr int HEAD_DIM_K = 576;
static constexpr int HEAD_DIM_V = 512;
static constexpr int NUM_THREADS = 256;
static_assert(std::is_same_v<InputT, cutlass::bfloat16_t> || std::is_same_v<InputT, cutlass::half_t>);
static constexpr int kBlockM = BLOCK_SIZE_M;
static constexpr int kBlockN = PAGE_BLOCK_SIZE;
static constexpr int kHeadDim = HEAD_DIM_K;
static constexpr int kHeadDimV = HEAD_DIM_V;
static constexpr int kNWarps = 4;
using Element = InputT;
using elem_type = Element;
using ElementAccum = float;
};
\ No newline at end of file
......@@ -236,7 +236,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
// int v_idx = row_offset;
int offset_v = col_offset * 2;
int ldsAddrPerWave = reinterpret_cast<size_t>(v_lds) + warp_idx * 16 * 32 * 2 + (k_idx % 1) * 512 * 16 * 2 + n_idx * 128 * 16 * 2;
int ldsAddrPerWave = reinterpret_cast<size_t>(v_lds) + warp_idx * 16 * 32 * 2 + (k_idx) * 128 * 16 * 2;
typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
uint32x2_t index_offset = {0};
index_offset[0] = row_offset;
......@@ -474,7 +474,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
#endif
}
softmax.template softmax_rescale_o_prefill_4x1</*Is_first=*/IS_FIRST_BLOCK, /*Check_inf=*//*Is_local=*/false>(scores, acco_f32, params.sm_scale_div_log2);
softmax.template softmax_rescale_o_prefill_4x1</*Is_first=*/IS_FIRST_BLOCK, /*Check_inf=*//*Is_local=*/true>(scores, acco_f32, params.sm_scale_div_log2);
Bf16_storage_x4 p[4];
for (int i = 0; i < 4; i++)
......@@ -500,9 +500,9 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
{
constexpr int k_val = (0);
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 0);
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 1);
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 2);
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 3);
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 0);
buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, 0);
buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, 0);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
......@@ -513,109 +513,111 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 0);
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 1);
flash::pv_gemm<k_val, 4>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 5>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 6>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 7>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 1, 0>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 1, 1>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 1, 2>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 1, 3>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 1);
flash::pv_gemm<k_val, 8>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 9>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 10>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 11>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 2, 0>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 2, 1>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 2, 2>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 2, 3>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 2);
buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, 1);
flash::pv_gemm<k_val, 12>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 13>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 14>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 15>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 3, 0>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 3, 1>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 3, 2>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val + 3, 3>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 3);
}
buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, 1);
#define LOAD_V_AND_PV_GEMM(k) \
}
#define LOAD_V_AND_PV_GEMM(n) \
{ \
constexpr int k_val = (k); \
flash::pv_gemm<k_val, 0>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
constexpr int k_val = (0); \
constexpr int n_val = (n); \
flash::pv_gemm<k_val, n_val * 4>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, n_val * 4 + 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, n_val * 4 + 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, n_val * 4 + 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 0); \
flash::pv_gemm<k_val, 4>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 5>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 6>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 7>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, n_val + 1); \
flash::pv_gemm<k_val + 1, n_val * 4>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 1, n_val * 4 + 1>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 1, n_val * 4 + 2>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 1, n_val * 4 + 3>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 1); \
flash::pv_gemm<k_val, 8>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 9>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 10>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 11>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, n_val + 1); \
flash::pv_gemm<k_val + 2, n_val * 4>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 2, n_val * 4 + 1>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 2, n_val * 4 + 2>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 2, n_val * 4 + 3>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 2); \
flash::pv_gemm<k_val, 12>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 13>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 14>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 15>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, n_val + 1); \
flash::pv_gemm<k_val + 3, n_val * 4>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 3, n_val * 4 + 1>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 3, n_val * 4 + 2>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 3, n_val * 4 + 3>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 3); \
buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, n_val + 1); \
}
LOAD_V_AND_PV_GEMM(1);
LOAD_V_AND_PV_GEMM(2);
{
constexpr int k_val = (3);
flash::pv_gemm<k_val, 0>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
constexpr int n_val = (3);
flash::pv_gemm<0, 12>(p[0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<0, 13>(p[0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<0, 14>(p[0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<0, 15>(p[0].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::pv_gemm<k_val, 4>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 5>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 6>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 7>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<1, 12>(p[1].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<1, 13>(p[1].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<1, 14>(p[1].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<1, 15>(p[1].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::pv_gemm<k_val, 8>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 9>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 10>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 11>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<2, 12>(p[2].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<2, 13>(p[2].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<2, 14>(p[2].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<2, 15>(p[2].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::pv_gemm<k_val, 12>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 13>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 14>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 15>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<3, 12>(p[3].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<3, 13>(p[3].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<3, 14>(p[3].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<3, 15>(p[3].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
}
#else
#define LOAD_V_AND_PV_GEMM(k) \
{ \
......
......@@ -58,6 +58,10 @@ struct DenseAttnDecodeParams { // TODO Change name to DenseAttnDecodeParams
float *__restrict__ oaccum_ptr;
cudaStream_t stream;
bool use_split_kv;
int partition_block_nums;
};
struct DenseAttnDecodeParams_fp8 : public DenseAttnDecodeParams {
......@@ -127,6 +131,12 @@ struct CombineParams {
float* attn_sink; // [h_q], may be nullptr
cudaStream_t stream;
bool use_split_kv;
int num_splits;
int *__restrict__ seqlens_k_ptr;
int partition_block_nums;
};
struct GetDecodeSchedMetaParams {
......
......@@ -621,7 +621,7 @@ struct Softmax {
// static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = !true
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
......
......@@ -1553,7 +1553,7 @@ __forceinline__ __device__ void gemm1_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor
#endif
typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8)));
template<typename Element, int k_idx>
template<typename Element, int k_idx, int k_mod = 4>
__forceinline__ __device__ void qk_gemm(const __fp16x8_t& q_data, Element* k_lds_read_ptr, v4f* accs_f32)
{
typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8)));
......@@ -1563,7 +1563,7 @@ __forceinline__ __device__ void qk_gemm(const __fp16x8_t& q_data, Element* k_lds
__fp16x4_t data_64[2];
uint16_t data_array[8];
};
constexpr int k_idx_even = k_idx % 4;
constexpr int k_idx_even = k_idx % k_mod;
constexpr int n_offset = 16 * 32;
constexpr int k_offset = k_idx_even * 64 * 32;
Bf16_storage q_reg;
......@@ -1616,7 +1616,7 @@ typedef __bf16 __fp16x4_t __attribute__((ext_vector_type(4)));
template<int k_idx, int n_idx_val>
__forceinline__ __device__ void pv_gemm(const __fp16x4_t& p, int v_lds_read_ptr, v4f* acco_f32)
{
constexpr int k_idx_even = k_idx % 1;
constexpr int k_idx_even = k_idx;
constexpr int n_offset = 16 * 32 * 2;
typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8)));
union Bf16_storage {
......@@ -1624,11 +1624,11 @@ __forceinline__ __device__ void pv_gemm(const __fp16x4_t& p, int v_lds_read_ptr,
__fp16x4_t data_64[2];
uint16_t data_array[8];
};
constexpr int k_offset = k_idx_even * 16 * 512 * 2;
constexpr int k_offset = k_idx_even * 16 * 128 * 2;
// #if 1
Bf16_storage v_reg;
v_reg.data_128 = __builtin_amdgcn_ds_read_m32x16f16_alt((__attribute__((address_space(3))) __fp16*)(v_lds_read_ptr), k_offset + n_idx_val * n_offset);
v_reg.data_128 = __builtin_amdgcn_ds_read_m32x16f16_alt((__attribute__((address_space(3))) __fp16*)(v_lds_read_ptr), k_offset + (n_idx_val % 4) * n_offset);
#if defined(__gfx938__)
acco_f32[n_idx_val * 2] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(p, v_reg.data_64[0], acco_f32[n_idx_val * 2], true, false);
acco_f32[n_idx_val * 2 + 1] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(p, v_reg.data_64[1], acco_f32[n_idx_val * 2 + 1], true, false);
......
......@@ -172,7 +172,7 @@ def test_flash_mla(t: TestParam):
assert is_correct
if t.test_performance:
time_usage = kk.bench_kineto(run_flash_mla, 10).get_kernel_time("flash_fwd_splitkv_mla_kernel")
time_usage = kk.bench_kineto(run_flash_mla, 10).get_kernel_time("flash_fwd_splitkv_mla")
mean_attended_seqlens = cache_seqlens.float().mean().item()
compute_volume_flop = t.b * t.h_q * t.s_q * sum([
......@@ -226,7 +226,7 @@ def main(torch_dtype):
TestParam(128, s_q, s_k, is_varlen=True, is_causal=is_causal, h_q = h_q, test_performance=True)
for is_causal in [False, True]
for s_q in [1, 2]
for h_q in [16, 128]
for h_q in [16, 64, 128]
for s_k in [4096, 8192, 16384, 32768]
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment