Commit 892f7274 authored by zhanghj2's avatar zhanghj2
Browse files

支持kv 软fp8 e5m2

parent 11e445c3
......@@ -4,12 +4,13 @@
#include "sparse_decode.h"
#include "dense_decode.h"
#include "dense_decode_qkvfp8.h"
#include "dense_decode_kvfp8.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashMLA";
m.def("sparse_decode_fwd", &sparse_attn_decode_interface);
m.def("dense_decode_fwd", &dense_attn_decode_interface);
m.def("dense_decode_fwd_qkvfp8", &dense_attn_decode_qkvfp8_interface);
m.def("dense_decode_fwd_kvfp8", &dense_attn_decode_kvfp8_interface);
m.def("sparse_prefill_fwd", &sparse_attn_prefill_interface);
}
#pragma once
#include <cutlass/half.h>
#include <cutlass/fast_math.h>
#include "common.h"
#include "params.h"
#include "sm90/decode/dense_kvfp8/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h"
static std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
dense_attn_decode_kvfp8_interface(
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True)
const int head_size_v,
const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
const float softmax_scale,
bool is_causal,
std::optional<at::Tensor> &tile_scheduler_metadata, // num_sm_parts x (DecodingSchedMetaSize/4)
std::optional<at::Tensor> &num_splits, // batch_size + 1
std::optional<const at::Tensor> &descale_q,
std::optional<const at::Tensor> &descale_k
) {
// Check arch
Arch arch = Arch();
if (!arch.is_sm90a()) {
TORCH_CHECK(false, "Dense decode MLA is only supported on SM90a architecture");
}
// Check data types
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kBFloat16);
TORCH_CHECK(descale_q.has_value() && descale_k.has_value(), "descale is required when input dtype is fp8");
auto descale_q_ = descale_q.value();
auto descale_k_ = descale_k.value();
TORCH_CHECK(kcache.dtype() == torch::kFloat8_e5m2, "key must have the same dtype torch::kFloat8_e5m2");
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
// Check device
KU_CHECK_DEVICE(q);
KU_CHECK_DEVICE(kcache);
KU_CHECK_DEVICE(seqlens_k);
KU_CHECK_DEVICE(block_table);
KU_CHECK_DEVICE(tile_scheduler_metadata);
KU_CHECK_DEVICE(num_splits);
KU_CHECK_DEVICE(descale_q_);
KU_CHECK_DEVICE(descale_k_);
TORCH_CHECK(descale_q_.stride(-1) == 1);
TORCH_CHECK(descale_k_.stride(-1) == 1);
TORCH_CHECK(descale_q_.dtype() == torch::kFloat);
TORCH_CHECK(descale_k_.dtype() == torch::kFloat);
KU_CHECK_SHAPE(descale_q_, 1);
KU_CHECK_SHAPE(descale_k_, 1);
// Check layout
TORCH_CHECK(q.stride(-1) == 1, "q must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "kcache must have contiguous last dimension");
KU_CHECK_CONTIGUOUS(seqlens_k);
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q_ori = sizes[1];
const int num_heads_q = sizes[2];
const int head_size_k = sizes[3];
TORCH_CHECK(head_size_k == 576 || head_size_k == 512, "Only head_size_k == 576 or 512 is supported");
TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported");
const int max_num_blocks_per_seq = block_table.size(1);
const int num_blocks = kcache.size(0);
const int page_block_size = kcache.size(1);
const int num_heads_k = kcache.size(2);
TORCH_CHECK(page_block_size == 64, "Currently page_block_size must be 64");
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q_ori == 1) { is_causal = false; }
const int num_q_heads_per_hk = num_heads_q / num_heads_k;
const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk;
const int num_heads = num_heads_k;
q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3)
.reshape({batch_size, q_seq_per_hk, num_heads, head_size_k});
int num_sm_parts = std::max(arch.num_sms / num_heads_k / cutlass::ceil_div(seqlen_q_ori*num_heads_q/num_heads_k, 16) * 2, 1);
KU_CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k);
KU_CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
KU_CHECK_SHAPE(seqlens_k, batch_size);
KU_CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, DecodingSchedMetaSize/sizeof(int));
KU_CHECK_SHAPE(num_splits, batch_size+1);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({batch_size, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(torch::kBFloat16));
at::Tensor lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
KU_CHECK_CONTIGUOUS(out);
KU_CHECK_CONTIGUOUS(lse);
if (!tile_scheduler_metadata.has_value()) {
tile_scheduler_metadata = torch::empty({num_sm_parts, sizeof(DecodingSchedMeta)/4}, opts.dtype(torch::kInt32));
num_splits = torch::empty({batch_size+1}, opts.dtype(torch::kInt32));
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
GetDecodeSchedMetaParams get_sched_meta_params = {
batch_size, seqlen_q_ori,
64,
5,
-1, -1,
nullptr, nullptr,
seqlens_k.data_ptr<int>(),
(DecodingSchedMeta*)tile_scheduler_metadata->data_ptr(),
num_splits->data_ptr<int>(),
num_sm_parts,
at::cuda::getCurrentCUDAStream().stream()
};
smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
} else {
KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);
KU_CHECK_DTYPE(num_splits, torch::kInt32);
KU_CHECK_DEVICE(tile_scheduler_metadata);
KU_CHECK_DEVICE(num_splits);
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, sizeof(DecodingSchedMeta)/sizeof(int));
KU_CHECK_SHAPE(num_splits, batch_size+1);
}
// Set the sizes
DenseAttnDecodeParams_fp8 params;
params.b = batch_size;
params.s_q = seqlen_q_ori;
params.q_seq_per_hk = q_seq_per_hk;
params.seqlens_k_ptr = seqlens_k.data_ptr<int>();
params.h_q = num_heads_q;
params.h_k = num_heads_k;
params.num_blocks = num_blocks;
params.q_head_per_hk = num_q_heads_per_hk;
params.is_causal = is_causal;
params.d = head_size_k;
params.d_v = head_size_v;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = kcache.data_ptr();
params.o_ptr = out.data_ptr();
params.softmax_lse_ptr = lse.data_ptr<float>();
// All stride are in elements, not bytes.
params.q_batch_stride = q.stride(0);
params.k_batch_stride = kcache.stride(0);
params.o_batch_stride = out.stride(0);
params.q_row_stride = q.stride(1);
params.k_row_stride = kcache.stride(1);
params.o_row_stride = out.stride(2);
params.q_head_stride = q.stride(2);
params.k_head_stride = kcache.stride(2);
params.o_head_stride = out.stride(1);
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;
params.tile_scheduler_metadata_ptr = (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr();
params.num_sm_parts = num_sm_parts;
params.num_splits_ptr = num_splits->data_ptr<int>();
params.descale_q_ptr = descale_q_.data_ptr<float>();
params.descale_k_ptr = descale_k_.data_ptr<float>();;
const int total_num_splits = batch_size + params.num_sm_parts;
at::Tensor lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat));
KU_CHECK_CONTIGUOUS(lse_accum);
KU_CHECK_CONTIGUOUS(out_accum);
params.total_num_splits = total_num_splits;
params.softmax_lseaccum_ptr = lse_accum.data_ptr<float>();
params.oaccum_ptr = out_accum.data_ptr<float>();
params.stream = at::cuda::getCurrentCUDAStream().stream();
if (q_dtype == torch::kBFloat16) {
sm90::run_flash_splitkv_mla_kvfp8_kernel<cutlass::bfloat16_t>(params);
} else {
TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90");
}
CombineParams combine_params = {
batch_size, seqlen_q_ori,
num_heads_q, head_size_v,
params.softmax_lse_ptr,
params.o_ptr,
num_heads*q_seq_per_hk, num_heads_q,
num_heads_q*seqlen_q_ori*head_size_v, num_heads_q*head_size_v, head_size_v,
params.softmax_lseaccum_ptr,
params.oaccum_ptr,
num_heads*q_seq_per_hk, num_heads_q,
num_heads_q*seqlen_q_ori*head_size_v, num_heads_q*head_size_v, head_size_v,
params.tile_scheduler_metadata_ptr,
params.num_splits_ptr,
params.num_sm_parts,
nullptr,
at::cuda::getCurrentCUDAStream().stream()
};
smxx::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params);
out = out.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk, head_size_v}).transpose(1, 2)
.reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v});
lse = lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3)
.reshape({batch_size, num_heads_q, seqlen_q_ori});
return {out, lse, tile_scheduler_metadata, num_splits};
}
......@@ -183,3 +183,10 @@ inline constexpr bool is_decode_v = std::bool_constant<FWD_MODE == SparseAttnFwd
template<SparseAttnFwdMode FWD_MODE>
using SparseFwdArgT = std::conditional_t<is_decode_v<FWD_MODE>, SparseAttnDecodeParams, SparseAttnFwdParams>;
enum class Fp8KVCacheDataType {
kAuto = 0,
kFp8E4M3 = 1,
kFp8E5M2 = 2,
kInt8 = 3,
};
#pragma once
namespace Config {
static constexpr int BLOCK_SIZE_M = 16;
static constexpr int PAGE_BLOCK_SIZE = 64;
static constexpr int HEAD_DIM_K = 576;
static constexpr int HEAD_DIM_V = 512;
}
#include "../splitkv_mla.cuh"
#include "../splitkv_mla.h"
namespace sm90 {
template void run_flash_splitkv_mla_kvfp8_kernel<cutlass::bfloat16_t>(DenseAttnDecodeParams_fp8 &params);
}
#include <cutlass/cutlass.h>
#include "utils.h"
#include "params.h"
#include "config.h"
#include "traits.h"
#include "softmax.h"
using namespace cute;
namespace sm90 {
template<typename T>
__device__ void
compute_attn_1rowblock_splitkv_mla_kvfp8(const DenseAttnDecodeParams_fp8 &params,
const int bidb, const int bidh, const int m_block,
const int n_split_idx, const int seqlen_k,
const int n_block_min, const int n_block_max, const bool NoSplit)
{
constexpr static bool Is_causal = T::Is_causal;
constexpr int kBlockM = T::kBlockM;
constexpr int kBlockN = T::kBlockN;
constexpr int kHeadDim = T::kHeadDim;
constexpr int kHeadDimV = T::kHeadDimV;
const int tidx = threadIdx.x;
const int lane_idx = tidx % 64;
extern __shared__ char shared_memory[];
using SharedMemoryPlan = typename T::SharedMemoryPlan;
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(shared_memory);
using index_t = int64_t;
using Element = typename T::Element;
const index_t row_offset_k = (bidh) * params.k_head_stride;
const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<uint8_t *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<uint8_t *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.k_row_stride, _1{}));
Tensor sQ = make_tensor(make_smem_ptr(plan.smem_q.data()), typename T::SmemLayoutQ{});
Tensor sV = make_tensor(make_smem_ptr(plan.smem_v.data()), typename T::SmemLayoutV{});
Tensor sK = make_tensor(make_smem_ptr(plan.smem_v.data()), typename T::SmemLayoutK{});
Tensor sP = make_tensor(make_smem_ptr(plan.smem_p.data()), typename T::SmemLayoutP{});
Tensor sVt = make_tensor(sV.data(), typename T::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename T::SmemLayoutVtransposedNoSwizzle{});
Tensor sVtNoSwizzle_fp8 = make_tensor(sV.data(), typename T::SmemLayoutVtransposedNoSwizzle_fp8{});
Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(plan.smem_row_max.data()), typename T::SmemLayoutRow{});
Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(plan.smem_row_sum.data()), typename T::SmemLayoutRow{});
typename T::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
typename T::TiledMma_16_16_32 tiled_mma_16_16_32;
auto thr_mma_16_16_32 = tiled_mma_16_16_32.get_thread_slice(tidx);
typename T::TiledMma_O_16_32_16 tiled_mma_o_16_32_16;
auto thr_mma_o_16_32_16 = tiled_mma_o_16_32_16.get_thread_slice(tidx);
typename T::TiledMma_int8 tiled_mma_int8;
auto thr_mma_int8 = tiled_mma_int8.get_thread_slice(tidx);
typename T::TiledMma_O tiled_mma_o;
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
// 过lds读取q, 由于q是4个warp共用的
typename T::GmemTiledCopyQ gmem_tiled_copy_Q;
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ)));
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ);
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQgQ)));
if (threadIdx.x < 128)
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, false>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
params.q_seq_per_hk - m_block * kBlockM);
__syncthreads();
auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ);
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
__syncthreads();
auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom<DefaultCopy, Element>{}, tiled_mma_16_16_32);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
Tensor tSrK = thr_mma.partition_fragment_B(gK);
Tensor tSrK_int8 = thr_mma_int8.partition_fragment_B(gK);
auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_DS_M32x16_B16, Element>{}, tiled_mma_o_16_32_16);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
Tensor tOrVt = thr_mma_o.partition_fragment_B(sVtNoSwizzle_fp8);
constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
int n_block = n_block_max - 1;
const auto sk_data = sK.data();
const auto sRow_max_reduce_buffer_data = sRow_max_reduce_buffer.data();
constexpr auto sk_size = size(sK);
const auto sP_data = sP.data();
const auto tSsK_data = tSsK.data();
const auto tOsVt_data = tOsVt.data();
const auto gK_data = gK.data();
constexpr static int BUFFER_SIZE = 1;
constexpr short int wait_cnt = 8;
{
int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block;
// cur_block_table = block_table[n_block - 1];
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK_data + (offset_k);
sK.data() = n_block % 2 == 1 ? sk_data + sk_size : sk_data;
#pragma unroll
for (int i = 0; i < 8; i++) {
flash::lds_direct_copy_fp8<false, true>(gK, sK, i, params.k_row_stride, seqlen_k - n_block * kBlockN);
}
}
constexpr static Fp8KVCacheDataType KV_DTYPE = Fp8KVCacheDataType::kFp8E5M2;
constexpr static bool is_scale_equal_one = true;
const float k_scale = 1.0;
Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{});
clear(acc_o);
flash::Softmax<size<1>(acc_o)> softmax;
for (int masking_step = 0; masking_step < n_masking_steps && n_block >= n_block_min; ++masking_step, --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s);
Tensor tSrK_int8_copy_view = smem_thr_copy_K.retile_D(tSrK_int8);
{
tSsK.data() = n_block % 2 == 1 ? tSsK_data + sk_size : tSsK_data;
uint32x4_t buffer[BUFFER_SIZE];
flash::buffer_load_copy_fp8<false, true, true, true>(gK, buffer[0], 8, params.k_row_stride, 0, seqlen_k - n_block * kBlockN);
#if 0
#else
flash::gemm_rs_fp8<Element, is_scale_equal_one, KV_DTYPE>(acc_s, tSrQ, tSrK_int8, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k_scale);
// asm volatile("s_barrier\n\t");
asm volatile("s_waitcnt vmcnt(0) \n\t \n\t");
flash::gemm_k_rs_fp8<Element, 8, is_scale_equal_one, KV_DTYPE>(acc_s, tSrQ, tSrK, tiled_mma, buffer[0], k_scale);
#endif
// asm volatile("s_barrier\n\t");
}
// if (block0()) {
// printf(" tid = %d %.2f %.2f %.2f %.2f \n",tidx, acc_s(0), acc_s(1), acc_s(2), acc_s(3));
// }
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tScS = thr_mma.partition_C(cS);
for (int i = 0; i < size(acc_s); ++i) {
// if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && !is_scale_equal_one && std::is_same_v<Element, cutlass::bfloat16_t>) {
// acc_s(i) *= k_scale;
// }
if constexpr (!Is_causal) {
if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY;
} else {
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int row = int(get<0>(tScS(i)));
int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.q_seq_per_hk - 1 - (m_block * kBlockM + row)) / params.q_head_per_hk;
if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY;
}
}
sRow_max_reduce_buffer.data() = n_block % 2 == 1 ? sRow_max_reduce_buffer_data + (-8192) : sRow_max_reduce_buffer_data;
if constexpr (n_masking_steps == 1) {
softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2);
}
else {
const bool is_first_masking_step = masking_step == 0;
is_first_masking_step
? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2)
: softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2);
}
Tensor rP = flash::convert_type<Element>(acc_s);
sP.data() = n_block % 2 == 1 ? sP_data + (-sk_size) : sP_data;
Tensor tOrP = flash::convert_layout_acc_Aregs(tiled_mma, tiled_mma_o, rP, sP);
__syncthreads();
if (n_block > n_block_min)
{
int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block - 1;
// cur_block_table = block_table[n_block - 1];
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK_data + (offset_k);
sK.data() = (n_block - 1) % 2 ? sk_data + sk_size : sk_data;
sRow_max_reduce_buffer.data() = (n_block - 1) % 2 ? sRow_max_reduce_buffer_data + (-8192) : sRow_max_reduce_buffer_data;
sP.data() = (n_block - 1) % 2 ? sP_data + (-sk_size) : sP_data;
tSsK.data() = (n_block - 1) % 2 ? tSsK_data + sk_size : tSsK_data;
#pragma unroll
for (int i = 0; i < 8; i++) {
flash::lds_direct_copy_fp8<true, true>(gK, sK, i, params.k_row_stride, seqlen_k - (n_block - 1) * kBlockN);
}
// buffer_load_copy_fp8<false, true, true, true>(gK, buffer[0], 8, params.k_row_stride, offset_k, seqlen_k - (n_block - 1) * kBlockN);
// gK.data() = gK.data() + (-offset_k);
}
{
tOsVt.data() = (n_block) % 2 ? tOsVt_data + sk_size : tOsVt_data;
#if 0
#else
flash::gemm1_rs_fp8<Element, is_scale_equal_one, KV_DTYPE>(acc_o, tOrP, tOrVt, tOsVt, tiled_mma_o, smem_tiled_copy_V, smem_thr_copy_V, k_scale);
#endif
}
}
for (; n_block >= n_block_min; --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s);
Tensor tSrK_int8_copy_view = smem_thr_copy_K.retile_D(tSrK_int8);
{
tSsK.data() = n_block % 2 == 1 ? tSsK_data + sk_size : tSsK_data;
uint32x4_t buffer[BUFFER_SIZE];
flash::buffer_load_copy_fp8<true, true, true, true>(gK, buffer[0], 8, params.k_row_stride, 0, seqlen_k - (n_block - 1) * kBlockN);
#if 0
#else
flash::gemm_rs_fp8<Element, is_scale_equal_one, KV_DTYPE>(acc_s, tSrQ, tSrK_int8, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k_scale);
#endif
// asm volatile("s_barrier\n\t");
asm volatile("s_waitcnt vmcnt(0) \n\t \n\t");
flash::gemm_k_rs_fp8<Element, 8, is_scale_equal_one, KV_DTYPE>(acc_s, tSrQ, tSrK, tiled_mma, buffer[0], k_scale);
// asm volatile("s_barrier\n\t");
}
sRow_max_reduce_buffer.data() = n_block % 2 == 1 ? sRow_max_reduce_buffer_data + (-8192) : sRow_max_reduce_buffer_data;
softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(acc_s);
sP.data() = n_block % 2 == 1 ? sP_data + (-sk_size) : sP_data;
Tensor tOrP = flash::convert_layout_acc_Aregs(tiled_mma, tiled_mma_o, rP, sP);
__syncthreads();
if (n_block > n_block_min)
{
int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block - 1;
// cur_block_table = block_table[n_block - 1];
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK_data + (offset_k);
sK.data() = (n_block - 1) % 2 ? sk_data + sk_size : sk_data;
// sRow_max_reduce_buffer.data() = (n_block - 1) % 2 ? sRow_max_reduce_buffer_data + (-8192) : sRow_max_reduce_buffer_data;
// sP.data() = (n_block - 1) % 2 ? sP_data + (-sk_size) : sP_data;
// tSsK.data() = (n_block - 1) % 2 ? tSsK_data + sk_size : tSsK_data;
#pragma unroll
for (int i = 0; i < 8; i++) {
flash::lds_direct_copy_fp8<true, true>(gK, sK, i, params.k_row_stride, seqlen_k - (n_block - 1) * kBlockN);
}
// buffer_load_copy_fp8<false, true, true, true>(gK, buffer[0], 8, params.k_row_stride, offset_k, seqlen_k - (n_block - 1) * kBlockN);
// gK.data() = gK.data() + (-offset_k);
}
{
tOsVt.data() = (n_block) % 2 ? tOsVt_data + sk_size : tOsVt_data;
#if 0
#else
flash::gemm1_rs_fp8<Element, is_scale_equal_one, KV_DTYPE>(acc_o, tOrP, tOrVt, tOsVt, tiled_mma_o, smem_tiled_copy_V, smem_thr_copy_V, k_scale);
#endif
// tOsVt.data() = (n_block - 1) % 2 ? tOsVt_data + sk_size : tOsVt_data;
}
}
using ElementAccum = float;
if (NoSplit)
{
using ElementO = Element;
const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_lse = (bidb * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM;
constexpr bool Split = false;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + ( row_offset_o)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(acc_o, sRow_sum_reduce_buffer, params.scale_softmax);
// if (block0() && tidx < 64)
// {
// printf(" %.3f %.3f \n", float(acc_o(0)), float(acc_o(1)));
// }
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lse)),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma_o.partition_C(caccO);
Tensor rO = flash::convert_type<ElementO>(acc_o);
if (get<1>(taccOcO(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO(0, mi, 0));
if (row < params.q_seq_per_hk - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
}
}
{
// using result_type = cutlass::Array<bfloat16_t, 2>;
// int tidx = threadIdx.x;
int col = 0;
int warpid = tidx / 64;
for (int m = 0; m < 1; m++) {
const int row = tidx % 16;
if (row < params.q_seq_per_hk - m_block * kBlockM) {
for (int n = 0; n < size<2>(acc_o); n++) {
col = (tidx % 64 / 16) * 2 + warpid * 64 + n * 256;
for (int ei = 0; ei < 16; ei += 2) {
gOaccum(row, col) = rO(ei, m, n);
gOaccum(row, col + 1) = rO(ei + 1, m, n);
col += 8;
}
}
}
}
}
}
else
{
using ElementO = float;
int split_idx = params.num_splits_ptr[bidb] + n_split_idx;
constexpr bool Split = true;
const index_t row_offset_oaccum = ((split_idx*params.h_k + bidh)*params.q_seq_per_hk + m_block * kBlockM)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1)
const index_t row_offset_lseaccum = (split_idx*params.h_k + bidh)*params.q_seq_per_hk + m_block * kBlockM;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (row_offset_oaccum)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lseaccum)),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(acc_o, sRow_sum_reduce_buffer, params.scale_softmax);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma_o.partition_C(caccO);
if (get<1>(taccOcO(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO(0, mi, 0));
if (row < params.q_seq_per_hk - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
}
}
{
// using result_type = cutlass::Array<bfloat16_t, 2>;
// int tidx = threadIdx.x;
int col = 0;
int warpid = tidx / 64;
for (int m = 0; m < 1; m++) {
const int row = tidx % 16;
if (row < params.q_seq_per_hk - m_block * kBlockM) {
for (int n = 0; n < size<2>(acc_o); n++) {
col = (tidx % 64 / 16) * 2 + warpid * 64 + n * 256;
for (int ei = 0; ei < 16; ei += 2) {
gOaccum(row, col) = acc_o(ei, m, n);
gOaccum(row, col + 1) = acc_o(ei + 1, m, n);
col += 8;
}
}
}
}
}
}
}
template<typename T>
__global__ void __launch_bounds__(T::NUM_THREADS, 1)
flash_fwd_splitkv_mla_kvfp8_kernel(const DenseAttnDecodeParams_fp8 params) {
const int m_block = blockIdx.x;
const int bidh = blockIdx.y;
const int partition_idx = blockIdx.z;
DecodingSchedMeta sched_meta = params.tile_scheduler_metadata_ptr[partition_idx];
// if (thread0())
// {
// printf("m_block = %d sched_meta.begin_req_idx = %d \n ", m_block, sched_meta.begin_req_idx);
// }
if (sched_meta.begin_req_idx >= params.b) return;
for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) {
constexpr int kBlockN = T::PAGE_BLOCK_SIZE;
const int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0;
int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx);
const int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0;
int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : cute::ceil_div(seqlen_k, kBlockN);
const bool is_no_split = batch_idx == sched_meta.begin_req_idx ? !sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? !sched_meta.is_last_req_splitted : true);
if (batch_idx > sched_meta.begin_req_idx) {
__syncthreads();
}
#if defined(__gfx936__) || defined(__gfx938__)
compute_attn_1rowblock_splitkv_mla_kvfp8<T>(params, batch_idx, bidh, m_block, n_split_idx,
seqlen_k, start_block_idx, end_block_idx, is_no_split
);
#endif
}
}
template<typename InputT>
void run_flash_splitkv_mla_kvfp8_kernel(DenseAttnDecodeParams_fp8 &params) {
FLASH_ASSERT(params.d == Config::HEAD_DIM_K);
FLASH_ASSERT(params.d_v == Config::HEAD_DIM_V);
constexpr size_t smem_size = 65536;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
using T = Traits<InputT, Is_causal>;
const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M);
auto mla_kernel = &flash_fwd_splitkv_mla_kvfp8_kernel<T>;
mla_kernel<<<dim3(num_m_block, params.h_k, params.num_sm_parts), T::NUM_THREADS, smem_size, params.stream>>>(params);
});
CHECK_CUDA_KERNEL_LAUNCH();
}
}
#pragma once
#include "params.h"
namespace sm90 {
template<typename InputT>
void run_flash_splitkv_mla_kvfp8_kernel(DenseAttnDecodeParams_fp8 &params);
}
#pragma once
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_types.h>
#include <cutlass/barrier.h>
#include "config.h"
using namespace cute;
template<typename InputT_, bool Is_causal_>
struct Traits {
using InputT = InputT_;
static constexpr bool Is_causal = Is_causal_;
static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M;
static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE;
static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K;
static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V;
static constexpr int NUM_THREADS = 256;
static_assert(std::is_same_v<InputT, cutlass::bfloat16_t>);
static constexpr int kBlockM = BLOCK_SIZE_M;
static constexpr int kBlockN = PAGE_BLOCK_SIZE;
static constexpr int kHeadDim = HEAD_DIM_K;
static constexpr int kHeadDimV = HEAD_DIM_V;
static constexpr int kNWarps = 4;
static constexpr int kSwizzle = 3;
using Element = InputT;
using elem_type = Element;
using ElementAccum = float;
using ValLayoutMNK = Layout<Shape<_1, _1, _1>>;
using SmemLayoutRow = Layout<Shape<_128>, Stride<_1>>;
using SmemLayoutAtomK = decltype(composition(
Swizzle<3, 3, 3>{},
Layout<Shape<Int<8>, Int<32>>, Stride<Int<32>, _1>>{}));
using SmemLayoutK = decltype(tile_to_shape(
SmemLayoutAtomK{},
Shape<Int<kBlockN>, Int<8 * 32>>{}));
using SmemLayoutAtomV = SmemLayoutAtomK;
using SmemLayoutV = decltype(tile_to_shape(
SmemLayoutAtomV{},
Shape<Int<kBlockN>, Int<8 * 32>>{}));
using SmemLayoutAtomP = Layout<Shape<Int<4*16*16>>, Stride<Int<1>>>;
using SmemLayoutP = decltype(tile_to_shape(
SmemLayoutAtomP{},
Shape<Int<4*16*16>>{}));
using SmemLayoutVtransposed = decltype(
composition(SmemLayoutV{}, make_layout(Shape<Int<8 * 32>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
using SmemLayoutAtomV_fp8 = Layout<Shape<Int<kBlockN>, Int<512>>, Stride<_512, _1>>;
using SmemLayoutV_fp8 = decltype(tile_to_shape(
SmemLayoutAtomV_fp8{},
Shape<Int<kBlockN>, Int<512>>{}));
using SmemLayoutVtransposed_fp8 = decltype(
composition(SmemLayoutV_fp8{}, make_layout(Shape<Int<512>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutVtransposedNoSwizzle_fp8 = decltype(get_nonswizzle_portion(SmemLayoutVtransposed_fp8{}));
using SmemLayoutAtomQ = decltype(composition(
Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<Int<8>, Int<64>>, Stride<Int<64>, _1>>{}));
using SmemLayoutQ = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using MMA_Atom_Arch_16_16_32 = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<GFX928_16x16x32_F32F16F16F32_NT>,
MMA_Atom<GFX928_16x16x32_F32BF16BF16F32_NT>
>;
using TiledMma_16_16_32 = TiledMMA<
MMA_Atom_Arch_16_16_32,
Layout<Shape<_1, Int<4>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
using MMA_Atom_Arch = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<GFX928_16x16x64_F32F16F16F32_NT>,
MMA_Atom<GFX928_16x16x64_F32BF16BF16F32_NT>
>;
using MMA_Atom_Arch_16_32_16 = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<GFX928_16x32x16_F32F16F16F32_NT>,
MMA_Atom<GFX928_16x32x16_F32BF16BF16F32_NT>
>;
using TiledMma_O_16_32_16 = TiledMMA<
MMA_Atom_Arch_16_32_16,
Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
using TiledMma = TiledMMA<
MMA_Atom_Arch,
Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
using MMA_Atom_Arch_int8 = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<GFX928_16x16x64_F32F16uint8F32_NT>,
MMA_Atom<GFX928_16x16x64_F32BF16int8F32_NT>
>;
using MMA_Atom_Arch_16x64 = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<GFX928_16x64x16_FP8_F32F16F16F32_NT>,
MMA_Atom<GFX928_16x64x16_FP8_F32BF16BF16F32_NT>
>;
using TiledMma_O = TiledMMA<
MMA_Atom_Arch_16x64,
Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
using TiledMma_int8 = TiledMMA<
MMA_Atom_Arch_int8,
Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
using GmemLayoutAtomQ = Layout<Shape <_32, _8>,
Stride< _8, _1>>;
using GmemTiledCopyQ = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomQ{},
Layout<Shape<_1, _8>>{}));
struct SharedMemoryPlan {
union {
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v; // Double buffer
cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_sum;
cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_max;
};
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
};
};
};
};
......@@ -9,6 +9,7 @@
#include <cutlass/numeric_types.h>
#include <cute/tensor.hpp>
#include "defines.h"
#include "params.h"
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
......@@ -943,5 +944,561 @@ __forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, int
dst = d;
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool Is_load_Q=false,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy_fp8(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_, const int row_stride,
const int max_MN=0)
{
if constexpr (Is_load_Q) {
} else {
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 16;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 64*64;
// int row = lane / 4;
// int col = lane % 4;
// int swizzle_col = ((row / 2) ^ (col )) * 4 + (col % 4);
// 此处待优化,后8行,行号需要交换
int virtual_row = lane / 8;
int virtual_col = lane % 8;
int swizzle_col = virtual_row ^ virtual_col;
int row = lane / 4;
// 8->9 9->8
row = (row >= 8 ) ^ row;
// row = row >= 8 ? (swizzle_col / 4) > 0 ? row + 1 : row - 1 : row;
int col = swizzle_col % 4;
int row_offset = row + (warp_id * 16) ;
int col_offset = col * elements_per_thread + k_idx * 64;
int offset_v = row_offset * row_stride + (col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
// if (thread(0))
// {
// printf("offset_v = %d %d \n", offset_v, warp_id * bytes_per_warp + k_idx * mma_k * element_size);
// }
#if defined(__gfx936__) || defined(__gfx938__)
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
}
}
__forceinline__ __device__ cutlass::bfloat16_t fp8e4m3_to_bf16(const fp8& input) {
const uint16_t w = (uint16_t)input << 8;
const uint16_t sign = w & UINT16_C(0x8000);
const uint16_t nonsign = w & UINT16_C(0x7FFF);
constexpr uint16_t exp_offset=(0x78 << 7);
uint16_t result = sign | ((nonsign >> 4) + exp_offset);
// if(nonsign == 0x0000) result = 0x0000;
// if (thread0() && nonsign == 0x0000)
// {
// printf(" input = %x result = %x\n", input, result);
// }
return cutlass::bfloat16_t::bitcast(result);
}
__forceinline__ __device__ float fp8e5m2_to_fp32(const fp8& input) {
union uf16{
uint16_t as_bits;
_Float16 as_value;
} ;
union uf32 {
uint32_t as_bits;
float as_value;
};
uf16 u16;
uf32 u32;
u16.as_bits = (uint16_t)input << 8;
u32.as_value = (float)u16.as_value;
// return u32.as_bits>>16;
return u32.as_value;
}
__forceinline__ __device__ cutlass::half_t fp8e5m2_to_fp16(const fp8& input) {
union uf16{
uint16_t as_bits;
__fp16 as_value;
} ;
union uf32 {
uint32_t as_bits;
float as_value;
};
uf16 u16;
// uf32 u32;
// u16.as_bits = (uint16_t)input << 8;
// u32.as_value = (float)u16.as_value;
// return u32.as_bits>>16;
uint16_t output = (uint16_t)(input << 8);
return cutlass::half_t::bitcast(output);
}
template <int N>
CUTE_HOST_DEVICE
void wait_vmcnt() {
asm volatile("s_waitcnt vmcnt(%0) ;\n\t"
"s_barrier; \n\t"
:: "n"(N));
}
template<typename Element, bool is_scale_equal_one, Fp8KVCacheDataType KV_DTYPE, typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB_int8, Tensor3 &tCrB, Tensor4 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B, const float& k_scale )
{
typedef __fp16 __fp16x8_t __attribute__((ext_vector_type(8)));
typedef unsigned int __hip_fp8x4_storage_t;
typedef unsigned short int __hip_fp8x2_storage_t;
typedef unsigned char __hip_fp8_storage_t;
union {
__fp16x8_t data_128;
__hip_fp8x4_storage_t fp8_array[4];
} data[8];
__builtin_amdgcn_sched_barrier(0);
wait_vmcnt<8>();
data[0].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 0));
wait_vmcnt<7>();
data[1].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 1));
wait_vmcnt<6>();
data[2].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 2));
wait_vmcnt<5>();
data[3].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 3));
wait_vmcnt<4>();
data[4].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 4));
wait_vmcnt<3>();
data[5].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 5));
wait_vmcnt<2>();
data[6].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 6));
wait_vmcnt<1>();
data[7].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 7));
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int k_idx = 0; k_idx < 8; k_idx++)
{
#pragma unroll
for (int j = 0; j < 16; j+=4) {
auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&data[k_idx].fp8_array[j / 4]);
auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&(data[k_idx].fp8_array[j / 4])) + 1);
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v<Element, cutlass::bfloat16_t>) {
auto f1 = (static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = (static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = (static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = (static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
auto rst0 = fp8e4m3_to_bf16(f1);
auto rst1 = fp8e4m3_to_bf16(f2);
auto rst2 = fp8e4m3_to_bf16(f3);
auto rst3 = fp8e4m3_to_bf16(f4);
tCrB(j, 0, k_idx) = rst0;
tCrB(j + 1, 0, k_idx) = rst1;
tCrB(j + 2, 0, k_idx) = rst2;
tCrB(j + 3, 0, k_idx) = rst3;
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v<Element, cutlass::bfloat16_t>) {
auto f1 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
if constexpr(!is_scale_equal_one) {
f1 *= k_scale;
f2 *= k_scale;
f3 *= k_scale;
f4 *= k_scale;
}
// if (block0())
// {
// printf("threadIdx.x = %d %.2f %.2f %.2f %.2f \n", threadIdx.x, f1, f2, f3, f4);
// }
cutlass::NumericConverter<Element, float, cutlass::FloatRoundStyle::round_toward_zero> convert_;
auto rst0 = convert_(f1);
auto rst1 = convert_(f2);
auto rst2 = convert_(f3);
auto rst3 = convert_(f4);
tCrB(j, 0, k_idx) = rst0;
tCrB(j + 1, 0, k_idx) = rst1;
tCrB(j + 2, 0, k_idx) = rst2;
tCrB(j + 3, 0, k_idx) = rst3;
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v<Element, cutlass::half_t>) {
// auto f1 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
// auto f2 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
// auto f3 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
// auto f4 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
// tCrB(j, 0, k_idx) = f1;
// tCrB(j + 1, 0, k_idx) = f2;
// tCrB(j + 2, 0, k_idx) = f3;
// tCrB(j + 3, 0, k_idx) = f4;
__hip_fp8x4_storage_t fp8_data = data[k_idx].fp8_array[j / 4];
union Fp8_data_union{
__hip_fp8x4_storage_t fp8x4;
uint16_t fp16[2];
};
Fp8_data_union first_fp8, last_fp8;
first_fp8.fp8x4 = ((fp8_data & 0xff00ff00));
last_fp8.fp8x4 = ((fp8_data & 0x00ff00ff) << 8);
tCrB(j, 0, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[0]);
tCrB(j + 1, 0, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[0]);;
tCrB(j + 2, 0, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[1]);;
tCrB(j + 3, 0, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[1]);;;
}
else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v<Element, cutlass::half_t>) {
auto f1 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
if constexpr(!is_scale_equal_one) {
f1 *= k_scale;
f2 *= k_scale;
f3 *= k_scale;
f4 *= k_scale;
}
auto rst0 = __builtin_amdgcn_cvt_pkrtz(f1, f2);
auto rst1 = __builtin_amdgcn_cvt_pkrtz(f3, f4);
cutlass::Array<half_t, 2> result0 = reinterpret_cast<cutlass::Array<half_t, 2> &>(rst0);
cutlass::Array<half_t, 2> result1 = reinterpret_cast<cutlass::Array<half_t, 2> &>(rst1);
tCrB(j, 0, k_idx) = result0[0];
tCrB(j + 1, 0, k_idx) = result0[1];
tCrB(j + 2, 0, k_idx) = result1[0];
tCrB(j + 3, 0, k_idx) = result1[1];
}
}
cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc);
}
}
template<typename Element, int k_idx, bool is_scale_equal_one, Fp8KVCacheDataType KV_DTYPE, typename Tensor0, typename Tensor1,typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm_k_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, TiledMma tiled_mma, uint32x4_t& _data, const float& k_scale)
{
typedef __fp16 __fp16x8_t __attribute__((ext_vector_type(8)));
typedef unsigned int __hip_fp8x4_storage_t;
typedef unsigned short int __hip_fp8x2_storage_t;
typedef unsigned char __hip_fp8_storage_t;
union {
uint32x4_t data_128;
__hip_fp8x4_storage_t fp8_array[4];
} data;
data.data_128 = _data;
for (int j = 0; j < 16; j+=4) {
auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&data.fp8_array[j / 4]);
auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&(data.fp8_array[j / 4])) + 1);
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v<Element, cutlass::bfloat16_t>) {
auto f1 = (static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = (static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = (static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = (static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
auto rst0 = fp8e4m3_to_bf16(f1);
auto rst1 = fp8e4m3_to_bf16(f2);
auto rst2 = fp8e4m3_to_bf16(f3);
auto rst3 = fp8e4m3_to_bf16(f4);
tCrB(j, 0, k_idx) = rst0;
tCrB(j + 1, 0, k_idx) = rst1;
tCrB(j + 2, 0, k_idx) = rst2;
tCrB(j + 3, 0, k_idx) = rst3;
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v<Element, cutlass::bfloat16_t>) {
auto f1 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
if constexpr(!is_scale_equal_one) {
f1 *= k_scale;
f2 *= k_scale;
f3 *= k_scale;
f4 *= k_scale;
}
// if (thread0()) {
// printf(" static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8) = %x f1 = %.2f\n", static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), f1);
// }
cutlass::NumericConverter<Element, float, cutlass::FloatRoundStyle::round_toward_zero> convert_;
auto rst0 = convert_(f1);
auto rst1 = convert_(f2);
auto rst2 = convert_(f3);
auto rst3 = convert_(f4);
tCrB(j, 0, k_idx) = rst0;
tCrB(j + 1, 0, k_idx) = rst1;
tCrB(j + 2, 0, k_idx) = rst2;
tCrB(j + 3, 0, k_idx) = rst3;
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v<Element, cutlass::half_t>) {
__hip_fp8x4_storage_t fp8_data = data.fp8_array[j / 4];
union Fp8_data_union{
__hip_fp8x4_storage_t fp8x4;
uint16_t fp16[2];
} ;
Fp8_data_union first_fp8, last_fp8;
first_fp8.fp8x4 = ((fp8_data & 0xff00ff00));
last_fp8.fp8x4 = ((fp8_data & 0x00ff00ff) << 8);
tCrB(j, 0, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[0]);
tCrB(j + 1, 0, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[0]);;
tCrB(j + 2, 0, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[1]);;
tCrB(j + 3, 0, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[1]);;;
}
else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v<Element, cutlass::half_t>) {
auto f1 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
if constexpr(!is_scale_equal_one) {
f1 *= k_scale;
f2 *= k_scale;
f3 *= k_scale;
f4 *= k_scale;
}
auto rst0 = __builtin_amdgcn_cvt_pkrtz(f1, f2);
auto rst1 = __builtin_amdgcn_cvt_pkrtz(f3, f4);
cutlass::Array<half_t, 2> result0 = reinterpret_cast<cutlass::Array<half_t, 2> &>(rst0);
cutlass::Array<half_t, 2> result1 = reinterpret_cast<cutlass::Array<half_t, 2> &>(rst1);
tCrB(j, 0, k_idx) = result0[0];
tCrB(j + 1, 0, k_idx) = result0[1];
tCrB(j + 2, 0, k_idx) = result1[0];
tCrB(j + 3, 0, k_idx) = result1[1];
}
}
cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc);
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool mma_layout = false,
bool use_asm = false,
class SrcEngine, class SrcLayout
>
CUTE_HOST_DEVICE
void
buffer_load_copy_fp8(
Tensor<SrcEngine, SrcLayout> const& src,
uint32x4_t & dst,
int k_idx_, const int row_stride,
int offset_k,
const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
constexpr int elements_per_thread = 16;
if constexpr (mma_layout)
{
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
int mma_k = 32*64;
int row = tidx % 16;
int col = lane / 16;
int row_offset = row + (warp_id * 16) ;
int col_offset = col * elements_per_thread + k_idx * 64;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if constexpr(use_asm) {
asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr)
);
}
else {
auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
dst = *reinterpret_cast<uint32x4_t*>(&res);
}
}
}
template<typename Element, bool is_scale_equal_one, Fp8KVCacheDataType KV_DTYPE, typename Tensor0, typename Tensor1, typename Tensor3, typename Tensor4,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm1_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor3 &tCrB, Tensor4 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B, const float& k_scale ) {
typedef __fp16 __fp16x8_t __attribute__((ext_vector_type(8)));
typedef unsigned int __hip_fp8x4_storage_t;
typedef unsigned short int __hip_fp8x2_storage_t;
typedef unsigned char __hip_fp8_storage_t;
auto lds = reinterpret_cast<__fp16 *>(&tCsB(0, 0, 0));
auto layout = tCsB.layout();
union {
__fp16x8_t data_128;
__hip_fp8x4_storage_t fp8_array[4];
} data[8];
constexpr short offset0 = layout(0, 0, 0) * 2;
data[0].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset0);
constexpr short offset1 = layout(0, 1, 0) * 2;
data[1].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset1);
constexpr short offset2 = layout(0, 0, 1) * 2;
data[2].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset2);
constexpr short offset3 = layout(0, 1, 1) * 2;
data[3].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset3);
constexpr short offset4 = layout(0, 0, 2) * 2;
data[4].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset4);
constexpr short offset5 = layout(0, 1, 2) * 2;
data[5].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset5);
constexpr short offset6 = layout(0, 0, 3) * 2;
data[6].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset6);
constexpr short offset7 = layout(0, 1, 3) * 2;
data[7].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset7);
#pragma unroll
for (int k_idx = 0; k_idx < 4; k_idx++) {
#pragma unroll
for (int i = 0; i < 2; i++) {
#pragma unroll
for (int j = 0; j < 16; j+=4) {
auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&data[k_idx * 2 + i].fp8_array[j / 4]);
auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&(data[k_idx * 2 + i].fp8_array[j / 4])) + 1);
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v<Element, cutlass::bfloat16_t>) {
// cutlass::NumericConverter<Element, float, cutlass::FloatRoundStyle::round_toward_zero> convert_;
auto f1 = (static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = (static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = (static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = (static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
auto rst0 = fp8e4m3_to_bf16(f1);
auto rst1 = fp8e4m3_to_bf16(f2);
auto rst2 = fp8e4m3_to_bf16(f3);
auto rst3 = fp8e4m3_to_bf16(f4);
tCrB(j, i, k_idx) = rst0;
tCrB(j + 1, i, k_idx) = rst1;
tCrB(j + 2, i, k_idx) = rst2;
tCrB(j + 3, i, k_idx) = rst3;
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v<Element, cutlass::bfloat16_t>) {
auto f1 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
if constexpr(!is_scale_equal_one) {
f1 *= k_scale;
f2 *= k_scale;
f3 *= k_scale;
f4 *= k_scale;
}
cutlass::NumericConverter<Element, float, cutlass::FloatRoundStyle::round_toward_zero> convert_;
auto rst0 = convert_(f1);
auto rst1 = convert_(f2);
auto rst2 = convert_(f3);
auto rst3 = convert_(f4);
tCrB(j, i, k_idx) = rst0;
tCrB(j + 1, i, k_idx) = rst1;
tCrB(j + 2, i, k_idx) = rst2;
tCrB(j + 3, i, k_idx) = rst3;
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v<Element, cutlass::half_t>) {
// auto f1 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
// auto f2 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
// auto f3 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
// auto f4 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
// tCrB(j, i, k_idx) = f1;
// tCrB(j + 1, i, k_idx) = f2;
// tCrB(j + 2, i, k_idx) = f3;
// tCrB(j + 3, i, k_idx) = f4;
__hip_fp8x4_storage_t fp8_data = data[k_idx * 2 + i].fp8_array[j / 4];
union Fp8_data_union{
__hip_fp8x4_storage_t fp8x4;
uint16_t fp16[2];
} ;
Fp8_data_union first_fp8, last_fp8;
first_fp8.fp8x4 = ((fp8_data & 0xff00ff00));
last_fp8.fp8x4 = ((fp8_data & 0x00ff00ff) << 8);
tCrB(j, i, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[0]);
tCrB(j + 1, i, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[0]);;
tCrB(j + 2, i, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[1]);;
tCrB(j + 3, i, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[1]);;;
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v<Element, cutlass::half_t>){
auto f1 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
if constexpr(!is_scale_equal_one) {
f1 *= k_scale;
f2 *= k_scale;
f3 *= k_scale;
f4 *= k_scale;
}
auto rst0 = __builtin_amdgcn_cvt_pkrtz(f1, f3);
auto rst1 = __builtin_amdgcn_cvt_pkrtz(f2, f4);
cutlass::Array<half_t, 2> result0 = reinterpret_cast<cutlass::Array<half_t, 2> &>(rst0);
cutlass::Array<half_t, 2> result1 = reinterpret_cast<cutlass::Array<half_t, 2> &>(rst1);
tCrB(j, i, k_idx) = result0[0];
tCrB(j + 1, i, k_idx) = result1[0];
tCrB(j + 2, i, k_idx) = result0[1];
tCrB(j + 3, i, k_idx) = result1[1];
}
}
}
cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc);
}
}
}
\ No newline at end of file
......@@ -4,12 +4,14 @@ from flash_mla.flash_mla_interface import (
get_mla_metadata,
flash_mla_with_kvcache,
flash_mla_sparse_fwd,
flash_mla_with_kvcache_qkvfp8
flash_mla_with_kvcache_qkvfp8,
flash_mla_with_kvcache_kvfp8
)
__all__ = [
"get_mla_metadata",
"flash_mla_with_kvcache",
"flash_mla_sparse_fwd",
"flash_mla_with_kvcache_qkvfp8"
"flash_mla_with_kvcache_qkvfp8",
"flash_mla_with_kvcache_kvfp8"
]
......@@ -228,8 +228,6 @@ def flash_mla_with_kvcache_qkvfp8(
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details.
The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks.
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
head_dim_v: Head_dim of v. Must be 512
......@@ -289,4 +287,82 @@ def flash_mla_with_kvcache_qkvfp8(
)
sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
sched_meta.num_splits = new_num_splits
return (out, lse)
def flash_mla_with_kvcache_kvfp8(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: Optional[torch.Tensor],
cache_seqlens: Optional[torch.Tensor],
head_dim_v: int,
tile_scheduler_metadata: FlashMLASchedMeta,
num_splits: None = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
head_dim_v: Head_dim of v. Must be 512
sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
num_splits_placeholder: must be "None" (to be compatible with the old interface).
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
causal: bool. Whether to apply causal attention mask. Only valid for dense attention
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
sched_meta = tile_scheduler_metadata
assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
assert num_splits is None, "num_splits must be None"
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if not sched_meta.have_initialized:
# Initialize the tile scheduler metadata during the first invocation.
sched_meta.have_initialized = True
sched_meta.config = FlashMLASchedMeta.Config(
q.shape[0],
q.shape[1],
q.shape[2],
k_cache.shape[1],
k_cache.shape[2],
causal,
False,
0,
0,
0
)
else:
# Check whether the input arguments are consistent with sched_meta
helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
assert sched_meta.config is not None
assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
# Dense attention
assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_kvfp8(
q, k_cache, head_dim_v,
cache_seqlens, block_table,
softmax_scale, causal,
sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
descale_q, descale_k
)
sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
sched_meta.num_splits = new_num_splits
return (out, lse)
\ No newline at end of file
......@@ -61,6 +61,9 @@ ext_modules.append(
## sm90 dense qkvfp8 decode
"csrc/sm90/decode/dense_qkvfp8/instantiations/fp8e4m3.cu",
## sm90 dense kvfp8 decode
"csrc/sm90/decode/dense_kvfp8/instantiations/kvfp8.cu",
# # sm90 sparse decode
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h16.cu",
......@@ -97,7 +100,7 @@ ext_modules.append(
Path(this_dir) / "csrc",
Path(this_dir) / "csrc" / "kerutils" / "include", # TODO Remove me
Path(this_dir) / "csrc" / "sm90",
Path(this_dir) / "csrc" / "cutlass" / "include"
Path(this_dir) / "csrc" / "cutlass" / "include",
],
)
)
......
import argparse
import math
import random
import torch
import triton
from flash_mla import flash_mla_with_kvcache_kvfp8, get_mla_metadata
torch.set_printoptions(precision=4, profile="default", sci_mode=False)
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False, k_scale=1.0):
query = query.float()
key = key.float() * k_scale
value = value.float() * k_scale
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
# tmp = query @ key.transpose(-2, -1)
# print("tmp ", tmp.shape, tmp[0, 0, :16])
# print("tmp ", tmp.shape, tmp[0, 0, 16:32])
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight @ value, lse
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
torch_dtype = x.dtype
x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
amax_diff = (x - y).abs().max().item()
print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
assert cos_diff < (1e-4 if torch_dtype == torch.bfloat16 else 1e-5)
@torch.inference_mode()
def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=False):
print(
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}"
)
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
if varlen:
for i in range(b):
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}, {max_seqlen_pad=}")
q = torch.randn(b, s_q, h_q, d)
# q = torch.ones(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32
).view(b, max_seqlen_pad // block_size)
# blocked_k = torch.randint(low=0, high=4, size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8)
# blocked_k = torch.ones(size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8)
blocked_k = (torch.randn(block_table.numel(), block_size, h_kv, d))
# blocked_k = (torch.zeros(block_table.numel(), block_size, h_kv, d))
# blocked_k[:, 1:, :, :] = 0
# blocked_k[:, :, :, 1:] = 0
# blocked_k[0, 0:16, 0, 0] = 0
# blocked_k[0, 32:, 0, 0] = 0
# blocked_k[0, 0, 0, 1] = 2
# blocked_k[0, 0, 0, 2] = 3
# blocked_k[0, 0, 0, 3] = 4
# print(" blocked_k ", blocked_k[0, 0, 0, :])
blocked_k = blocked_k.to(torch.float8_e5m2)
# blocked_k = (torch.ones(block_table.numel(), block_size, h_kv, d)).to(torch.float8_e5m2)
# blocked_k[0, 0, 0, 56] = 1
# blocked_k[0, 1, 0, 8] = 2
# blocked_k[0, 2, 0, 8] = 5
# blocked_k[0, 3, 0, 8] = 4
# for i in range(64):
# for j in range(64):
# blocked_k[0, i, 0, j] = j
# blocked_k[0, i, 0, j] = (i * 50 + j) % 128
# print("blocked_k ", blocked_k[0, 0, 0, 0:10])
# for i in range(b):
# blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = (
# -128
# )
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata()
# print("q:", q.shape, q.dtype, q)
# print("cache_seqlens:", cache_seqlens.shape, cache_seqlens)
# print("block_table:", block_table.shape, block_table)
# print("blocked_k:", blocked_k.shape, blocked_k[0])
# print("blocked_v:", blocked_v.shape)
# torch.set_printoptions(precision=4, profile="full", sci_mode=False)
# print("tile_scheduler_metadata:", tile_scheduler_metadata.shape, tile_scheduler_metadata)
# torch.set_printoptions(precision=4, profile="default", sci_mode=False)
# print("num_splits:", num_splits.shape, num_splits)
# k_scale = torch.tensor(1.0).to(torch.float32).to("cuda:0")
# k_scale = torch.tensor(2.1).to(torch.float32).to("cuda:0")
k_scale = torch.tensor(1.0).to(torch.float32).to("cuda:0")
descale_q = torch.ones((1), dtype=torch.float32)
descale_k = torch.ones((1), dtype=torch.float32)
def flash_mla():
return flash_mla_with_kvcache_kvfp8(
q,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
descale_q = descale_q,
descale_k = descale_k,
)
def ref_mla():
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q=h_q,
h_kv=h_kv,
is_causal=causal,
k_scale = k_scale
)
out[i] = O.transpose(0, 1)
lse[i] = LSE
return out, lse
out_flash, lse_flash = flash_mla()
out_torch, lse_torch = ref_mla()
# print("out_flash ", out_flash[0, 0, 0, 0:14])
# print("out_torch ", out_torch[0, 0, 0, 0:14])
# print("lse_flash ", lse_flash[0, 0, 0:10])
# print("lse_torch ", lse_torch[0, 0, 0:10])
print("out max_diff ", (out_flash - out_torch).abs().max())
print("lse max_diff ", (lse_flash - lse_torch).abs().max())
# print(" out ", torch.nonzero((out_flash - out_torch).abs() > 0.1))
# print(" out_torch", out_torch)
cal_diff(lse_flash, lse_torch, "lse")
cal_diff(out_flash, out_torch, "out")
t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = ( b * s_q * h_q * d + b * s_q * h_q * dv) * (
torch.finfo(q.dtype).bits // 8
) + total_seqlens * h_kv * d
print(
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
)
def main(torch_dtype, is_prof=False):
device = torch.device("cuda:0")
torch.set_default_dtype(torch_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
'''
h_kv = 1
d, dv = 576, 512
causal = True
for b in [128]:
for s in [4096, 8192]:
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
for s_q in [1, 2]: # MTP = 1, 2
for varlen in [False, True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
# b, s_q, s, h_q, h_kv, d, dv, causal, varlen'''
# test_flash_mla( 1, 1, 64, 16, 1, 576, 512, True, False, is_prof=is_prof)
# test_flash_mla_fp8( 1, 1, 1000, 1, 1, 576, 512, True, False, is_prof=is_prof)
# test_flash_mla_fp8( 1, 1, 4096, 8, 1, 576, 512, True, False, is_prof=is_prof)
# test_flash_mla_fp8(32, 1, 4096, 16, 1, 576, 512, False, False, is_prof=is_prof)
# '''
h_kv = 1
d, dv = 576, 512
causal = True
# for b in [1, 32]:
# for s in [200, 1002, 2002, 1024, 2000, 4000, 32768, 65536]:
# for h_q in [4, 16, 32, 64]:
# for s_q in [1, 2]: # MTP = 1, 2
# for varlen in [True]:
# test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
# for b in [32]:
# for s in [16384, 32768, 65536*2]:
# for h_q in [16]:
# for s_q in [1]: # MTP = 1, 2
# for varlen in [False]:
# # for varlen in [True]:
# test_flash_mla_fp8_e5m2(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
# test_flash_mla_fp8_e4m3(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
# '''
for b in [3, 6, 9, 12, 15, 18, 21, 24]:
for s in [111, 112, 123, 1234, 432, 4325, 4000, 8192, 11111]:
for h_q in [16]:
for s_q in [1, 2, 3]: # MTP = 1, 2
for varlen in [False, True]:
test_flash_mla_fp8_e5m2(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,True)
for b in [3, 6, 9, 12, 15, 18, 21, 24, 32, 64, 128, 256]:
for s in [4000]:
for h_q in [16]:
for s_q in [1]: # MTP = 1, 2
for varlen in [False]:
test_flash_mla_fp8_e5m2(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
# for b in [1]:
# for s in [64]:
# for h_q in [16]:
# for s_q in [1]: # MTP = 1, 2
# for varlen in [False]:
# test_flash_mla_fp8_e5m2(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dtype",
type=str,
choices=["bf16", "fp16"],
default="bf16",
help="Data type to use for testing (bf16 or fp16)",
)
parser.add_argument('--prof', default=False, action='store_true', help='prof or not')
args = parser.parse_args()
torch_dtype = torch.bfloat16
if args.dtype == "fp16":
torch_dtype = torch.float16
main(torch_dtype, args.prof)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment