Commit 8b581a54 authored by zhanghj2's avatar zhanghj2
Browse files

支持kme dense bf16

parent c85c787e
......@@ -24,9 +24,9 @@ dense_attn_decode_interface(
) {
// Check arch
Arch arch = Arch();
if (!arch.is_gfx93x()) {
TORCH_CHECK(false, "Dense decode MLA is only supported on gfx936 or gfx938 architecture");
}
// if (!arch.is_gfx93x()) {
// TORCH_CHECK(false, "Dense decode MLA is only supported on gfx936 or gfx938 architecture");
// }
// Check data types
auto q_dtype = q.dtype();
......@@ -92,7 +92,7 @@ dense_attn_decode_interface(
KU_CHECK_CONTIGUOUS(out);
KU_CHECK_CONTIGUOUS(lse);
if (!tile_scheduler_metadata.has_value() && ((num_heads_q < 64 && num_heads_k == 1) || num_heads_k > 1)) {
if (!tile_scheduler_metadata.has_value() && (arch.is_gfx928() || (num_heads_q < 64 && num_heads_k == 1) || num_heads_k > 1)) {
tile_scheduler_metadata = torch::empty({num_sm_parts, sizeof(DecodingSchedMeta)/4}, opts.dtype(torch::kInt32));
num_splits = torch::empty({batch_size+1}, opts.dtype(torch::kInt32));
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
......@@ -159,8 +159,8 @@ dense_attn_decode_interface(
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;
if ((num_heads_q < 64 && num_heads_k == 1) || num_heads_k > 1) {
params.is_gfx928 = arch.is_gfx928();
if ((num_heads_q < 64 && num_heads_k == 1) || num_heads_k > 1 || arch.is_gfx928()) {
params.tile_scheduler_metadata_ptr = (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr();
params.num_sm_parts = num_sm_parts;
params.num_splits_ptr = num_splits->data_ptr<int>();
......@@ -271,7 +271,7 @@ dense_attn_decode_interface(
params.partition_block_nums
};
if ((num_heads_q < 64 && num_heads_k == 1) || num_heads_k > 1 || params.use_split_kv) {
if ((num_heads_q < 64 && num_heads_k == 1) || num_heads_k > 1 || params.use_split_kv || arch.is_gfx928()) {
if (q_dtype == torch::kBFloat16) {
gfx9::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params);
} else if (q_dtype == torch::kHalf) {
......
......@@ -1178,6 +1178,364 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
}
}
template<typename T>
__device__ void
compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params,
const int bidb, const int bidh, const int m_block,
const int n_split_idx, const int seqlen_k,
const int n_block_min, const int n_block_max, const bool NoSplit)
{
extern __shared__ char shared_memory[];
using SharedMemoryPlan = typename T::SharedMemoryPlan;
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(shared_memory);
const int tidx = threadIdx.x;
constexpr int kBlockM = T::BLOCK_SIZE_M;
constexpr int kBlockN = T::PAGE_BLOCK_SIZE;
constexpr int kHeadDim = T::HEAD_DIM_K;
constexpr int kHeadDimV = T::HEAD_DIM_V;
using Element = T::InputT;
using index_t = int64_t;
const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
const index_t row_offset_k = (bidh) * params.k_head_stride;
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.k_row_stride, _1{}));
Tensor sQ = make_tensor(make_smem_ptr(plan.smem_q.data()), typename T::SmemLayoutQ{});
Tensor sV = make_tensor(make_smem_ptr(plan.smem_v.data()), typename T::SmemLayoutV{});
Tensor sK = make_tensor(make_smem_ptr(plan.smem_v.data()), typename T::SmemLayoutK{});
Tensor sP = make_tensor(make_smem_ptr(plan.smem_p.data()), typename T::SmemLayoutP{});
Tensor sVt = make_tensor(sV.data(), typename T::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename T::SmemLayoutVtransposedNoSwizzle{});
Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(plan.smem_row_max.data()), typename T::SmemLayoutRow{});
Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(plan.smem_row_sum.data()), typename T::SmemLayoutRow{});
using MMA_Atom_Arch = std::conditional_t<
std::is_same_v<Element, cutlass::half_t>,
MMA_Atom<GFX928_16x16x32_F32F16F16F32_NT>,
MMA_Atom<GFX928_16x16x32_F32BF16BF16F32_NT>
>;
using ValLayoutMNK = Layout<Shape<_1, _1, _1>>;
using TiledMma = TiledMMA<
MMA_Atom_Arch,
Layout<Shape<_1, Int<4>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
typename T::TiledMma_O tiled_mma_o;
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
#if 1
typename T::GmemTiledCopyQ gmem_tiled_copy_Q;
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ)));
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ);
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQgQ)));
if (tidx < 128)
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, false>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
params.q_seq_per_hk - m_block * kBlockM);
__syncthreads();
auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ);
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
__syncthreads();
#else
#endif
typename T::GmemTiledCopyK gmem_tiled_copy_K;
auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx);
Tensor tKgK = gmem_thr_copy_K.partition_S(gK);
Tensor tKsK = gmem_thr_copy_K.partition_D(sK);
Tensor cK = make_identity_tensor(make_shape(size<0>(gK), size<1>(gK)));
Tensor tKcK = gmem_thr_copy_K.partition_S(cK);
Tensor tKpK = make_tensor<bool>(make_shape(size<2>(tKgK)));
auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSgK = smem_thr_copy_K.partition_S(gK);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
Tensor tSrK = thr_mma.partition_fragment_B(sK);
Tensor tKcK_smem = smem_thr_copy_K.partition_S(cK);
Tensor tKpK_smem = make_tensor<bool>(make_shape(size<2>(tSgK)));
Tensor tSrK_smem = thr_mma.partition_fragment_B(gK);
typename T::GmemTiledCopyV gmem_tiled_copy_V;
auto gmem_thr_copy_V = gmem_tiled_copy_V.get_thread_slice(tidx);
Tensor tVgV = gmem_thr_copy_V.partition_S(gV);
Tensor tVsV = gmem_thr_copy_V.partition_D(sV);
Tensor cV = make_identity_tensor(make_shape(size<0>(gV), size<1>(gV)));
Tensor tVcV = gmem_thr_copy_V.partition_S(cV);
Tensor tVpV = make_tensor<bool>(make_shape(size<2>(tVgV)));
auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_DS_M32x16_B16, Element>{}, tiled_mma_o);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
Tensor tOrVt = thr_mma_o.partition_fragment_B(sVtNoSwizzle);
constexpr int n_masking_steps = !T::Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
int n_block = n_block_max - 1;
// constexpr static int k0_lds_loops = 0;
constexpr static int k0_lds_loops = 16;
constexpr static int k0_loops = size<2>(tSrK_smem);
constexpr static int k1_loops = size<2>(tOrVt);
Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{});
clear(acc_o);
flash::Softmax<size<1>(acc_o)> softmax;
int cur_block_table;
index_t offset_k;
constexpr static int BUFFER_SIZE = 4;
uint128_t buffer[BUFFER_SIZE];
if (n_block >= n_block_min)
{
cur_block_table = block_table[n_block];
offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK.data() + (offset_k);
flash::buffer_load_copy<false, true, false>(gK, buffer[0], 0, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN);
flash::buffer_load_copy<false, true, false>(gK, buffer[1], 1, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN);
flash::buffer_load_copy<false, true, false>(gK, buffer[2], 2, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN);
}
#if 1
#pragma unroll
for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) {
asm volatile("s_barrier\n\t");
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s);
Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
// 计算0~11
#if 1
#pragma unroll
for (int i = 0; i < k0_lds_loops - BUFFER_SIZE + 1; i++) {
// asm volatile("s_waitcnt vmcnt(3) \n\t \n\t");
flash::asm_ds_write(buffer[i % BUFFER_SIZE], tKsK, i);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, i), tSrK_copy_view(_, _, i));
flash::buffer_load_copy<false, true, false>(gK, buffer[(i + BUFFER_SIZE - 1) % BUFFER_SIZE], i + BUFFER_SIZE - 1, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN);
cute::gemm(tiled_mma, tSrQ(_, _, i), tSrK(_, _, i), acc_s);
// asm volatile("s_barrier\n\t");
}
// asm volatile("s_barrier\n\t");
#endif
#if 0
#else
// 计算 13-15
const int k_idx = k0_lds_loops - BUFFER_SIZE + 1;
flash::asm_ds_write(buffer[k_idx % BUFFER_SIZE], tKsK, k_idx);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
flash::asm_ds_write(buffer[(k_idx + 1) % BUFFER_SIZE], tKsK, k_idx + 1);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx + 1), tSrK_copy_view(_, _, k_idx + 1));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx + 1), tSrK(_, _, k_idx + 1), acc_s);
flash::asm_ds_write(buffer[(k_idx + 2) % BUFFER_SIZE], tKsK, k_idx + 2);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx + 2), tSrK_copy_view(_, _, k_idx + 2));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx + 2), tSrK(_, _, k_idx + 2), acc_s);
// asm volatile("s_barrier\n\t");
// 读取16-17
flash::buffer_load_copy<false, true, true>(gK, buffer[1], 16, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN);
flash::buffer_load_copy<false, true, true>(gK, buffer[2], 17, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN);
flash::buffer_to_tensor(buffer[1], tSrK_smem, 16);
cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s);
flash::buffer_to_tensor(buffer[2], tSrK_smem, 17);
cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s);
asm volatile("s_barrier\n\t");
#endif
const bool is_masking_step = masking_step > 0;
const bool is_first_masking_step = masking_step == n_masking_steps;
if (is_masking_step) {
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tScS = thr_mma.partition_C(cS);
for (int i = 0; i < size(acc_s); ++i) {
if constexpr (!T::Is_causal) {
if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY;
} else {
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int row = int(get<0>(tScS(i)));
int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.q_seq_per_hk - 1 - (m_block * kBlockM + row)) / params.q_head_per_hk;
if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY;
}
}
}
// We have key_padding_mask so we'll need to Check_inf
is_first_masking_step
? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/T::Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2)
: is_masking_step ?
softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/T::Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2)
: softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(acc_s);
Tensor tOrP = flash::convert_layout_acc_Aregs(tiled_mma, tiled_mma_o, rP, sP);
__syncthreads();
#if 1
// 第15块已经读取到了buffer[3]中
flash::asm_ds_write(buffer[3], tVsV, 15);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
#endif
gK.data() = gK.data() + (-offset_k);
if (n_block > n_block_min) {
cur_block_table = block_table[n_block - 1];
offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK.data() + (offset_k);
flash::buffer_load_copy<true, true, false>(gK, buffer[0], 0, params.k_row_stride, offset_k);
flash::buffer_load_copy<true, true, false>(gK, buffer[1], 1, params.k_row_stride, offset_k);
flash::buffer_load_copy<true, true, false>(gK, buffer[2], 2, params.k_row_stride, offset_k);
}
Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt);
#pragma unroll
for (int i = 0; i < k1_loops; i++) {
cute::copy(smem_tiled_copy_V, tOsVt(_, _, i), tOrVt_copy_view(_, _, i));
cute::gemm(tiled_mma_o, tOrP(_, _, i), tOrVt(_, _, i), acc_o);
}
asm volatile(" s_barrier\n\t");
}
#endif
using ElementAccum = float;
if (NoSplit)
{
using ElementO = Element;
const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_lse = (bidb * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM;
constexpr bool Split = false;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + ( row_offset_o)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(acc_o, sRow_sum_reduce_buffer, params.scale_softmax);
// if (block0() && tidx < 64)
// {
// printf(" %.3f %.3f \n", float(acc_o(0)), float(acc_o(1)));
// }
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lse)),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma_o.partition_C(caccO);
Tensor rO = flash::convert_type<ElementO>(acc_o);
if (get<1>(taccOcO(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO(0, mi, 0));
if (row < params.q_seq_per_hk - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
}
}
{
// using result_type = cutlass::Array<bfloat16_t, 2>;
// int tidx = threadIdx.x;
int col = 0;
int warpid = tidx / 64;
for (int m = 0; m < 1; m++) {
const int row = tidx % 16;
if (row < params.q_seq_per_hk - m_block * kBlockM) {
for (int n = 0; n < size<2>(acc_o); n++) {
col = (tidx % 64 / 16) + warpid * 32 + n * 128;
for (int ei = 0; ei < 8; ei ++) {
gOaccum(row, col) = rO(ei, m, n);
col += 4;
}
}
}
}
}
}
else
{
using ElementO = float;
int split_idx = params.num_splits_ptr[bidb] + n_split_idx;
constexpr bool Split = true;
const index_t row_offset_oaccum = ((split_idx*params.h_k + bidh)*params.q_seq_per_hk + m_block * kBlockM)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1)
const index_t row_offset_lseaccum = (split_idx*params.h_k + bidh)*params.q_seq_per_hk + m_block * kBlockM;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (row_offset_oaccum)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lseaccum)),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(acc_o, sRow_sum_reduce_buffer, params.scale_softmax);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma_o.partition_C(caccO);
if (get<1>(taccOcO(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO(0, mi, 0));
if (row < params.q_seq_per_hk - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
}
}
{
// using result_type = cutlass::Array<bfloat16_t, 2>;
// int tidx = threadIdx.x;
int col = 0;
int warpid = tidx / 64;
for (int m = 0; m < 1; m++) {
const int row = tidx % 16;
if (row < params.q_seq_per_hk - m_block * kBlockM) {
for (int n = 0; n < size<2>(acc_o); n++) {
col = (tidx % 64 / 16) + warpid * 32 + n * 128;
for (int ei = 0; ei < 8; ei ++) {
gOaccum(row, col) = acc_o(ei, m, n);
col += 4;
}
}
}
}
}
}
}
template<typename T>
__global__ void __launch_bounds__(T::NUM_THREADS, 1)
......@@ -1199,9 +1557,15 @@ flash_fwd_splitkv_mla_kernel(const DenseAttnDecodeParams params) {
if (batch_idx > sched_meta.begin_req_idx) {
__syncthreads();
}
#if defined(__gfx928__)
compute_attn_1rowblock_splitkv_mla_gfx928<T>(params, batch_idx, bidh, m_block, n_split_idx,
seqlen_k, start_block_idx, end_block_idx, is_no_split
);
#else
compute_attn_1rowblock_splitkv_mla_gfx936<T>(params, batch_idx, bidh, m_block, n_split_idx,
seqlen_k, start_block_idx, end_block_idx, is_no_split
);
#endif
}
}
......@@ -1209,6 +1573,7 @@ flash_fwd_splitkv_mla_kernel(const DenseAttnDecodeParams params) {
template<typename T, bool use_split_kv=false>
__global__ void __launch_bounds__(T::NUM_THREADS, 1)
flash_fwd_splitkv_mla_block_m_64_kernel(const DenseAttnDecodeParams params) {
#if defined(__gfx936__) || defined(__gfx938__)
constexpr int kBlockN = T::PAGE_BLOCK_SIZE;
const int m_block = blockIdx.x;
const int bidh = blockIdx.y;
......@@ -1908,7 +2273,7 @@ flash_fwd_splitkv_mla_block_m_64_kernel(const DenseAttnDecodeParams params) {
}
}
}
#endif
}
......@@ -1918,7 +2283,7 @@ void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams &params) {
FLASH_ASSERT(params.d_v == Config::HEAD_DIM_V);
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (params.h_q >= 64 && params.h_k == 1) {
if (params.h_q >= 64 && params.h_k == 1 && !params.is_gfx928) {
using T = Traits_Block_M_64<InputT, Is_causal>;
constexpr size_t smem_size = 16384 + 4096;
if (params.use_split_kv)
......
......@@ -102,6 +102,14 @@ struct Traits {
GmemLayoutAtomQ{},
Layout<Shape<_1, _8>>{}));
using GmemLayoutAtomK = Layout<Shape <_64, _4>,
Stride< _4, _1>>;
using GmemTiledCopyK = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomK{},
Layout<Shape<_1, _8>>{}));
using GmemTiledCopyV = GmemTiledCopyK;
struct SharedMemoryPlan {
......
......@@ -969,7 +969,9 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::devfunc(const SparseAttnD
template<typename Kernel>
__global__ void __launch_bounds__(Kernel::NUM_THREADS, 1)
flash_fwd_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams params) {
#if defined(__gfx936__) || defined(__gfx938__)
Kernel::devfunc(params);
#endif
}
template<ModelType MODEL_TYPE, int NUM_HEADS>
......
......@@ -1287,9 +1287,9 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
template<typename Kernel>
__global__ void __launch_bounds__(Kernel::NUM_THREADS, 1)
sparse_attn_fwd_kernel(const SparseAttnFwdParams params) {
// #if defined(__gfx936__)
#if defined(__gfx936__) || defined(__gfx938__)
Kernel::devfunc(params);
// #endif
#endif
}
template<int D_QK, bool HAVE_TOPK_LENGTH>
......
......@@ -61,7 +61,7 @@ struct DenseAttnDecodeParams { // TODO Change name to DenseAttnDecodeParams
bool use_split_kv;
int partition_block_nums;
bool is_gfx928;
};
struct DenseAttnDecodeParams_fp8 : public DenseAttnDecodeParams {
......
......@@ -621,7 +621,17 @@ lds_direct_copy_for_prefill_sparse_mla(
:);
}
template<
class SrcEngine, class SrcLayout>
CUTE_HOST_DEVICE
void
asm_ds_write(const uint128_t & src, Tensor<SrcEngine, SrcLayout> & dst, int k_idx)
{
uint128_t* d = reinterpret_cast<uint128_t*>(&dst(0, 0, k_idx));
d[0] = src;
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
......
......@@ -31,7 +31,7 @@ def get_features_args():
def get_arch_flags():
arch_flags = []
arch_flags.append("--offload-arch=gfx938;gfx936")
arch_flags.append("--offload-arch=gfx938;gfx936;gfx928")
return arch_flags
# def get_nvcc_thread_args():
......
......@@ -140,6 +140,9 @@ def reference_torch(
out_ref = out_ref.to(q.dtype)
return out_ref, lse_ref
def get_gcn_arch_name() -> str:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return GPU_ARCH.split(':')[0]
@torch.inference_mode()
def test_flash_mla(t: TestParam):
......@@ -166,6 +169,12 @@ def test_flash_mla(t: TestParam):
out_ans, lse_ans = run_flash_mla()
out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal)
if get_gcn_arch_name() == "gfx928":
lse_abs_diff = (lse_ans - lse_ref).max().abs().item()
out_abs_diff = (out_ref - out_ans).max().abs().item()
print("lse_abs_diff ", lse_abs_diff, out_abs_diff)
assert out_abs_diff <= 4e-3
else:
is_correct = True
is_correct &= kk.check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=5e-6)
is_correct &= kk.check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01 / 65536)
......
......@@ -214,7 +214,14 @@ def main(torch_dtype, is_prof=False):
for varlen in [False]:
test_flash_mla_fp8_e5m2(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
def get_gcn_arch_name() -> str:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return GPU_ARCH.split(':')[0]
if __name__ == "__main__":
if get_gcn_arch_name() == "gfx928":
print("[WARNING] gfx928 architecture is not supported.")
exit(0)
parser = argparse.ArgumentParser()
parser.add_argument(
"--dtype",
......
......@@ -175,9 +175,14 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
)
def get_gcn_arch_name() -> str:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return GPU_ARCH.split(':')[0]
def main(torch_dtype, is_prof=False):
if get_gcn_arch_name() != "gfx938":
print("[WARNING] The architecture is not supported.")
exit(0)
device = torch.device("cuda:0")
init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
torch.set_default_dtype(init_dtype)
......
......@@ -252,8 +252,14 @@ def main(torch_dtype, is_prof=False):
# '''
def get_gcn_arch_name() -> str:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return GPU_ARCH.split(':')[0]
if __name__ == "__main__":
if get_gcn_arch_name() != "gfx938":
print("[WARNING] The architecture is not supported.")
exit(0)
parser = argparse.ArgumentParser()
parser.add_argument(
"--dtype",
......
......@@ -232,8 +232,14 @@ def test_flash_mla(p: TestParam) -> Result:
performance_result.is_correct = is_correct
return performance_result
def get_gcn_arch_name() -> str:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return GPU_ARCH.split(':')[0]
def main():
if get_gcn_arch_name() == "gfx928":
print("[WARNING] gfx928 architecture is not supported.")
exit(0)
dtype = torch.bfloat16
device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
......
......@@ -51,8 +51,14 @@ def run_test(p: TestParam) -> bool:
else:
return True
def get_gcn_arch_name() -> str:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return GPU_ARCH.split(':')[0]
if __name__ == '__main__':
if get_gcn_arch_name() == "gfx928":
print("[WARNING] gfx928 architecture is not supported.")
exit(0)
device = torch.device("cuda:0")
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device(device)
......
......@@ -141,8 +141,14 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
)
def get_gcn_arch_name() -> str:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return GPU_ARCH.split(':')[0]
def main(torch_dtype, is_prof=False):
if get_gcn_arch_name() == "gfx928":
print("[WARNING] gfx928 architecture is not supported.")
exit(0)
device = torch.device("cuda:0")
torch.set_default_dtype(torch_dtype)
torch.set_default_device(device)
......
......@@ -168,7 +168,14 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
)
def get_gcn_arch_name() -> str:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return GPU_ARCH.split(':')[0]
def main(torch_dtype, is_prof=False):
if get_gcn_arch_name() == "gfx928":
print("[WARNING] gfx928 architecture is not supported.")
exit(0)
device = torch.device("cuda:0")
torch.set_default_dtype(torch_dtype)
torch.set_default_device(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment