Commit c353b35b authored by zhanghj2's avatar zhanghj2
Browse files

恢复支持旧接口

parent c566af36
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "dense_decode.h" #include "dense_decode.h"
#include "dense_decode_qkvfp8.h" #include "dense_decode_qkvfp8.h"
#include "dense_decode_kvfp8.h" #include "dense_decode_kvfp8.h"
#include "../extension/flash_api.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashMLA"; m.doc() = "FlashMLA";
m.def("sparse_decode_fwd", &sparse_attn_decode_interface); m.def("sparse_decode_fwd", &sparse_attn_decode_interface);
...@@ -13,4 +13,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -13,4 +13,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dense_decode_fwd_qkvfp8", &dense_attn_decode_qkvfp8_interface); m.def("dense_decode_fwd_qkvfp8", &dense_attn_decode_qkvfp8_interface);
m.def("dense_decode_fwd_kvfp8", &dense_attn_decode_kvfp8_interface); m.def("dense_decode_fwd_kvfp8", &dense_attn_decode_kvfp8_interface);
m.def("sparse_prefill_fwd", &sparse_attn_prefill_interface); m.def("sparse_prefill_fwd", &sparse_attn_prefill_interface);
m.def("get_mla_decoding_metadata_dense_fp8", &get_mla_decoding_metadata_dense_fp8);
m.def("fwd_kvcache_quantization_mla", &mha_fwd_kvcache_quantization_mla);
m.def("fwd_kvcache_quantization_q_nope_pe_mla", &mha_fwd_kvcache_quantization_q_nope_pe_mla);
m.def("fwd_kvcache_mla_nope_pe", &mha_fwd_kvcache_mla_nope_pe);
m.def("fwd_kvcache_mla_fp8", &mha_fwd_kvcache_mla_fp8);
m.def("fwd_kvcache_mla_fp8_with_cat", &mha_fwd_kvcache_mla_fp8_with_cat);
} }
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
// #include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cutlass/fast_math.h>
#include <cstdlib>
#include "flash_mla.h"
#include "static_switch.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
static const bool print_param = get_env_("FLASH_MLA_PRINT_PARAM");
std::string static execCommand(const char* cmd) {
std::string result;
FILE* pipe = popen(cmd, "r"); // 打开管道,只读方式
if (!pipe) {
return "popen failed";
}
char buffer[256];
while (fgets(buffer, sizeof(buffer), pipe) != nullptr) {
result += buffer;
}
pclose(pipe); // 关闭管道并等待子进程结束
if (!result.empty() && result.back() == '\n') {
result.pop_back();
}
return result;
}
std::vector<at::Tensor>
mha_fwd_kvcache_quantization_mla(
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
const int head_size_v,
const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
const float softmax_scale,
bool is_causal,
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor &num_splits, // batch_size + 1
const at::Tensor &k_scale,
const std::string &kv_cache_dtype
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_gfx936 = dprops->major == 9 && dprops->minor == 3;
// TORCH_CHECK(is_sm90);
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
auto q_dtype = q.dtype();
if (kv_cache_dtype == "fp8_e4m3" || kv_cache_dtype == "fp8_e5m2")
{
TORCH_CHECK(kcache.dtype() != q_dtype, "非量化情况下, query and key must have not the same dtype");
CHECK_DEVICE(k_scale);
TORCH_CHECK(k_scale.dtype() == torch::kFloat32, "非量化情况下, query and key must have the same dtype");
TORCH_CHECK(is_gfx936, "fp8_e4m3 and fp8_e5m2 Attention Forward Kernel (mha_fwd_kvcache_quantization_mla) is only supported on gfx936 architectures");
}
else
{
TORCH_CHECK(false, "Unsupported kv cache dtype");
}
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q_ori = sizes[1];
const int num_heads_ori = sizes[2];
const int head_size = sizes[3];
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32");
const int max_num_blocks_per_seq = block_table.size(1);
const int num_blocks = kcache.size(0);
const int page_block_size = kcache.size(1);
const int num_heads_k = kcache.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q_ori == 1) { is_causal = false; }
const int ngroups = num_heads_ori / num_heads_k;
const int seqlen_q = seqlen_q_ori * ngroups;
const int num_heads = num_heads_k;
q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3)
.reshape({batch_size, seqlen_q, num_heads, head_size});
int head_size_k = head_size;
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); }
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
CHECK_DEVICE(seqlens_k);
CHECK_CONTIGUOUS(seqlens_k);
CHECK_SHAPE(seqlens_k, batch_size);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
Flash_fwd_mla_params params = {};
// Set the sizes.
params.b = batch_size;
params.seqlen_q = seqlen_q;
params.cu_seqlens_k = seqlens_k.data_ptr<int>();
params.h = num_heads;
params.h_h_k_ratio = num_heads / num_heads_k;
params.ngroups = ngroups;
params.is_causal = is_causal;
params.d = head_size;
params.d_v = head_size_v;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = kcache.data_ptr();
params.v_ptr = vcache.data_ptr();
params.o_ptr = out.data_ptr();
params.softmax_lse_ptr = softmax_lse.data_ptr();
// All stride are in elements, not bytes.
params.q_batch_stride = q.stride(0);
params.k_batch_stride = kcache.stride(0);
params.v_batch_stride = vcache.stride(0);
params.o_batch_stride = out.stride(0);
params.q_row_stride = q.stride(-3);
params.k_row_stride = kcache.stride(-3);
params.v_row_stride = vcache.stride(-3);
params.o_row_stride = out.stride(-3);
params.q_head_stride = q.stride(-2);
params.k_head_stride = kcache.stride(-2);
params.v_head_stride = vcache.stride(-2);
params.o_head_stride = out.stride(-2);
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;
params.k_scale_ptr = k_scale.data_ptr();
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
CHECK_DEVICE(tile_scheduler_metadata);
CHECK_CONTIGUOUS(tile_scheduler_metadata);
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
params.num_sm_parts = tile_scheduler_metadata.size(0);
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
CHECK_DEVICE(num_splits);
CHECK_CONTIGUOUS(num_splits);
params.num_splits_ptr = num_splits.data_ptr<int>();
at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(head_size == 576);
if (print_param)
{
fprintf(stderr, "[flashmla] [mha_fwd_kvcache_quantization_mla] q_dtype = %s input size q (%d %d %d %d) block_table (%d %d) kcache (%d %d %d %d) is_causal = %d softmax_scale = %.4f kv_cache_dtype = %s\n", (q_dtype == torch::kBFloat16?"bf16":"fp16"),
batch_size, seqlen_q_ori, num_heads_ori, head_size, batch_size, max_num_blocks_per_seq,
num_blocks, page_block_size, num_heads_k, head_size_k, is_causal, softmax_scale, kv_cache_dtype.c_str());
}
if (q_dtype == torch::kBFloat16) {
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, kv_cache_dtype, stream);
}
#ifndef FLASH_MLA_DISABLE_FP16
else if (q_dtype == torch::kHalf) {
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, kv_cache_dtype, stream);
}
#endif
else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
.reshape({batch_size, num_heads_ori, seqlen_q_ori});
return {out, softmax_lse};
}
// static inline int int64_stride_to_int(int64_t orig_stride) {
// if (orig_stride > std::numeric_limits<int>::max()) {
// TORCH_CHECK(false, "[Sparse TopK Attention] Stride exceeds int32 limit: ", orig_stride);
// }
// return static_cast<int>(orig_stride);
// }
struct DecodingAttnImplMeta {
int num_sm_parts;
int fixed_overhead_num_blocks;
int k_block_size;
};
DecodingAttnImplMeta get_attn_impl_meta(
int sm_count,
int num_q_tokens_per_head_k,
int h_k,
std::optional<int> h_q_,
bool is_fp8_kvcache,
bool is_sparse_attn
) {
// if (arch.is_sm90())
{
if (is_sparse_attn) {
if (is_fp8_kvcache) {
TORCH_CHECK(h_q_.has_value());
int h_q = h_q_.value();
TORCH_CHECK(h_q % h_k == 0);
int s_q = num_q_tokens_per_head_k * h_k / h_q;
// FP8 + Sparse MLA
return {
std::max((sm_count * 2) / h_k / (cutlass::ceil_div(h_q/h_k, 16) * s_q), 1),
5,
64
};
} else {
// Sparse BF16 MLA
TORCH_CHECK(false, "Sparse BF16 MLA is not supported on SM90");
}
} else {
if (is_fp8_kvcache) {
// Dense FP8 MLA
TORCH_CHECK(false, "Dense FP8 MLA is not supported on SM90");
} else {
int h_q = h_q_.has_value() && h_q_.value() >= 64 ? 64 : 16;
// Dense BF16 MLA
return {
std::max(sm_count / h_k / cutlass::ceil_div(num_q_tokens_per_head_k, h_q), 1),
5,
64
};
}
}
}
}
std::vector<at::Tensor>
get_mla_decoding_metadata_dense_fp8(
at::Tensor &seqlens_k,
const int num_heads_per_head_k,
const int num_heads_k,
const std::optional<int> h_q
) {
// This should match the logic in the MLA kernel.
int block_size_m = 16;
static constexpr int block_size_n = 64;
if (h_q.has_value()) {
if (h_q.value() >= 64) {
block_size_m = 64;
} else if (h_q.value() > 16) {
block_size_m = 32;
}
}
static constexpr int fixed_overhead_num_blocks = 5;
CHECK_DEVICE(seqlens_k);
TORCH_CHECK(seqlens_k.is_contiguous());
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);
int batch_size = seqlens_k.size(0);
int *seqlens_k_ptr = seqlens_k.data_ptr<int>();
auto options = seqlens_k.options();
auto dprops = at::cuda::getCurrentDeviceProperties();
int sm_count = dprops->multiProcessorCount*(block_size_m == 16 ? 2 : 1);
int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m);
if (print_param)
{
fprintf(stderr, "[flashmla] [get_mla_decoding_metadata_dense_fp8] block_size_m=%d sm_count=%d num_sm_parts=%d\n",
block_size_m, sm_count, num_sm_parts);
}
auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options);
auto num_splits = torch::empty({batch_size + 1}, options);
int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
int *num_splits_ptr = num_splits.data_ptr<int>();
at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
Mla_metadata_params params = {};
params.seqlens_k_ptr = seqlens_k_ptr;
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr;
params.num_splits_ptr = num_splits_ptr;
params.batch_size = batch_size;
params.block_size_n = block_size_n;
params.fixed_overhead_num_blocks = fixed_overhead_num_blocks;
params.num_sm_parts = num_sm_parts;
get_mla_metadata_func(params, stream);
return {tile_scheduler_metadata, num_splits};
}
std::vector<at::Tensor>
mha_fwd_kvcache_mla_nope_pe(
at::Tensor &q_nope, // batch_size x seqlen_q x num_heads x head_size
at::Tensor &q_pe, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
const int head_size_v,
const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
const float softmax_scale,
bool is_causal,
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor &num_splits // batch_size + 1
) {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
// TORCH_CHECK(is_sm90);
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
auto q_dtype = q_nope.dtype();
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
CHECK_DEVICE(q_nope); CHECK_DEVICE(q_pe); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
TORCH_CHECK(q_nope.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(q_pe.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
const auto sizes = q_nope.sizes();
const int batch_size = sizes[0];
const int seqlen_q_ori = sizes[1];
const int num_heads_ori = sizes[2];
const int head_size_nope = sizes[3];
const int head_size_pe = q_pe.size(3);
const int head_size = head_size_nope + head_size_pe;
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32");
const int max_num_blocks_per_seq = block_table.size(1);
const int num_blocks = kcache.size(0);
const int page_block_size = kcache.size(1);
const int num_heads_k = kcache.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q_ori == 1) { is_causal = false; }
TORCH_CHECK(seqlen_q_ori == 1, "mha_fwd_kvcache_mla_nope_pe only support seqlen_q_ori=1");
const int ngroups = num_heads_ori / num_heads_k;
const int seqlen_q = seqlen_q_ori * ngroups;
const int num_heads = num_heads_k;
q_nope = q_nope.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size_nope}).transpose(2, 3)
.reshape({batch_size, seqlen_q, num_heads, head_size_nope});
q_pe = q_pe.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size_pe}).transpose(2, 3)
.reshape({batch_size, seqlen_q, num_heads, head_size_pe});
int head_size_k = head_size;
CHECK_SHAPE(q_nope, batch_size, seqlen_q, num_heads, head_size_nope);
CHECK_SHAPE(q_pe, batch_size, seqlen_q, num_heads, head_size_pe);
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); }
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
CHECK_DEVICE(seqlens_k);
CHECK_CONTIGUOUS(seqlens_k);
CHECK_SHAPE(seqlens_k, batch_size);
// at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q_nope.options();
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
Flash_fwd_mla_params params = {};
// Set the sizes.
params.b = batch_size;
params.seqlen_q = seqlen_q;
params.cu_seqlens_k = seqlens_k.data_ptr<int>();
params.h = num_heads;
params.h_h_k_ratio = num_heads / num_heads_k;
params.ngroups = ngroups;
params.is_causal = is_causal;
params.d = head_size;
params.d_v = head_size_v;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
// Set the pointers and strides.
params.q_nope_ptr = q_nope.data_ptr();
params.q_pe_ptr = q_pe.data_ptr();
params.k_ptr = kcache.data_ptr();
params.v_ptr = vcache.data_ptr();
params.o_ptr = out.data_ptr();
params.softmax_lse_ptr = softmax_lse.data_ptr();
// All stride are in elements, not bytes.
params.q_nope_batch_stride = q_nope.stride(0);
params.q_pe_batch_stride = q_pe.stride(0);
params.k_batch_stride = kcache.stride(0);
params.v_batch_stride = vcache.stride(0);
params.o_batch_stride = out.stride(0);
params.q_nope_row_stride = q_nope.stride(-3);
params.q_pe_row_stride = q_pe.stride(-3);
params.k_row_stride = kcache.stride(-3);
params.v_row_stride = vcache.stride(-3);
params.o_row_stride = out.stride(-3);
params.q_nope_head_stride = q_nope.stride(-2);
params.q_pe_head_stride = q_pe.stride(-2);
params.k_head_stride = kcache.stride(-2);
params.v_head_stride = vcache.stride(-2);
params.o_head_stride = out.stride(-2);
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
CHECK_DEVICE(tile_scheduler_metadata);
CHECK_CONTIGUOUS(tile_scheduler_metadata);
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
params.num_sm_parts = tile_scheduler_metadata.size(0);
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
CHECK_DEVICE(num_splits);
CHECK_CONTIGUOUS(num_splits);
params.num_splits_ptr = num_splits.data_ptr<int>();
at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(head_size == 576);
if (print_param)
{
fprintf(stderr, "[flashmla] [mha_fwd_kvcache_mla_nope_pe] q_dtype = %s input size q (%d %d %d %d) block_table (%d %d) kcache (%d %d %d %d) is_causal = %d softmax_scale = %.4f\n", (q_dtype == torch::kBFloat16?"bf16":"fp16"),
batch_size, seqlen_q_ori, num_heads_ori, head_size, batch_size, max_num_blocks_per_seq,
num_blocks, page_block_size, num_heads_k, head_size_k, is_causal, softmax_scale);
}
if (q_dtype == torch::kBFloat16) {
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, "auto", stream, true);
}
#ifndef FLASH_MLA_DISABLE_FP16
else if (q_dtype == torch::kHalf) {
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, "auto", stream, true);
}
#endif
else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
.reshape({batch_size, num_heads_ori, seqlen_q_ori});
return {out, softmax_lse};
}
std::vector<at::Tensor>
mha_fwd_kvcache_quantization_q_nope_pe_mla(
at::Tensor &q_nope, // batch_size x seqlen_q x num_heads x 512
at::Tensor &q_pe, // batch_size x seqlen_q x num_heads x 64
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
const int head_size_v,
const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
const float softmax_scale,
bool is_causal,
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor &num_splits, // batch_size + 1
const at::Tensor &k_scale,
const std::string &kv_cache_dtype
) {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
// TORCH_CHECK(is_sm90);
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
auto q_dtype = q_nope.dtype();
if (kv_cache_dtype == "fp8_e5m2")
{
TORCH_CHECK(kcache.dtype() != q_dtype, "非量化情况下, query and key must have not the same dtype");
CHECK_DEVICE(k_scale);
TORCH_CHECK(k_scale.dtype() == torch::kFloat32, "非量化情况下, query and key must have the same dtype");
}
else
{
TORCH_CHECK(false, "Unsupported kv cache dtype");
}
CHECK_DEVICE(q_nope); CHECK_DEVICE(q_pe); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
TORCH_CHECK(q_nope.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(q_pe.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
const auto sizes = q_nope.sizes();
const int batch_size = sizes[0];
const int seqlen_q_ori = sizes[1];
const int num_heads_ori = sizes[2];
const int head_size_nope = sizes[3];
const int head_size_pe = q_pe.size(3);
const int head_size = head_size_nope + head_size_pe;
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32");
const int max_num_blocks_per_seq = block_table.size(1);
const int num_blocks = kcache.size(0);
const int page_block_size = kcache.size(1);
const int num_heads_k = kcache.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q_ori == 1) { is_causal = false; }
TORCH_CHECK(seqlen_q_ori == 1, "mha_fwd_kvcache_quantization_q_nope_pe_mla only support seqlen_q_ori=1");
const int ngroups = num_heads_ori / num_heads_k;
const int seqlen_q = seqlen_q_ori * ngroups;
const int num_heads = num_heads_k;
q_nope = q_nope.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size_nope}).transpose(2, 3)
.reshape({batch_size, seqlen_q, num_heads, head_size_nope});
q_pe = q_pe.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size_pe}).transpose(2, 3)
.reshape({batch_size, seqlen_q, num_heads, head_size_pe});
int head_size_k = head_size;
CHECK_SHAPE(q_nope, batch_size, seqlen_q, num_heads, head_size_nope);
CHECK_SHAPE(q_pe, batch_size, seqlen_q, num_heads, head_size_pe);
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); }
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
CHECK_DEVICE(seqlens_k);
CHECK_CONTIGUOUS(seqlens_k);
CHECK_SHAPE(seqlens_k, batch_size);
// at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q_nope.options();
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
Flash_fwd_mla_params params = {};
// Set the sizes.
params.b = batch_size;
params.seqlen_q = seqlen_q;
params.cu_seqlens_k = seqlens_k.data_ptr<int>();
params.h = num_heads;
params.h_h_k_ratio = num_heads / num_heads_k;
params.ngroups = ngroups;
params.is_causal = is_causal;
params.d = head_size;
params.d_v = head_size_v;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
// Set the pointers and strides.
params.q_nope_ptr = q_nope.data_ptr();
params.q_pe_ptr = q_pe.data_ptr();
params.k_ptr = kcache.data_ptr();
params.v_ptr = vcache.data_ptr();
params.o_ptr = out.data_ptr();
params.softmax_lse_ptr = softmax_lse.data_ptr();
// All stride are in elements, not bytes.
params.q_nope_batch_stride = q_nope.stride(0);
params.q_pe_batch_stride = q_pe.stride(0);
params.k_batch_stride = kcache.stride(0);
params.v_batch_stride = vcache.stride(0);
params.o_batch_stride = out.stride(0);
params.q_nope_row_stride = q_nope.stride(-3);
params.q_pe_row_stride = q_pe.stride(-3);
params.k_row_stride = kcache.stride(-3);
params.v_row_stride = vcache.stride(-3);
params.o_row_stride = out.stride(-3);
params.q_nope_head_stride = q_nope.stride(-2);
params.q_pe_head_stride = q_pe.stride(-2);
params.k_head_stride = kcache.stride(-2);
params.v_head_stride = vcache.stride(-2);
params.o_head_stride = out.stride(-2);
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;
params.k_scale_ptr = k_scale.data_ptr();
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
CHECK_DEVICE(tile_scheduler_metadata);
CHECK_CONTIGUOUS(tile_scheduler_metadata);
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
params.num_sm_parts = tile_scheduler_metadata.size(0);
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
CHECK_DEVICE(num_splits);
CHECK_CONTIGUOUS(num_splits);
params.num_splits_ptr = num_splits.data_ptr<int>();
at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(head_size == 576);
if (print_param)
{
fprintf(stderr, "[flashmla] [mha_fwd_kvcache_quantization_q_nope_pe_mla] q_dtype = %s input size q (%d %d %d %d) block_table (%d %d) kcache (%d %d %d %d) is_causal = %d softmax_scale = %.4f kv_cache_dtype = %s\n", (q_dtype == torch::kBFloat16?"bf16":"fp16"),
batch_size, seqlen_q_ori, num_heads_ori, head_size, batch_size, max_num_blocks_per_seq,
num_blocks, page_block_size, num_heads_k, head_size_k, is_causal, softmax_scale, kv_cache_dtype.c_str());
}
if (q_dtype == torch::kBFloat16) {
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, kv_cache_dtype, stream, true);
}
#ifndef FLASH_MLA_DISABLE_FP16
else if (q_dtype == torch::kHalf) {
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, kv_cache_dtype, stream, true);
}
#endif
else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
.reshape({batch_size, num_heads_ori, seqlen_q_ori});
return {out, softmax_lse};
}
std::vector<at::Tensor>
mha_fwd_kvcache_mla_fp8(
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
const int head_size_v,
const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
const float softmax_scale,
bool is_causal,
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor &num_splits, // batch_size + 1
const std::optional<at::Tensor> &descale_q, // None or batch_size
const std::optional<at::Tensor> &descale_k // None or batch_size
) {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
// TORCH_CHECK(is_sm90);
static std::string FLASH_MLA_ROOT_DIR = execCommand("python -c 'import site; print(site.getsitepackages()[0])'");
setenv("FLASH_MLA_ROOT_DIR", (FLASH_MLA_ROOT_DIR + "/flash_mla/asm/").c_str(), 1);
// std::cout << FLASH_MLA_ROOT_DIR << "\n";
// exit(-1);
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
auto q_dtype = q.dtype();
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
if (descale_q.has_value()) CHECK_DEVICE(descale_q.value());
if (descale_k.has_value()) CHECK_DEVICE(descale_k.value());
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q_ori = sizes[1];
const int num_heads_ori = sizes[2];
const int head_size = sizes[3];
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32");
const int max_num_blocks_per_seq = block_table.size(1);
const int num_blocks = kcache.size(0);
const int page_block_size = kcache.size(1);
const int num_heads_k = kcache.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
TORCH_CHECK(descale_q.has_value() && descale_k.has_value(), "descale is required when input dtype is fp8");
auto descale_q_ = descale_q.value();
auto descale_k_ = descale_k.value();
CHECK_DEVICE(descale_q_);
CHECK_DEVICE(descale_k_);
TORCH_CHECK(descale_q_.stride(-1) == 1);
TORCH_CHECK(descale_k_.stride(-1) == 1);
TORCH_CHECK(descale_q_.dtype() == torch::kFloat);
TORCH_CHECK(descale_k_.dtype() == torch::kFloat);
CHECK_SHAPE(descale_q_, 1);
CHECK_SHAPE(descale_k_, 1);
if (seqlen_q_ori == 1) { is_causal = false; }
const int ngroups = num_heads_ori / num_heads_k;
const int seqlen_q = seqlen_q_ori * ngroups;
const int num_heads = num_heads_k;
q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3)
.reshape({batch_size, seqlen_q, num_heads, head_size});
int head_size_k = head_size;
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); }
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
CHECK_DEVICE(seqlens_k);
CHECK_CONTIGUOUS(seqlens_k);
CHECK_SHAPE(seqlens_k, batch_size);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(torch::kBFloat16));//1,16,1,512
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));//1,1,16
Flash_fwd_mla_params params = {};
// Set the sizes.
params.b = batch_size;//1
params.seqlen_q = seqlen_q;//16
params.cu_seqlens_k = seqlens_k.data_ptr<int>();
params.h = num_heads;//1
params.h_h_k_ratio = num_heads / num_heads_k;//1
params.ngroups = ngroups;//16
params.is_causal = is_causal;//false
params.d = head_size;//576
params.d_v = head_size_v;//512
params.scale_softmax = softmax_scale;//0.417
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = kcache.data_ptr();
params.v_ptr = vcache.data_ptr();
params.o_ptr = out.data_ptr();
params.softmax_lse_ptr = softmax_lse.data_ptr();
// All stride are in elements, not bytes.
params.q_batch_stride = q.stride(0);
params.k_batch_stride = kcache.stride(0);
params.v_batch_stride = vcache.stride(0);
params.o_batch_stride = out.stride(0);
params.q_row_stride = q.stride(-3);
params.k_row_stride = kcache.stride(-3);
params.v_row_stride = vcache.stride(-3);
params.o_row_stride = out.stride(-3);
params.q_head_stride = q.stride(-2);
params.k_head_stride = kcache.stride(-2);
params.v_head_stride = vcache.stride(-2);
params.o_head_stride = out.stride(-2);
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;//64
params.descale_q_ptr = reinterpret_cast<float *>(descale_q.value().data_ptr());
params.descale_k_ptr = reinterpret_cast<float *>(descale_k.value().data_ptr());
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
CHECK_DEVICE(tile_scheduler_metadata);
CHECK_CONTIGUOUS(tile_scheduler_metadata);
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
params.num_sm_parts = tile_scheduler_metadata.size(0);
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
CHECK_DEVICE(num_splits);
CHECK_CONTIGUOUS(num_splits);
params.num_splits_ptr = num_splits.data_ptr<int>();
at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(head_size == 576);
if (print_param)
{
fprintf(stderr, "[flashmla] [mha_fwd_kvcache_mla_fp8] input size q (%d %d %d %d) block_table (%d %d) kcache (%d %d %d %d) is_causal = %d softmax_scale = %.4f \n",
batch_size, seqlen_q_ori, num_heads_ori, head_size, batch_size, max_num_blocks_per_seq,
num_blocks, page_block_size, num_heads_k, head_size_k, is_causal, softmax_scale);
}
if (q_dtype == torch::kFloat8_e4m3fn) {
run_mha_fwd_splitkv_mla_fp8<cutlass::float_e4m3_t,cutlass::bfloat16_t, 576>(params,stream,false);
}
else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
.reshape({batch_size, num_heads_ori, seqlen_q_ori});
return {out, softmax_lse};
}
std::vector<at::Tensor>
mha_fwd_kvcache_mla_fp8_with_cat(
at::Tensor &q_nope, // batch_size x seqlen_q x num_heads x 512
at::Tensor &q_pe, // batch_size x seqlen_q x num_heads x 64
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
const int head_size_v,
const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
const float softmax_scale,
bool is_causal,
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor &num_splits, // batch_size + 1
const std::optional<at::Tensor> &descale_q, // None or batch_size
const std::optional<at::Tensor> &descale_k // None or batch_size
) {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
// TORCH_CHECK(is_sm90);
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
auto q_dtype = q_nope.dtype();
// TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
CHECK_DEVICE(q_nope); CHECK_DEVICE(q_pe); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
TORCH_CHECK(q_nope.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(q_pe.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
if (descale_q.has_value()) CHECK_DEVICE(descale_q.value());
if (descale_k.has_value()) CHECK_DEVICE(descale_k.value());
const auto sizes = q_nope.sizes();
const int batch_size = sizes[0];
const int seqlen_q_ori = sizes[1];
const int num_heads_ori = sizes[2];
const int head_size_nope = sizes[3];
const int head_size_pe = q_pe.size(3);
const int head_size = head_size_nope + head_size_pe;
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32");
const int max_num_blocks_per_seq = block_table.size(1);
const int num_blocks = kcache.size(0);
const int page_block_size = kcache.size(1);
const int num_heads_k = kcache.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
// TORCH_CHECK(num_heads_ori == 16, "only support q nheads = 16");
TORCH_CHECK(descale_q.has_value() && descale_k.has_value(), "descale is required when input dtype is fp8");
auto descale_q_ = descale_q.value();
auto descale_k_ = descale_k.value();
CHECK_DEVICE(descale_q_);
CHECK_DEVICE(descale_k_);
TORCH_CHECK(descale_q_.stride(-1) == 1);
TORCH_CHECK(descale_k_.stride(-1) == 1);
TORCH_CHECK(descale_q_.dtype() == torch::kFloat);
TORCH_CHECK(descale_k_.dtype() == torch::kFloat);
CHECK_SHAPE(descale_q_, 1);
CHECK_SHAPE(descale_k_, 1);
if (seqlen_q_ori == 1) { is_causal = false; }
const int ngroups = num_heads_ori / num_heads_k;
const int seqlen_q = seqlen_q_ori * ngroups;
const int num_heads = num_heads_k;
q_nope = q_nope.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size_nope}).transpose(2, 3)
.reshape({batch_size, seqlen_q, num_heads, head_size_nope});
q_pe = q_pe.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size_pe}).transpose(2, 3)
.reshape({batch_size, seqlen_q, num_heads, head_size_pe});
int head_size_k = head_size;
CHECK_SHAPE(q_nope, batch_size, seqlen_q, num_heads, head_size_nope);
CHECK_SHAPE(q_pe, batch_size, seqlen_q, num_heads, head_size_pe);
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); }
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
CHECK_DEVICE(seqlens_k);
CHECK_CONTIGUOUS(seqlens_k);
CHECK_SHAPE(seqlens_k, batch_size);
// at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q_nope.options();
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(torch::kBFloat16));//1,16,1,512
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));//1,1,16
Flash_fwd_mla_params params = {};
// Set the sizes.
params.b = batch_size;//1
params.seqlen_q = seqlen_q;//16
params.cu_seqlens_k = seqlens_k.data_ptr<int>();
params.h = num_heads;//1
params.h_h_k_ratio = num_heads / num_heads_k;//1
params.ngroups = ngroups;//16
params.is_causal = is_causal;//false
params.d = head_size;//576
params.d_v = head_size_v;//512
params.scale_softmax = softmax_scale;//0.417
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
// Set the pointers and strides.
params.q_nope_ptr = q_nope.data_ptr();
params.q_pe_ptr = q_pe.data_ptr();
params.k_ptr = kcache.data_ptr();
params.v_ptr = vcache.data_ptr();
params.o_ptr = out.data_ptr();
params.softmax_lse_ptr = softmax_lse.data_ptr();
// All stride are in elements, not bytes.
params.q_nope_batch_stride = q_nope.stride(0);
params.q_pe_batch_stride = q_pe.stride(0);
params.k_batch_stride = kcache.stride(0);
params.v_batch_stride = vcache.stride(0);
params.o_batch_stride = out.stride(0);
params.q_nope_row_stride = q_nope.stride(-3);
params.q_pe_row_stride = q_pe.stride(-3);
params.k_row_stride = kcache.stride(-3);
params.v_row_stride = vcache.stride(-3);
params.o_row_stride = out.stride(-3);
params.q_nope_head_stride = q_nope.stride(-2);
params.q_pe_head_stride = q_pe.stride(-2);
params.k_head_stride = kcache.stride(-2);
params.v_head_stride = vcache.stride(-2);
params.o_head_stride = out.stride(-2);
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;//64
params.descale_q_ptr = reinterpret_cast<float *>(descale_q.value().data_ptr());
params.descale_k_ptr = reinterpret_cast<float *>(descale_k.value().data_ptr());
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
CHECK_DEVICE(tile_scheduler_metadata);
CHECK_CONTIGUOUS(tile_scheduler_metadata);
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
params.num_sm_parts = tile_scheduler_metadata.size(0);
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
CHECK_DEVICE(num_splits);
CHECK_CONTIGUOUS(num_splits);
params.num_splits_ptr = num_splits.data_ptr<int>();
at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(head_size == 576);
if (print_param)
{
fprintf(stderr, "[flashmla] [mha_fwd_kvcache_mla_fp8_with_cat] input size q (%d %d %d %d) block_table (%d %d) kcache (%d %d %d %d) is_causal = %d softmax_scale = %.4f \n",
batch_size, seqlen_q_ori, num_heads_ori, head_size, batch_size, max_num_blocks_per_seq,
num_blocks, page_block_size, num_heads_k, head_size_k, is_causal, softmax_scale);
}
if (q_dtype == torch::kFloat8_e4m3fn && kcache.dtype() == torch::kFloat8_e4m3fn)
{
run_mha_fwd_splitkv_mla_fp8<cutlass::float_e4m3_t,cutlass::bfloat16_t, 576>(params,stream,true);
} else if (q_dtype == torch::kBFloat16 && kcache.dtype() == torch::kFloat8_e4m3fn) {
run_mha_fwd_splitkv_mla_fp8<cutlass::bfloat16_t,cutlass::bfloat16_t, 576>(params,stream,true);
}
else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
.reshape({batch_size, num_heads_ori, seqlen_q_ori});
return {out, softmax_lse};
}
// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// m.doc() = "FlashMLA";
// m.def("get_mla_metadata", &get_mla_metadata);
// m.def("get_mla_decoding_metadata_dense_fp8", &get_mla_decoding_metadata_dense_fp8);
// m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla);
// m.def("fwd_kvcache_quantization_mla", &mha_fwd_kvcache_quantization_mla);
// m.def("sparse_prefill_fwd", &sparse_prefill_fwd);
// m.def("fwd_kvcache_quantization_q_nope_pe_mla", &mha_fwd_kvcache_quantization_q_nope_pe_mla);
// m.def("fwd_kvcache_mla_nope_pe", &mha_fwd_kvcache_mla_nope_pe);
// m.def("fwd_kvcache_mla_fp8", &mha_fwd_kvcache_mla_fp8);
// m.def("fwd_kvcache_mla_fp8_with_cat", &mha_fwd_kvcache_mla_fp8_with_cat);
// }
#include "flash_fwd_mla_kernel.h"
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(Flash_fwd_mla_params &params, const std::string& kv_cache_dtype, cudaStream_t stream, bool is_q_nope_pe = false);
#include "flash_fwd_mla_kernel.h"
template void run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(Flash_fwd_mla_params &params, const std::string& kv_cache_dtype, cudaStream_t stream, bool is_q_nope_pe = false);
#include "flash_fwd_mla_kernel_fp8.h"
template void run_mha_fwd_splitkv_mla_fp8<cutlass::float_e4m3_t,cutlass::bfloat16_t,576>(Flash_fwd_mla_params &params, cudaStream_t stream, bool is_with_cat);
#include "flash_fwd_mla_kernel_fp8.h"
template void run_mha_fwd_splitkv_mla_fp8<cutlass::bfloat16_t,cutlass::bfloat16_t,576>(Flash_fwd_mla_params &params, cudaStream_t stream, bool is_with_cat);
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
#include "flash_fwd_mla_kernel.h"
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t,typename elem_type_o = cutlass::bfloat16_t, int kHeadDimV_ = 0, bool is_with_cat=false, typename elem_type_q = elem_type>
struct Flash_fwd_kernel_traits_mla_qkvfp8 {
using Element = elem_type;
using ElementO = elem_type_o;
using ElementQ = elem_type_q;
using ElementAccum = float;
using index_t = int64_t;
static constexpr bool IS_WITH_CAT = is_with_cat;
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * 64;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 64 == 0);
static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim;
static_assert(kHeadDimV % 64 == 0);
static_assert(kHeadDimV <= kHeadDim);
// static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kSwizzle = 3;
//gloalload
//using MMA_Atom_Arch = MMA_Atom<GFX938_16x16x64_F32F8F8F32E4M3E4M3_NT_LIT>;
//bufferload
using MMA_Atom_Arch = MMA_Atom<GFX938_16x16x64_F32F8F8F32E4M3E4M3_NN_LIT>;
using MMA_Atom_Arch_16x32 = MMA_Atom<GFX938_16x32x32_F32F8F8F32E4M3E4M3_NT_LIT>;
using ValLayoutMNK = Layout<Shape<_1, _1, _1>>;
using TiledMma = TiledMMA<
MMA_Atom_Arch,
Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;//
using TiledMma_O = TiledMMA<
MMA_Atom_Arch_16x32,
Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
// using SmemLayoutRow = Layout<Shape<_16, Int<4>>, Stride<_4, _1>>;
// 128*4=512
using SmemLayoutRow = Layout<Shape<_128>, Stride<_1>>;
// using SmemLayoutAtomK = Layout<Shape<Int<kBlockN>, Int<512>>, Stride<_512, _1>>;
using SmemLayoutAtomK = decltype(composition(
Swizzle<kSwizzle, 4, 3>{},
Layout<Shape<Int<8>, Int<64>>, Stride<Int<64>, _1>>{}));
using SmemLayoutK = decltype(tile_to_shape(
SmemLayoutAtomK{},
Shape<Int<kBlockN>, Int<8 * 64>>{}));
using SmemLayoutAtomK_temp = Layout<Shape<Int<kBlockN>, Int<64>>, Stride<_64, _1>>;
using SmemLayoutK_temp = decltype(tile_to_shape(
SmemLayoutAtomK_temp{},
Shape<Int<kBlockN>, Int<7*64>>{}));
using SmemLayoutAtomV = SmemLayoutAtomK;
using SmemLayoutV = decltype(tile_to_shape(
SmemLayoutAtomV{},
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
using SmemLayoutAtomP = Layout<Shape<Int<4*16*16>>, Stride<Int<1>>>;
using SmemLayoutP = decltype(tile_to_shape(
SmemLayoutAtomP{},
Shape<Int<4*16*16>>{}));
using SmemLayoutVtransposed = decltype(
composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
//bufferload
using SmemLayoutAtomQ =
Layout<Shape<Int<16>, Int<64>>, Stride<Int<64>, _1>>;
using SmemLayoutQ = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
// //gloalload
// using SmemLayoutAtomQ = decltype(composition(
// Swizzle<kSwizzle, 3, 3>{},
// Layout<Shape<Int<8>, Int<64>>, Stride<Int<64>, _1>>{}));//8*64
// using SmemLayoutQ = decltype(tile_to_shape(
// SmemLayoutAtomQ{},
// Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemLayoutAtomO = decltype(composition(
Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<Int<8>, Int<64>>, Stride<Int<64>, _1>>{}));
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDimV>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, ElementO>;
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
using GmemLayoutAtomO = Layout<Shape <_16, _16>,
Stride< _16, _1>>;
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementO>{},
GmemLayoutAtomO{},
Layout<Shape<_1, _8>>{}));
using GmemLayoutAtomQ = Layout<Shape <_32, _8>,
Stride< _8, _1>>;
using GmemTiledCopyQ = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomQ{},
Layout<Shape<_1, _8>>{}));
using GmemLayoutAtomOaccum = Layout<Shape <_16, _16>,
Stride< _16, _1>>;
using GmemTiledCopyOaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape<_1, _4>>{}));
};
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t,typename elem_type_o = cutlass::bfloat16_t, int kHeadDimV_ = 0>
struct Flash_fwd_kernel_traits_mla_qkvfp8_TP1 {
using Element = elem_type;
using ElementO = elem_type_o;
using ElementAccum = float;
using index_t = int64_t;
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * 64;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 64 == 0);
static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim;
static_assert(kHeadDimV % 64 == 0);
static_assert(kHeadDimV <= kHeadDim);
// static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kSwizzle = 3;
//gloalload
using MMA_Atom_Arch = MMA_Atom<GFX938_16x16x64_F32F8F8F32E4M3E4M3_NT_LIT>;
//bufferload
// using MMA_Atom_Arch = MMA_Atom<GFX938_16x16x64_F32F8F8F32E4M3E4M3_NN_LIT>;
using MMA_Atom_Arch_16x32 = MMA_Atom<GFX938_16x64x16_F32F8F8F32E4M3E4M3_NT>;
using ValLayoutMNK = Layout<Shape<_1, _1, _1>>;
using TiledMma = TiledMMA<
MMA_Atom_Arch,
Layout<Shape<_4, Int<2>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;//
using TiledMma_O = TiledMMA<
MMA_Atom_Arch_16x32,
Layout<Shape<_4, Int<2>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
// using SmemLayoutRow = Layout<Shape<_16, Int<4>>, Stride<_4, _1>>;
using SmemLayoutRow = Layout<Shape<_256>, Stride<_1>>;
// using SmemLayoutAtomK = Layout<Shape<Int<kBlockN>, Int<512>>, Stride<_512, _1>>;
using SmemLayoutAtomK = decltype(composition(
Swizzle<kSwizzle, 4, 3>{},
Layout<Shape<Int<8>, Int<64>>, Stride<Int<64>, _1>>{}));
using SmemLayoutK = decltype(tile_to_shape(
SmemLayoutAtomK{},
Shape<Int<kBlockN>, Int<10 * 64>>{}));
using SmemLayoutAtomK_temp = Layout<Shape<Int<kBlockN>, Int<64>>, Stride<_64, _1>>;
using SmemLayoutK_temp = decltype(tile_to_shape(
SmemLayoutAtomK_temp{},
Shape<Int<kBlockN>, Int<10*64>>{}));
using SmemLayoutAtomV = SmemLayoutAtomK;
using SmemLayoutV = decltype(tile_to_shape(
SmemLayoutAtomV{},
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
using SmemLayoutAtomP = Layout<Shape<Int<64*64>>, Stride<Int<1>>>;
using SmemLayoutP = decltype(tile_to_shape(
SmemLayoutAtomP{},
Shape<Int<64*64>>{}));
using SmemLayoutVtransposed = decltype(
composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
//bufferload
using SmemLayoutAtomQ =
Layout<Shape<Int<64>, Int<64>>, Stride<Int<64>, _1>>;
using SmemLayoutQ = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
// //gloalload
// using SmemLayoutAtomQ = decltype(composition(
// Swizzle<kSwizzle, 3, 3>{},
// Layout<Shape<Int<8>, Int<64>>, Stride<Int<64>, _1>>{}));//8*64
// using SmemLayoutQ = decltype(tile_to_shape(
// SmemLayoutAtomQ{},
// Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemLayoutAtomO = decltype(composition(
Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<Int<8>, Int<64>>, Stride<Int<64>, _1>>{}));
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDimV>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, ElementO>;
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
using GmemLayoutAtomO = Layout<Shape <_16, _16>,
Stride< _16, _1>>;
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementO>{},
GmemLayoutAtomO{},
Layout<Shape<_1, _8>>{}));
using GmemLayoutAtomQ = Layout<Shape <_32, _8>,
Stride< _8, _1>>;
using GmemTiledCopyQ = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomQ{},
Layout<Shape<_1, _8>>{}));
using GmemLayoutAtomOaccum = Layout<Shape <_16, _16>,
Stride< _16, _1>>;
using GmemTiledCopyOaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape<_1, _4>>{}));
};
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t,typename elem_type_o = cutlass::bfloat16_t, int kHeadDimV_ = 0>
struct Flash_fwd_kernel_traits_mla_qkvfp8_TP4 {
using Element = elem_type;
using ElementO = elem_type_o;
using ElementAccum = float;
using index_t = int64_t;
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * 64;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 64 == 0);
static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim;
static_assert(kHeadDimV % 64 == 0);
static_assert(kHeadDimV <= kHeadDim);
// static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kSwizzle = 3;
//gloalload
using MMA_Atom_Arch = MMA_Atom<GFX938_16x16x64_F32F8F8F32E4M3E4M3_NT_LIT>;
//bufferload
// using MMA_Atom_Arch = MMA_Atom<GFX938_16x16x64_F32F8F8F32E4M3E4M3_NN_LIT>;
using MMA_Atom_Arch_16x32 = MMA_Atom<GFX938_16x64x16_F32F8F8F32E4M3E4M3_NT>;
using ValLayoutMNK = Layout<Shape<_1, _1, _1>>;
using TiledMma = TiledMMA<
MMA_Atom_Arch,
Layout<Shape<_2, Int<2>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;//
using TiledMma_O = TiledMMA<
MMA_Atom_Arch_16x32,
Layout<Shape<_2, Int<2>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
// using SmemLayoutRow = Layout<Shape<_16, Int<4>>, Stride<_4, _1>>;
using SmemLayoutRow = Layout<Shape<_128>, Stride<_1>>;
// using SmemLayoutAtomK = Layout<Shape<Int<kBlockN>, Int<512>>, Stride<_512, _1>>;
using SmemLayoutAtomK = decltype(composition(
Swizzle<kSwizzle, 4, 3>{},
Layout<Shape<Int<8>, Int<64>>, Stride<Int<64>, _1>>{}));
using SmemLayoutK = decltype(tile_to_shape(
SmemLayoutAtomK{},
Shape<Int<kBlockN>, Int<10 * 64>>{}));
using SmemLayoutAtomK_temp = Layout<Shape<Int<kBlockN>, Int<64>>, Stride<_64, _1>>;
using SmemLayoutK_temp = decltype(tile_to_shape(
SmemLayoutAtomK_temp{},
Shape<Int<kBlockN>, Int<10*64>>{}));
using SmemLayoutAtomV = SmemLayoutAtomK;
using SmemLayoutV = decltype(tile_to_shape(
SmemLayoutAtomV{},
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
using SmemLayoutAtomP = Layout<Shape<Int<kBlockM*64>>, Stride<Int<1>>>;
using SmemLayoutP = decltype(tile_to_shape(
SmemLayoutAtomP{},
Shape<Int<kBlockM*64>>{}));
using SmemLayoutVtransposed = decltype(
composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
//bufferload
using SmemLayoutAtomQ =
Layout<Shape<Int<kBlockM>, Int<64>>, Stride<Int<64>, _1>>;
using SmemLayoutQ = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
// //gloalload
// using SmemLayoutAtomQ = decltype(composition(
// Swizzle<kSwizzle, 3, 3>{},
// Layout<Shape<Int<8>, Int<64>>, Stride<Int<64>, _1>>{}));//8*64
// using SmemLayoutQ = decltype(tile_to_shape(
// SmemLayoutAtomQ{},
// Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemLayoutAtomO = decltype(composition(
Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<Int<8>, Int<64>>, Stride<Int<64>, _1>>{}));
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDimV>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, ElementO>;
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
using GmemLayoutAtomO = Layout<Shape <_16, _16>,
Stride< _16, _1>>;
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementO>{},
GmemLayoutAtomO{},
Layout<Shape<_1, _8>>{}));
using GmemLayoutAtomQ = Layout<Shape <_32, _8>,
Stride< _8, _1>>;
using GmemTiledCopyQ = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomQ{},
Layout<Shape<_1, _8>>{}));
using GmemLayoutAtomOaccum = Layout<Shape <_16, _16>,
Stride< _16, _1>>;
using GmemTiledCopyOaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape<_1, _4>>{}));
};
namespace flash {
using namespace cute;
template<typename Kernel_traits>
struct SharedStorageMLAFloat8 {
union {
struct {
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutV>> smem_v; // Double buffer
};
struct {
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK_temp>> smem_temp; // Double buffer
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_row_sum;
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_row_max;
};
struct {
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
};
struct {
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
};
};
};
template<typename Kernel_traits>
struct SharedStorageMLAFloat8_TP1 {
union {
struct {
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutV>> smem_v; // Double buffer
};
struct {
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK_temp>> smem_temp; // Double buffer
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_row_sum;
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_row_max;
};
struct {
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
};
struct {
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
};
};
};
template<typename Kernel_traits>
struct SharedStorageMLAFloat8_TP4 {
union {
struct {
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutV>> smem_v; // Double buffer
};
struct {
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK_temp>> smem_temp; // Double buffer
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_row_sum;
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_row_max;
};
struct {
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
};
struct {
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
};
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Split, typename SharedStorage, typename AccO, typename Softmax>
__forceinline__ __device__ void store_float8(const Flash_fwd_mla_params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx,
SharedStorage &shared_storage, AccO tOrO, Softmax softmax,float descale_k, float scale_softmax) {
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
using Element = typename Kernel_traits::ElementO;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
const int tidx = threadIdx.x;
typename Kernel_traits::TiledMma_O tiled_mma_o;
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
// Epilogue
const int split_offset = __ldg(params.num_splits_ptr + bidb);
Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_sum.data()), typename Kernel_traits::SmemLayoutRow{});
Tensor lse = softmax.template normalize_softmax_lse_fp8</*Is_dropout=*/false, Split>(tOrO, sRow_sum_reduce_buffer, scale_softmax, descale_k);
using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
using SmemTiledCopyO = std::conditional_t<
!Split,
typename Kernel_traits::SmemCopyAtomO,
typename Kernel_traits::SmemCopyAtomOaccum
>;
auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o);
auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor rO = flash::convert_type<ElementO>(tOrO);
Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// __syncthreads();
cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)),
Shape<Int<kBlockM>>{}, Stride<_1>{});
using GmemTiledCopyO = std::conditional_t<!Split, typename Kernel_traits::GmemTiledCopyO, typename Kernel_traits::GmemTiledCopyOaccum>;
GmemTiledCopyO gmem_tiled_copy_Oaccum;
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
__syncthreads();
// if (tidx >= kNThreadsS) { return; }
Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1)
// Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0);
// CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
if (get<1>(taccOcO(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO(0, mi, 0));
if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
}
}
// Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM
);
}
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(const Flash_fwd_mla_params &params,
const int bidb, const int bidh, const int m_block,
const int n_split_idx, const int seqlen_k,
const int n_block_min, const int n_block_max, const bool NoSplit,
SharedStorage &shared_storage,const float descale_k, const float scale_softmax, const float scale_softmax_log2) {
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr static int IS_WITH_CAT = Kernel_traits::IS_WITH_CAT;
using Element = typename Kernel_traits::Element;
using ElementQ = typename Kernel_traits::ElementQ;
using index_t = typename Kernel_traits::index_t;
extern __shared__ char shared_memory[];
const int tidx = threadIdx.x;
const int lane_idx = tidx % 64;
const int warp_idx = __builtin_amdgcn_readfirstlane(tidx / 64);
const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride;
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));//64*576
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.k_row_stride, _1{}));//64*512
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); //16,576
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutV{});//64,512
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutK{}); //64,512
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); //16*64
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});//512,64
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_max.data()), typename Kernel_traits::SmemLayoutRow{}); //64
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);//16*16*64
typename Kernel_traits::TiledMma_O tiled_mma_o;
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);//16*32*32
auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ);
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
union Fp8_storage
{
intx4_t data;
intx2_t p[2];
int bf16[4];
};
#if 0
#else
if constexpr (!IS_WITH_CAT)
{
const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));//16*576
lds_direct_copy_qkvfp8<false, true, true>(gQ, sQ, 0, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8<false, true, true>(gQ, sQ, 1, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8<false, false, true>(gQ, sQ, 2, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3));
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7));
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8));
__syncthreads();
}
else
{
const index_t row_offset_q_nope = bidb * params.q_nope_batch_stride + m_block * kBlockM * params.q_nope_row_stride + bidh * params.q_nope_head_stride;
Tensor gQ_nope = make_tensor(make_gmem_ptr(reinterpret_cast<ElementQ *>(params.q_nope_ptr) + row_offset_q_nope),
Shape<Int<kBlockM>, Int<512>>{},
make_stride(params.q_nope_row_stride, _1{}));
const index_t row_offset_q_pe = bidb * params.q_pe_batch_stride + m_block * kBlockM * params.q_pe_row_stride + bidh * params.q_pe_head_stride;
Tensor gQ_pe = make_tensor(make_gmem_ptr(reinterpret_cast<ElementQ *>(params.q_pe_ptr) + row_offset_q_pe),
Shape<Int<kBlockM>, Int<64>>{},
make_stride(params.q_pe_row_stride, _1{}));
if constexpr (std::is_same_v<cutlass::bfloat16_t, ElementQ>) {
ElementQ* s_q = reinterpret_cast<ElementQ *>(shared_memory);
auto lds_direct_copy_q = [&](const int k_idx, const int offset_k) {
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
if (k_idx == 2) {
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(gQ_pe.data().get());
} else {
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(gQ_nope.data().get());
}
uint32x4_t global_addr = {0};
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = 64 * 8 * 2;
constexpr int bytes_per_block = bytes_per_warp * 4;
const int row_idx = lane_idx % 16;
const int col_idx = lane_idx / 16;
const int row_offset = row_idx;
int col_offset;
int offset_v;
if (k_idx == 2) {
col_offset = warp_idx * 8 + col_idx * 16;
offset_v = (row_offset * params.q_pe_row_stride + col_offset) * 2;
} else {
col_offset = k_idx * 256 + warp_idx * 64 + col_idx * 16 + offset_k * 8;
offset_v = (row_offset * params.q_nope_row_stride + col_offset) * 2;
}
if (k_idx == 2 && warp_idx >= 2) {
offset_v = -1;
}
if (m_block * kBlockM + row_idx >= params.seqlen_q) {
offset_v = -1;
}
const int offset_s = 0;
int ldsAddrPerWave = reinterpret_cast<size_t>(s_q) + warp_idx * bytes_per_warp + k_idx * bytes_per_block
+ offset_k * 3 * bytes_per_block;
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
};
lds_direct_copy_q(0, 0);
lds_direct_copy_q(1, 0);
lds_direct_copy_q(0, 1);
lds_direct_copy_q(1, 1);
lds_direct_copy_q(2, 0);
ElementQ* s_q_read_ptr = s_q + lane_idx * 8;
Fp8_storage bf16_data;
asm volatile("s_waitcnt vmcnt(4) \n s_barrier");
float fp32[8];
union Fp8_temp{
int32_t data;
uint8_t p_fp8[4];
};
for (int k = 0; k < 4; k++) {
bf16_data.data = *reinterpret_cast<intx4_t*>(s_q_read_ptr);
for (int i = 0; i < 4; i++) {
fp32[i * 2] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, false);
fp32[i * 2 + 1] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, true);
}
for (int i = 0; i < 8; i+=4) {
Fp8_temp fp8_tmp;
fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i], fp32[i + 1], fp8_tmp.data, false);
fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i + 2], fp32[i + 3], fp8_tmp.data, true);
tSrQ(i, 0, k).storage = fp8_tmp.p_fp8[0];
tSrQ(i + 1, 0, k).storage = fp8_tmp.p_fp8[1];
tSrQ(i + 2, 0, k).storage = fp8_tmp.p_fp8[2];
tSrQ(i + 3, 0, k).storage = fp8_tmp.p_fp8[3];
}
s_q_read_ptr += 16 * 32;
}
asm volatile("s_waitcnt vmcnt(3) \n s_barrier");
for (int k = 4; k < 8; k++) {
bf16_data.data = *reinterpret_cast<intx4_t*>(s_q_read_ptr);
for (int i = 0; i < 4; i++) {
fp32[i * 2] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, false);
fp32[i * 2 + 1] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, true);
}
for (int i = 0; i < 8; i+=4) {
Fp8_temp fp8_tmp;
fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i], fp32[i + 1], fp8_tmp.data, false);
fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i + 2], fp32[i + 3], fp8_tmp.data, true);
tSrQ(i, 0, k).storage = fp8_tmp.p_fp8[0];
tSrQ(i + 1, 0, k).storage = fp8_tmp.p_fp8[1];
tSrQ(i + 2, 0, k).storage = fp8_tmp.p_fp8[2];
tSrQ(i + 3, 0, k).storage = fp8_tmp.p_fp8[3];
}
s_q_read_ptr += 16 * 32;
}
asm volatile("s_waitcnt vmcnt(2) \n s_barrier");
s_q_read_ptr = s_q + lane_idx * 8 + 3 * 4 * 16 * 4 * 8;
for (int k = 0; k < 4; k++) {
bf16_data.data = *reinterpret_cast<intx4_t*>(s_q_read_ptr);
for (int i = 0; i < 4; i++) {
fp32[i * 2] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, false);
fp32[i * 2 + 1] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, true);
}
for (int i = 8; i < 16; i+=4) {
Fp8_temp fp8_tmp;
fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i - 8], fp32[i + 1 - 8], fp8_tmp.data, false);
fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i + 2 - 8], fp32[i + 3 - 8], fp8_tmp.data, true);
tSrQ(i, 0, k).storage = fp8_tmp.p_fp8[0];
tSrQ(i + 1, 0, k).storage = fp8_tmp.p_fp8[1];
tSrQ(i + 2, 0, k).storage = fp8_tmp.p_fp8[2];
tSrQ(i + 3, 0, k).storage = fp8_tmp.p_fp8[3];
}
s_q_read_ptr += 16 * 32;
}
asm volatile("s_waitcnt vmcnt(1) \n s_barrier");
for (int k = 4; k < 8; k++) {
bf16_data.data = *reinterpret_cast<intx4_t*>(s_q_read_ptr);
for (int i = 0; i < 4; i++) {
fp32[i * 2] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, false);
fp32[i * 2 + 1] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, true);
}
for (int i = 8; i < 16; i+=4) {
Fp8_temp fp8_tmp;
fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i - 8], fp32[i + 1 - 8], fp8_tmp.data, false);
fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i + 2 - 8], fp32[i + 3 - 8], fp8_tmp.data, true);
tSrQ(i, 0, k).storage = fp8_tmp.p_fp8[0];
tSrQ(i + 1, 0, k).storage = fp8_tmp.p_fp8[1];
tSrQ(i + 2, 0, k).storage = fp8_tmp.p_fp8[2];
tSrQ(i + 3, 0, k).storage = fp8_tmp.p_fp8[3];
}
s_q_read_ptr += 16 * 32;
}
asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
s_q_read_ptr = s_q + lane_idx * 8 + 2 * 4 * 16 * 4 * 8;
for (int k = 8; k < 9; k++) {
bf16_data.data = *reinterpret_cast<intx4_t*>(s_q_read_ptr);
for (int i = 0; i < 4; i++) {
fp32[i * 2] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, false);
fp32[i * 2 + 1] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, true);
}
for (int i = 0; i < 8; i+=4) {
Fp8_temp fp8_tmp;
fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i], fp32[i + 1], fp8_tmp.data, false);
fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i + 2], fp32[i + 3], fp8_tmp.data, true);
tSrQ(i, 0, k).storage = fp8_tmp.p_fp8[0];
tSrQ(i + 1, 0, k).storage = fp8_tmp.p_fp8[1];
tSrQ(i + 2, 0, k).storage = fp8_tmp.p_fp8[2];
tSrQ(i + 3, 0, k).storage = fp8_tmp.p_fp8[3];
}
s_q_read_ptr += 16 * 32;
}
for (int k = 8; k < 9; k++) {
bf16_data.data = *reinterpret_cast<intx4_t*>(s_q_read_ptr);
for (int i = 0; i < 4; i++) {
fp32[i * 2] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, false);
fp32[i * 2 + 1] = __builtin_hcu_cvt_f32_bf16(bf16_data.bf16[i], false, 0, true);
}
for (int i = 8; i < 16; i+=4) {
Fp8_temp fp8_tmp;
fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i - 8], fp32[i + 1 - 8], fp8_tmp.data, false);
fp8_tmp.data = __builtin_hcu_cvt_pk_fp8_f32(fp32[i + 2 - 8], fp32[i + 3 - 8], fp8_tmp.data, true);
tSrQ(i, 0, k).storage = fp8_tmp.p_fp8[0];
tSrQ(i + 1, 0, k).storage = fp8_tmp.p_fp8[1];
tSrQ(i + 2, 0, k).storage = fp8_tmp.p_fp8[2];
tSrQ(i + 3, 0, k).storage = fp8_tmp.p_fp8[3];
}
s_q_read_ptr += 16 * 32;
}
__syncthreads();
} else {
lds_direct_copy_qkvfp8<false, true, true>(gQ_nope, sQ, 0, params.q_nope_head_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8<false, true, true>(gQ_nope, sQ, 1, params.q_nope_head_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8_pe<false, false, true>(gQ_pe, sQ, 2, params.q_pe_head_stride, params.seqlen_q - m_block * kBlockM);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3));
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7));
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8));
__syncthreads();
}
}
// asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
#endif
auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
Tensor tSrK = thr_mma.partition_fragment_B(gK); //将sk中数据按照tiled_mma拷贝到tSrQ
// Tensor tSrK_smem = thr_mma.partition_fragment_B(gK); //128,64 gk->rk
auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_DS_M32x32_B8, Element>{}, tiled_mma_o);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
Tensor tOrVt = thr_mma_o.partition_fragment_B(sVtNoSwizzle);
constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
int n_block = n_block_max - 1;
constexpr static int k1_loops = size<2>(tOrVt);
// Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{});
// clear(acc_o);
flash::Softmax<1> softmax;
constexpr static int STAGE = 8;
#if 1
v4f c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1;
c0_0.x = 0.0f; c0_0.y = 0.0f; c0_0.z = 0.0f; c0_0.w = 0.0f;
c0_1.x = 0.0f; c0_1.y = 0.0f; c0_1.z = 0.0f; c0_1.w = 0.0f;
c1_0.x = 0.0f; c1_0.y = 0.0f; c1_0.z = 0.0f; c1_0.w = 0.0f;
c1_1.x = 0.0f; c1_1.y = 0.0f; c1_1.z = 0.0f; c1_1.w = 0.0f;
c2_0.x = 0.0f; c2_0.y = 0.0f; c2_0.z = 0.0f; c2_0.w = 0.0f;
c2_1.x = 0.0f; c2_1.y = 0.0f; c2_1.z = 0.0f; c2_1.w = 0.0f;
c3_0.x = 0.0f; c3_0.y = 0.0f; c3_0.z = 0.0f; c3_0.w = 0.0f;
c3_1.x = 0.0f; c3_1.y = 0.0f; c3_1.z = 0.0f; c3_1.w = 0.0f;
// #pragma unroll
for (int masking_step = 0; n_block >= n_block_min; ++masking_step, --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt);
clear(acc_s);
// asm volatile("s_barrier\n\t");
Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block;
// cur_block_table = block_table[n_block - 1];
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK.data() + (offset_k);
#if 1
lds_direct_copy_qkvfp8<false, true>(gK, sK, 0, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 1, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 2, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 3, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 4, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 5, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 6, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 7, params.k_row_stride, seqlen_k - n_block * kBlockN);
constexpr static int BUFFER_SIZE = 1;
uint128_t buffer[BUFFER_SIZE];
buffer_load_copy_qkvfp8<false, true, true, true>(gK, buffer[0], 8, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(8) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 0), tSrK_copy_view(_, _, 0));
cute::gemm(tiled_mma, tSrQ(_, _, 0), tSrK(_, _, 0), acc_s);
asm volatile("s_waitcnt vmcnt(7) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 1), tSrK_copy_view(_, _, 1));
cute::gemm(tiled_mma, tSrQ(_, _, 1), tSrK(_, _, 1), acc_s);
asm volatile("s_waitcnt vmcnt(6) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 2), tSrK_copy_view(_, _, 2));
cute::gemm(tiled_mma, tSrQ(_, _, 2), tSrK(_, _, 2), acc_s);
asm volatile("s_waitcnt vmcnt(5) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 3), tSrK_copy_view(_, _, 3));
cute::gemm(tiled_mma, tSrQ(_, _, 3), tSrK(_, _, 3), acc_s);
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 4), tSrK_copy_view(_, _, 4));
cute::gemm(tiled_mma, tSrQ(_, _, 4), tSrK(_, _, 4), acc_s);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 5), tSrK_copy_view(_, _, 5));
cute::gemm(tiled_mma, tSrQ(_, _, 5), tSrK(_, _, 5), acc_s);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 6), tSrK_copy_view(_, _, 6));
cute::gemm(tiled_mma, tSrQ(_, _, 6), tSrK(_, _, 6), acc_s);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 7), tSrK_copy_view(_, _, 7));
Fp8_storage v3_0, v3_1;
__ds_read_m32x32_row_col_rrow<3, 0, 3>(tOsVt, v3_0.data);
__ds_read_m32x32_row_col_rrow<3, 1, 3>(tOsVt, v3_1.data);
cute::gemm(tiled_mma, tSrQ(_, _, 7), tSrK(_, _, 7), acc_s);
asm volatile("s_waitcnt vmcnt(0) \n\t");
buffer_to_tensor(buffer[0], tSrK, 8);
cute::gemm(tiled_mma, tSrQ(_, _, 8), tSrK(_, _, 8), acc_s);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
#else
#endif
gK.data() = gK.data() + (-offset_k);
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tScS = thr_mma.partition_C(cS);
for (int i = 0; i < size(acc_s); ++i) {
if constexpr (!Is_causal) {
if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY;
} else {
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int row = int(get<0>(tScS(i)));
int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;
if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY;
}
}
// We have key_padding_mask so we'll need to Check_inf
// if constexpr (n_masking_steps == 1)
// {
// softmax.template softmax_rescale_o_fp8</*Is_first=*/true, /*Check_inf=*/Is_causal>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1);
// }
// else
{
const bool is_first_masking_step = masking_step == 0;
is_first_masking_step
? softmax.template softmax_rescale_o_fp8</*Is_first=*/true, /*Check_inf=*/Is_causal>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1)
: softmax.template softmax_rescale_o_fp8</*Is_first=*/false, /*Check_inf=*/Is_causal>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1);
}
// Tensor rP = flash::convert_type<Element>(acc_s);
Fp8_storage data_fp8;
// convert_layout_acc_Aregs_fp8(tiled_mma, tiled_mma_o, rP, sP, data_fp8.data);
{
int tid = threadIdx.x % 64;
int warp_id = threadIdx.x / 64;
int32_t result;
result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(0), acc_s(1), result, false);
result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(2), acc_s(3), result, true);
int32_t* lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id % 2) * (16 * 32) + (warp_id / 2) * (8) + 0]));
*lds_ptr = result;
__syncthreads();
data_fp8.data = *reinterpret_cast<intx4_t*>(&(sP[tid * 16]));
}
{
Fp8_storage v0_0, v0_1;
Fp8_storage v1_0, v1_1;
Fp8_storage v2_0, v2_1;
__builtin_amdgcn_sched_barrier(0);
__ds_read_m32x32_row_col_rrow<0, 0, 0>(tOsVt, v0_0.data);
__ds_read_m32x32_row_col_rrow<1, 0, 1>(tOsVt, v1_0.data);
__ds_read_m32x32_row_col_rrow<2, 0, 2>(tOsVt, v2_0.data);
c3_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v3_0.p[0], c3_0, true, false);
c3_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v3_0.p[1], c3_1, true, false);
c3_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v3_1.p[0], c3_0, true, false);
c3_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v3_1.p[1], c3_1, true, false);
c0_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v0_0.p[0], c0_0, true, false);
c0_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v0_0.p[1], c0_1, true, false);
c1_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v1_0.p[0], c1_0, true, false);
c1_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v1_0.p[1], c1_1, true, false);
c2_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v2_0.p[0], c2_0, true, false);
c2_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v2_0.p[1], c2_1, true, false);
__ds_read_m32x32_row_col_rrow<0, 1, 0>(tOsVt, v0_1.data);
__ds_read_m32x32_row_col_rrow<1, 1, 1>(tOsVt, v1_1.data);
__ds_read_m32x32_row_col_rrow<2, 1, 2>(tOsVt, v2_1.data);
c0_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v0_1.p[0], c0_0, true, false);
c0_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v0_1.p[1], c0_1, true, false);
c1_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v1_1.p[0], c1_0, true, false);
c1_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v1_1.p[1], c1_1, true, false);
c2_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v2_1.p[0], c2_0, true, false);
c2_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v2_1.p[1], c2_1, true, false);
__builtin_amdgcn_sched_barrier(0);
}
}
#endif
Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{});
acc_o(0, 0, 0) = c0_0.x; acc_o(1, 0, 0) = c0_0.y; acc_o(2, 0, 0) = c0_0.z; acc_o(3, 0, 0) = c0_0.w;
acc_o(4, 0, 0) = c0_1.x; acc_o(5, 0, 0) = c0_1.y; acc_o(6, 0, 0) = c0_1.z; acc_o(7, 0, 0) = c0_1.w;
acc_o(0, 0, 1) = c1_0.x; acc_o(1, 0, 1) = c1_0.y; acc_o(2, 0, 1) = c1_0.z; acc_o(3, 0, 1) = c1_0.w;
acc_o(4, 0, 1) = c1_1.x; acc_o(5, 0, 1) = c1_1.y; acc_o(6, 0, 1) = c1_1.z; acc_o(7, 0, 1) = c1_1.w;
acc_o(0, 0, 2) = c2_0.x; acc_o(1, 0, 2) = c2_0.y; acc_o(2, 0, 2) = c2_0.z; acc_o(3, 0, 2) = c2_0.w;
acc_o(4, 0, 2) = c2_1.x; acc_o(5, 0, 2) = c2_1.y; acc_o(6, 0, 2) = c2_1.z; acc_o(7, 0, 2) = c2_1.w;
acc_o(0, 0, 3) = c3_0.x; acc_o(1, 0, 3) = c3_0.y; acc_o(2, 0, 3) = c3_0.z; acc_o(3, 0, 3) = c3_0.w;
acc_o(4, 0, 3) = c3_1.x; acc_o(5, 0, 3) = c3_1.y; acc_o(6, 0, 3) = c3_1.z; acc_o(7, 0, 3) = c3_1.w;
if (NoSplit)
store_float8<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax, descale_k, scale_softmax);
else
store_float8<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax, descale_k, scale_softmax);
}
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP1(const Flash_fwd_mla_params &params,
const int bidb, const int bidh, const int m_block,
const int n_split_idx, const int seqlen_k,
const int n_block_min, const int n_block_max, const bool NoSplit,
SharedStorage &shared_storage,const float descale_k, const float scale_softmax, const float scale_softmax_log2) {
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
using Element = typename Kernel_traits::Element;
using index_t = typename Kernel_traits::index_t;
const int tidx = threadIdx.x;
const int warp_id = tidx / 64;
const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));//16*576
const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride;
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));//64*576
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.k_row_stride, _1{}));//64*512
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); //16,576
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutV{});//64,512
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutK{}); //64,512
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); //16*64
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});//512,64
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_max.data()), typename Kernel_traits::SmemLayoutRow{}); //64
Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_sum.data()), typename Kernel_traits::SmemLayoutRow{});
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);//16*16*64
typename Kernel_traits::TiledMma_O tiled_mma_o;
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);//16*32*32
union Fp8_storage
{
intx4_t data;
intx2_t p[2];
int32_t fp8_array[4];
};
#if 0
auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tSrQ = thr_mma.partition_fragment_A(gQ);
Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ)));
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ);
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tSgQ)));
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tSgQ, tSrQ, tQcQ, tQpQ,
params.seqlen_q - m_block * kBlockM);
__syncthreads();
#else
Tensor tSrQ = thr_mma.partition_fragment_A(gQ);
lds_direct_copy_qkvfp8_q_tp1<false, true>(gQ, sQ, 0, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8_q_tp1<false, true>(gQ, sQ, 1, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8_q_tp1<false, true>(gQ, sQ, 2, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8_q_tp1<false, true>(gQ, sQ, 3, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8_q_tp1<false, false>(gQ, sQ, 4, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t");
uint8_t* q_lds_read_ptr = reinterpret_cast<uint8_t*>(sQ.data().get()) + (tidx % 64) * 16 + (warp_id % 4) * (16 * 64);
{
int k = 0;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k).storage = q_lds_read_ptr[i];
}
q_lds_read_ptr += 64*64;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k+1).storage = q_lds_read_ptr[i];
}
// int k = 0;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
{
q_lds_read_ptr += 64*64;
int k = 2;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k).storage = q_lds_read_ptr[i];
}
q_lds_read_ptr += 64*64;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k+1).storage = q_lds_read_ptr[i];
}
// q_lds_read_ptr += 64*64;
// int k = 2;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
{
q_lds_read_ptr += 64*64;
int k = 4;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k).storage = q_lds_read_ptr[i];
}
q_lds_read_ptr += 64*64;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k+1).storage = q_lds_read_ptr[i];
}
// q_lds_read_ptr += 64*64;
// int k = 4;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
{
q_lds_read_ptr += 64*64;
int k = 6;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k).storage = q_lds_read_ptr[i];
}
q_lds_read_ptr += 64*64;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k+1).storage = q_lds_read_ptr[i];
}
// q_lds_read_ptr += 64*64;
// int k = 6;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
{
q_lds_read_ptr += 64*64;
int k = 8;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k).storage = q_lds_read_ptr[i];
}
// q_lds_read_ptr += 64*64;
// int k = 8;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
__syncthreads();
#endif
auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
Tensor tSrK = thr_mma.partition_fragment_B(gK); //将sk中数据按照tiled_mma拷贝到tSrQ
// Tensor tSrK_smem = thr_mma.partition_fragment_B(gK); //128,64 gk->rk
auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_DS_M64x16_B8, Element>{}, tiled_mma_o);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
Tensor tOrVt = thr_mma_o.partition_fragment_B(sVtNoSwizzle);
constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
int n_block = n_block_max - 1;
union Val
{
intx2_t val_to_mmac;
int32_t data[2];
};
// Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{});
// clear(acc_o);
flash::Softmax<1> softmax;
v4f acco_f32[16];
for (int i = 0; i < 16; i++)
{
acco_f32[i].x = 0.0f;
acco_f32[i].y = 0.0f;
acco_f32[i].z = 0.0f;
acco_f32[i].w = 0.0f;
}
constexpr static int STAGE = 8;
for (int masking_step = 0; n_block >= n_block_min; ++masking_step, --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt);
clear(acc_s);
Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
// asm volatile("s_barrier \n\t");
int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block;
// cur_block_table = block_table[n_block - 1];
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK.data() + (offset_k);
lds_direct_copy_qkvfp8_tp1<false, true>(gK, sK, 0, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8_tp1<false, true>(gK, sK, 1, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8_tp1<false, true>(gK, sK, 2, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8_tp1<false, true>(gK, sK, 3, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8_tp1<false, false>(gK, sK, 4, params.k_row_stride, seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 0), tSrK_copy_view(_, _, 0));
cute::gemm(tiled_mma, tSrQ(_, _, 0), tSrK(_, _, 0), acc_s);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 1), tSrK_copy_view(_, _, 1));
cute::gemm(tiled_mma, tSrQ(_, _, 1), tSrK(_, _, 1), acc_s);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 2), tSrK_copy_view(_, _, 2));
cute::gemm(tiled_mma, tSrQ(_, _, 2), tSrK(_, _, 2), acc_s);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 3), tSrK_copy_view(_, _, 3));
cute::gemm(tiled_mma, tSrQ(_, _, 3), tSrK(_, _, 3), acc_s);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 4), tSrK_copy_view(_, _, 4));
cute::gemm(tiled_mma, tSrQ(_, _, 4), tSrK(_, _, 4), acc_s);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 5), tSrK_copy_view(_, _, 5));
cute::gemm(tiled_mma, tSrQ(_, _, 5), tSrK(_, _, 5), acc_s);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 6), tSrK_copy_view(_, _, 6));
cute::gemm(tiled_mma, tSrQ(_, _, 6), tSrK(_, _, 6), acc_s);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 7), tSrK_copy_view(_, _, 7));
cute::gemm(tiled_mma, tSrQ(_, _, 7), tSrK(_, _, 7), acc_s);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 8), tSrK_copy_view(_, _, 8));
cute::gemm(tiled_mma, tSrQ(_, _, 8), tSrK(_, _, 8), acc_s);
gK.data() = gK.data() + (-offset_k);
// asm volatile("s_barrier \n\t");
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tScS = thr_mma.partition_C(cS);
for (int i = 0; i < size(acc_s); ++i) {
if constexpr (!Is_causal) {
if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY;
} else {
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int row = int(get<0>(tScS(i)));
int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;
if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY;
}
}
// asm volatile("s_barrier \n\t");
{
const bool is_first_masking_step = masking_step == 0;
// is_first_masking_step
// ? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, scale_softmax_log2)
// : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, scale_softmax_log2);
is_first_masking_step
? softmax.template softmax_rescale_o_fp8_tp1</*Is_first=*/true, /*Check_inf=*/Is_causal, true>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, acco_f32)
: softmax.template softmax_rescale_o_fp8_tp1</*Is_first=*/false, /*Check_inf=*/Is_causal, true>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, acco_f32);
}
// asm volatile("s_barrier \n\t");
// if (block0() && tidx < 64)
// {
// // printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w,
// // acco_f32[1].x, acco_f32[1].y, acco_f32[1].z, acco_f32[1].w,
// // acco_f32[2].x, acco_f32[2].y, acco_f32[2].z, acco_f32[2].w,
// // acco_f32[3].x, acco_f32[3].y, acco_f32[3].z, acco_f32[3].w
// // );
// printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acc_s(0), acc_s(1), acc_s(2), acc_s(3),
// acc_s(4), acc_s(5), acc_s(6), acc_s(7)
// // acc_s(8), acc_s(9), acc_s(10), acc_s(11),
// // acc_s(12), acc_s(13), acc_s(14), acc_s(15)
// );
// }
#if 1
Fp8_storage p_fp8;
{
__builtin_amdgcn_sched_barrier(0);
int tid = threadIdx.x % 64;
int warp_id = threadIdx.x / 64;
int32_t result;
result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(0, 0, 0), acc_s(1, 0, 0), result, false);
result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(2, 0, 0), acc_s(3, 0, 0), result, true);
// int32_t* lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id / 4) * 16 * 32 + (warp_id % 4) * 16 * 64]));
// *lds_ptr = result;
int32_t result1;
result1 = __builtin_hcu_cvt_pk_fp8_f32(acc_s(0, 0, 1), acc_s(1, 0, 1), result1, false);
result1 = __builtin_hcu_cvt_pk_fp8_f32(acc_s(2, 0, 1), acc_s(3, 0, 1), result1, true);
// lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id / 4) * 16 * 32 + (warp_id % 4) * 16 * 64 + 8]));
// *lds_ptr = result1;
int32_t* lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid % 64) / 16) * 16 * 16 + (warp_id / 4) * 4 + (warp_id % 4) * 16 * 64]));
*lds_ptr = result;
int32_t* lds_ptr1 = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid % 64) / 16 )* 16 * 16 + (warp_id / 4) * 4 + (warp_id % 4) * 16 * 64 + 8]));
*lds_ptr1 = result1;
__syncthreads();
p_fp8.data = *reinterpret_cast<intx4_t*>(&(sP[tid * 16 + (warp_id % 4) * 16 * 64]));
__builtin_amdgcn_sched_barrier(0);
}
{
__builtin_amdgcn_sched_barrier(0);
for (int i = 0; i < 4; i++)
{
{
int k = 0;
Fp8_storage v0_0, v0_1;
v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(&(tOsVt(0, i, k))));
v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(&(tOsVt(0, i, k+1))));
// if (block0() && tidx < 64)
// {
// float v0 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 0);
// float v1 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 1);
// float v2 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 2);
// float v3 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 3);
// printf("tid = %d %.3f %.3f %.3f %.3f %x \n", tidx, v0, v1, v2, v3, v0_0.fp8_array[0]);
// }
for (int j = 0; j < 4; j++)
{
Val tmp;
tmp.data[0] = v0_0.fp8_array[j];
tmp.data[1] = v0_1.fp8_array[j];
acco_f32[i * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[k/2], tmp.val_to_mmac, acco_f32[i * 4 + j], true, false);
}
}
{
int k = 2;
Fp8_storage v0_0, v0_1;
v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(&(tOsVt(0, i, k))));
v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(&(tOsVt(0, i, k+1))));
// if (block0() && tidx < 64)
// {
// float v0 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 0);
// float v1 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 1);
// float v2 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 2);
// float v3 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 3);
// printf("tid = %d %.3f %.3f %.3f %.3f %x \n", tidx, v0, v1, v2, v3, v0_0.fp8_array[0]);
// }
for (int j = 0; j < 4; j++)
{
Val tmp;
tmp.data[0] = v0_0.fp8_array[j];
tmp.data[1] = v0_1.fp8_array[j];
acco_f32[i * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[k/2], tmp.val_to_mmac, acco_f32[i * 4 + j], true, false);
}
}
}
__builtin_amdgcn_sched_barrier(0);
}
// if (block0())
// {
// printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w,
// acco_f32[1].x, acco_f32[1].y, acco_f32[1].z, acco_f32[1].w,
// acco_f32[2].x, acco_f32[2].y, acco_f32[2].z, acco_f32[2].w,
// acco_f32[3].x, acco_f32[3].y, acco_f32[3].z, acco_f32[3].w
// );
// // printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acc_s(0), acc_s(1), acc_s(2), acc_s(3),
// // acc_s(4), acc_s(5), acc_s(6), acc_s(7),
// // acc_s(8), acc_s(9), acc_s(10), acc_s(11),
// // acc_s(12), acc_s(13), acc_s(14), acc_s(15)
// // );
// }
asm volatile("s_barrier \n\t");
#endif
}
using ElementO = typename Kernel_traits::ElementO;
using ElementAccum = typename Kernel_traits::ElementAccum;
const int split_offset = __ldg(params.num_splits_ptr + bidb);
// Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_sum.data()), typename Kernel_traits::SmemLayoutRow{});
const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
if (NoSplit) {
constexpr bool Split = false;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor lse = softmax.template normalize_softmax_lse_fp8_tp1</*Is_dropout=*/false, Split, true>(acco_f32, sRow_sum_reduce_buffer, scale_softmax, descale_k);
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma_o.partition_C(caccO);
if (get<1>(taccOcO(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO(0, mi, 0));
if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
}
}
// if (tidx == 1)
// {
// printf(" %.4f %.4f %.4f %.4f \n ", acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w);
// }
{
using result_type = cutlass::Array<bfloat16_t, 2>;
int tidx = threadIdx.x;
int col = 0;
int warpid = tidx / 64;
for (int m = 0; m < 1; m++) {
const int row = tidx % 16 + (warpid % 4) * 16;
if (row < params.seqlen_q - m_block * kBlockM) {
for (int n = 0; n < 4; n++)
{
col = (tidx % 64 / 16) * 16 + n * 128 + (warpid / 4) * 64;
{
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 0].x, 0, acco_f32[n * 4 + 1].x, 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 2].x, 0, acco_f32[n * 4 + 3].x, 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gOaccum(row, col) = res0[0];
gOaccum(row, col + 1) = res0[1];
gOaccum(row, col + 2) = res1[0];
gOaccum(row, col + 3) = res1[1];
col += 4;
}
{
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 0].y, 0, acco_f32[n * 4 + 1].y, 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 2].y, 0, acco_f32[n * 4 + 3].y, 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gOaccum(row, col) = res0[0];
gOaccum(row, col + 1) = res0[1];
gOaccum(row, col + 2) = res1[0];
gOaccum(row, col + 3) = res1[1];
col += 4;
}
{
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 0].z, 0, acco_f32[n * 4 + 1].z, 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 2].z, 0, acco_f32[n * 4 + 3].z, 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gOaccum(row, col) = res0[0];
gOaccum(row, col + 1) = res0[1];
gOaccum(row, col + 2) = res1[0];
gOaccum(row, col + 3) = res1[1];
col += 4;
}
{
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 0].w, 0, acco_f32[n * 4 + 1].w, 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 2].w, 0, acco_f32[n * 4 + 3].w, 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gOaccum(row, col) = res0[0];
gOaccum(row, col + 1) = res0[1];
gOaccum(row, col + 2) = res1[0];
gOaccum(row, col + 3) = res1[1];
// col += 16;
}
// for (int j = 0; j < 4; j++)
// {
// auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + j].x, 0, acco_f32[n * 4 + j].y, 0);
// auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + j].z, 0, acco_f32[n * 4 + j].w, 0);
// auto res0 = reinterpret_cast<result_type const &>(d0);
// auto res1 = reinterpret_cast<result_type const &>(d1);
// gOaccum(row, col) = res0[0];
// gOaccum(row, col + 1) = res0[1];
// gOaccum(row, col + 2) = res1[0];
// gOaccum(row, col + 3) = res1[1];
// col += 16;
// }
}
// for (int n = 0; n < 8; n++)
// {
// using result_type = cutlass::Array<bfloat16_t, 2>;
// auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n].x, 0, acco_f32[n].y, 0);
// auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n].z, 0, acco_f32[n].w, 0);
// col = (tidx % 64 / 16) * 4 + n * 64;
// auto res0 = reinterpret_cast<result_type const &>(d0);
// auto res1 = reinterpret_cast<result_type const &>(d1);
// gOaccum(row, col) = res0[0];
// gOaccum(row, col + 1) = res0[1];
// gOaccum(row, col + 2) = res1[0];
// gOaccum(row, col + 3) = res1[1];
// }
}
}
}
} else {
constexpr bool Split = true;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor lse = softmax.template normalize_softmax_lse_fp8_tp1</*Is_dropout=*/false, Split, true>(acco_f32, sRow_sum_reduce_buffer, scale_softmax, descale_k);
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma_o.partition_C(caccO);
if (get<1>(taccOcO(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO(0, mi, 0));
if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
}
}
int tidx = threadIdx.x;
int col = 0;
int warpid = tidx / 64;
for (int m = 0; m < 1; m++) {
const int row = tidx % 16 + (warpid % 4) * 16;
if (row < params.seqlen_q - m_block * kBlockM) {
// for (int n = 0; n < 32; n++)
// {
// col = (tidx % 64 / 16) * 4 + n * 16;
// gOaccum(row, col) = acco_f32[n].x;
// gOaccum(row, col + 1) = acco_f32[n].y;
// gOaccum(row, col + 2) = acco_f32[n].z;
// gOaccum(row, col + 3) = acco_f32[n].w;
// }
for (int n = 0; n < 4; n++)
{
col = (tidx % 64 / 16) * 16 + n * 128 + (warp_id / 4) * 64;
{
gOaccum(row, col) = acco_f32[n * 4 + 0].x;
gOaccum(row, col + 1) = acco_f32[n * 4 + 1].x;
gOaccum(row, col + 2) = acco_f32[n * 4 + 2].x;
gOaccum(row, col + 3) = acco_f32[n * 4 + 3].x;
col += 4;
}
{
gOaccum(row, col) = acco_f32[n * 4 + 0].y;
gOaccum(row, col + 1) = acco_f32[n * 4 + 1].y;
gOaccum(row, col + 2) = acco_f32[n * 4 + 2].y;
gOaccum(row, col + 3) = acco_f32[n * 4 + 3].y;
col += 4;
}
{
gOaccum(row, col) = acco_f32[n * 4 + 0].z;
gOaccum(row, col + 1) = acco_f32[n * 4 + 1].z;
gOaccum(row, col + 2) = acco_f32[n * 4 + 2].z;
gOaccum(row, col + 3) = acco_f32[n * 4 + 3].z;
col += 4;
}
{
gOaccum(row, col) = acco_f32[n * 4 + 0].w;
gOaccum(row, col + 1) = acco_f32[n * 4 + 1].w;
gOaccum(row, col + 2) = acco_f32[n * 4 + 2].w;
gOaccum(row, col + 3) = acco_f32[n * 4 + 3].w;
}
// for (int j = 0; j < 4; j++) {
// gOaccum(row, col) = acco_f32[n * 4 + j].x;
// gOaccum(row, col + 1) = acco_f32[n * 4 + j].y;
// gOaccum(row, col + 2) = acco_f32[n * 4 + j].z;
// gOaccum(row, col + 3) = acco_f32[n * 4 + j].w;
// col += 16;
// }
}
}
}
}
// Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{});
// for (int n = 0; n < 8; n++)
// {
// acc_o(0, 0, n) = acco_f32[n * 2].x;
// acc_o(1, 0, n) = acco_f32[n * 2].y;
// acc_o(2, 0, n) = acco_f32[n * 2].z;
// acc_o(3, 0, n) = acco_f32[n * 2].w;
// acc_o(4, 0, n) = acco_f32[n * 2 + 1].x;
// acc_o(5, 0, n) = acco_f32[n * 2 + 1].y;
// acc_o(6, 0, n) = acco_f32[n * 2 + 1].z;
// acc_o(7, 0, n) = acco_f32[n * 2 + 1].w;
// }
// if (NoSplit)
// store_float8<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax, descale_k, scale_softmax);
// else
// store_float8<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax, descale_k, scale_softmax);
}
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP4(const Flash_fwd_mla_params &params,
const int bidb, const int bidh, const int m_block,
const int n_split_idx, const int seqlen_k,
const int n_block_min, const int n_block_max, const bool NoSplit,
SharedStorage &shared_storage,const float descale_k, const float scale_softmax, const float scale_softmax_log2) {
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
using Element = typename Kernel_traits::Element;
using index_t = typename Kernel_traits::index_t;
const int tidx = threadIdx.x;
const int warp_id = tidx / 64;
const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));//16*576
const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride;
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));//64*576
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.k_row_stride, _1{}));//64*512
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); //16,576
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutV{});//64,512
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Kernel_traits::SmemLayoutK{}); //64,512
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); //16*64
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});//512,64
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_max.data()), typename Kernel_traits::SmemLayoutRow{}); //64
Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_sum.data()), typename Kernel_traits::SmemLayoutRow{});
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);//16*16*64
typename Kernel_traits::TiledMma_O tiled_mma_o;
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);//16*32*32
union Fp8_storage
{
intx4_t data;
intx2_t p[2];
int32_t fp8_array[4];
};
#if 0
auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tSrQ = thr_mma.partition_fragment_A(gQ);
Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ)));
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ);
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tSgQ)));
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tSgQ, tSrQ, tQcQ, tQpQ,
params.seqlen_q - m_block * kBlockM);
__syncthreads();
#else
Tensor tSrQ = thr_mma.partition_fragment_A(gQ);
lds_direct_copy_qkvfp8_q_tp4<false, true>(gQ, sQ, 0, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8_q_tp4<false, true>(gQ, sQ, 1, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8_q_tp4<false, true>(gQ, sQ, 2, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8_q_tp4<false, true>(gQ, sQ, 3, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8_q_tp4<false, false>(gQ, sQ, 4, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t");
uint8_t* q_lds_read_ptr = reinterpret_cast<uint8_t*>(sQ.data().get()) + (tidx % 64) * 16 + (warp_id % 2) * (16 * 64);
{
int k = 0;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k).storage = q_lds_read_ptr[i];
}
q_lds_read_ptr += 32*64;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k+1).storage = q_lds_read_ptr[i];
}
// int k = 0;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
{
q_lds_read_ptr += 32*64;
int k = 2;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k).storage = q_lds_read_ptr[i];
}
q_lds_read_ptr += 32*64;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k+1).storage = q_lds_read_ptr[i];
}
// q_lds_read_ptr += 64*64;
// int k = 2;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
{
q_lds_read_ptr += 32*64;
int k = 4;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k).storage = q_lds_read_ptr[i];
}
q_lds_read_ptr += 32*64;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k+1).storage = q_lds_read_ptr[i];
}
// q_lds_read_ptr += 64*64;
// int k = 4;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
{
q_lds_read_ptr += 32*64;
int k = 6;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k).storage = q_lds_read_ptr[i];
}
q_lds_read_ptr += 32*64;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k+1).storage = q_lds_read_ptr[i];
}
// q_lds_read_ptr += 64*64;
// int k = 6;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
{
q_lds_read_ptr += 32*64;
int k = 8;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k).storage = q_lds_read_ptr[i];
}
// q_lds_read_ptr += 64*64;
// int k = 8;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
__syncthreads();
#endif
auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
Tensor tSrK = thr_mma.partition_fragment_B(gK); //将sk中数据按照tiled_mma拷贝到tSrQ
// Tensor tSrK_smem = thr_mma.partition_fragment_B(gK); //128,64 gk->rk
auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_DS_M64x16_B8, Element>{}, tiled_mma_o);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
Tensor tOrVt = thr_mma_o.partition_fragment_B(sVtNoSwizzle);
constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
int n_block = n_block_max - 1;
union Val
{
intx2_t val_to_mmac;
int32_t data[2];
};
// Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{});
// clear(acc_o);
flash::Softmax<1> softmax;
v4f acco_f32[16];
for (int i = 0; i < 16; i++)
{
acco_f32[i].x = 0.0f;
acco_f32[i].y = 0.0f;
acco_f32[i].z = 0.0f;
acco_f32[i].w = 0.0f;
}
constexpr static int STAGE = 8;
for (int masking_step = 0; n_block >= n_block_min; ++masking_step, --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt);
clear(acc_s);
Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
// asm volatile("s_barrier \n\t");
int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block;
// cur_block_table = block_table[n_block - 1];
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK.data() + (offset_k);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 0, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 1, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 2, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 3, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 4, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 5, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 6, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 7, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 8, params.k_row_stride, seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(8) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 0), tSrK_copy_view(_, _, 0));
cute::gemm(tiled_mma, tSrQ(_, _, 0), tSrK(_, _, 0), acc_s);
asm volatile("s_waitcnt vmcnt(7) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 1), tSrK_copy_view(_, _, 1));
cute::gemm(tiled_mma, tSrQ(_, _, 1), tSrK(_, _, 1), acc_s);
asm volatile("s_waitcnt vmcnt(6) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 2), tSrK_copy_view(_, _, 2));
cute::gemm(tiled_mma, tSrQ(_, _, 2), tSrK(_, _, 2), acc_s);
asm volatile("s_waitcnt vmcnt(5) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 3), tSrK_copy_view(_, _, 3));
cute::gemm(tiled_mma, tSrQ(_, _, 3), tSrK(_, _, 3), acc_s);
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 4), tSrK_copy_view(_, _, 4));
cute::gemm(tiled_mma, tSrQ(_, _, 4), tSrK(_, _, 4), acc_s);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 5), tSrK_copy_view(_, _, 5));
cute::gemm(tiled_mma, tSrQ(_, _, 5), tSrK(_, _, 5), acc_s);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 6), tSrK_copy_view(_, _, 6));
cute::gemm(tiled_mma, tSrQ(_, _, 6), tSrK(_, _, 6), acc_s);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 7), tSrK_copy_view(_, _, 7));
cute::gemm(tiled_mma, tSrQ(_, _, 7), tSrK(_, _, 7), acc_s);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 8), tSrK_copy_view(_, _, 8));
cute::gemm(tiled_mma, tSrQ(_, _, 8), tSrK(_, _, 8), acc_s);
gK.data() = gK.data() + (-offset_k);
// asm volatile("s_barrier \n\t");
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tScS = thr_mma.partition_C(cS);
for (int i = 0; i < size(acc_s); ++i) {
if constexpr (!Is_causal) {
if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY;
} else {
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int row = int(get<0>(tScS(i)));
int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;
if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY;
}
}
// asm volatile("s_barrier \n\t");
{
const bool is_first_masking_step = masking_step == 0;
// is_first_masking_step
// ? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, scale_softmax_log2)
// : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, scale_softmax_log2);
is_first_masking_step
? softmax.template softmax_rescale_o_fp8_tp4</*Is_first=*/true, /*Check_inf=*/Is_causal>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, acco_f32)
: softmax.template softmax_rescale_o_fp8_tp4</*Is_first=*/false, /*Check_inf=*/Is_causal>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, acco_f32);
}
// asm volatile("s_barrier \n\t");
// if (block0() && tidx < 64)
// {
// // printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w,
// // acco_f32[1].x, acco_f32[1].y, acco_f32[1].z, acco_f32[1].w,
// // acco_f32[2].x, acco_f32[2].y, acco_f32[2].z, acco_f32[2].w,
// // acco_f32[3].x, acco_f32[3].y, acco_f32[3].z, acco_f32[3].w
// // );
// printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acc_s(0), acc_s(1), acc_s(2), acc_s(3),
// acc_s(4), acc_s(5), acc_s(6), acc_s(7)
// // acc_s(8), acc_s(9), acc_s(10), acc_s(11),
// // acc_s(12), acc_s(13), acc_s(14), acc_s(15)
// );
// }
#if 1
Fp8_storage p_fp8;
{
__builtin_amdgcn_sched_barrier(0);
int tid = threadIdx.x % 64;
int warp_id = threadIdx.x / 64;
int32_t result;
result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(0, 0, 0), acc_s(1, 0, 0), result, false);
result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(2, 0, 0), acc_s(3, 0, 0), result, true);
// int32_t* lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id / 4) * 16 * 32 + (warp_id % 4) * 16 * 64]));
// *lds_ptr = result;
int32_t result1;
result1 = __builtin_hcu_cvt_pk_fp8_f32(acc_s(0, 0, 1), acc_s(1, 0, 1), result1, false);
result1 = __builtin_hcu_cvt_pk_fp8_f32(acc_s(2, 0, 1), acc_s(3, 0, 1), result1, true);
// lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id / 4) * 16 * 32 + (warp_id % 4) * 16 * 64 + 8]));
// *lds_ptr = result1;
int32_t* lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid % 64) / 16) * 16 * 16 + (warp_id / 2) * 4 + (warp_id % 2) * 16 * 64]));
*lds_ptr = result;
int32_t* lds_ptr1 = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid % 64) / 16 )* 16 * 16 + (warp_id / 2) * 4 + (warp_id % 2) * 16 * 64 + 8]));
*lds_ptr1 = result1;
__syncthreads();
p_fp8.data = *reinterpret_cast<intx4_t*>(&(sP[tid * 16 + (warp_id % 2) * 16 * 64]));
__builtin_amdgcn_sched_barrier(0);
}
{
__builtin_amdgcn_sched_barrier(0);
for (int i = 0; i < 4; i++)
{
{
int k = 0;
Fp8_storage v0_0, v0_1;
v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(&(tOsVt(0, i, k))));
v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(&(tOsVt(0, i, k+1))));
// if (block0() && tidx < 64)
// {
// float v0 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 0);
// float v1 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 1);
// float v2 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 2);
// float v3 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 3);
// printf("tid = %d %.3f %.3f %.3f %.3f %x \n", tidx, v0, v1, v2, v3, v0_0.fp8_array[0]);
// }
for (int j = 0; j < 4; j++)
{
Val tmp;
tmp.data[0] = v0_0.fp8_array[j];
tmp.data[1] = v0_1.fp8_array[j];
acco_f32[i * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[k/2], tmp.val_to_mmac, acco_f32[i * 4 + j], true, false);
}
}
{
int k = 2;
Fp8_storage v0_0, v0_1;
v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(&(tOsVt(0, i, k))));
v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(&(tOsVt(0, i, k+1))));
// if (block0() && tidx < 64)
// {
// float v0 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 0);
// float v1 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 1);
// float v2 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 2);
// float v3 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 3);
// printf("tid = %d %.3f %.3f %.3f %.3f %x \n", tidx, v0, v1, v2, v3, v0_0.fp8_array[0]);
// }
for (int j = 0; j < 4; j++)
{
Val tmp;
tmp.data[0] = v0_0.fp8_array[j];
tmp.data[1] = v0_1.fp8_array[j];
acco_f32[i * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[k/2], tmp.val_to_mmac, acco_f32[i * 4 + j], true, false);
}
}
}
__builtin_amdgcn_sched_barrier(0);
}
// if (block0())
// {
// printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w,
// acco_f32[1].x, acco_f32[1].y, acco_f32[1].z, acco_f32[1].w,
// acco_f32[2].x, acco_f32[2].y, acco_f32[2].z, acco_f32[2].w,
// acco_f32[3].x, acco_f32[3].y, acco_f32[3].z, acco_f32[3].w
// );
// // printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acc_s(0), acc_s(1), acc_s(2), acc_s(3),
// // acc_s(4), acc_s(5), acc_s(6), acc_s(7),
// // acc_s(8), acc_s(9), acc_s(10), acc_s(11),
// // acc_s(12), acc_s(13), acc_s(14), acc_s(15)
// // );
// }
// asm volatile("s_barrier \n\t");
#endif
}
using ElementO = typename Kernel_traits::ElementO;
using ElementAccum = typename Kernel_traits::ElementAccum;
const int split_offset = __ldg(params.num_splits_ptr + bidb);
// Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_sum.data()), typename Kernel_traits::SmemLayoutRow{});
const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
if (NoSplit) {
constexpr bool Split = false;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor lse = softmax.template normalize_softmax_lse_fp8_tp4</*Is_dropout=*/false, Split, true>(acco_f32, sRow_sum_reduce_buffer, scale_softmax, descale_k);
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma_o.partition_C(caccO);
if (get<1>(taccOcO(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO(0, mi, 0));
if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
}
}
// if (tidx == 1)
// {
// printf(" %.4f %.4f %.4f %.4f \n ", acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w);
// }
{
using result_type = cutlass::Array<bfloat16_t, 2>;
int tidx = threadIdx.x;
int col = 0;
int warpid = tidx / 64;
for (int m = 0; m < 1; m++) {
const int row = tidx % 16 + (warpid % 2) * 16;
if (row < params.seqlen_q - m_block * kBlockM) {
for (int n = 0; n < 4; n++)
{
col = (tidx % 64 / 16) * 16 + n * 128 + (warpid / 2) * 64;
{
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 0].x, 0, acco_f32[n * 4 + 1].x, 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 2].x, 0, acco_f32[n * 4 + 3].x, 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gOaccum(row, col) = res0[0];
gOaccum(row, col + 1) = res0[1];
gOaccum(row, col + 2) = res1[0];
gOaccum(row, col + 3) = res1[1];
col += 4;
}
{
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 0].y, 0, acco_f32[n * 4 + 1].y, 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 2].y, 0, acco_f32[n * 4 + 3].y, 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gOaccum(row, col) = res0[0];
gOaccum(row, col + 1) = res0[1];
gOaccum(row, col + 2) = res1[0];
gOaccum(row, col + 3) = res1[1];
col += 4;
}
{
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 0].z, 0, acco_f32[n * 4 + 1].z, 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 2].z, 0, acco_f32[n * 4 + 3].z, 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gOaccum(row, col) = res0[0];
gOaccum(row, col + 1) = res0[1];
gOaccum(row, col + 2) = res1[0];
gOaccum(row, col + 3) = res1[1];
col += 4;
}
{
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 0].w, 0, acco_f32[n * 4 + 1].w, 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + 2].w, 0, acco_f32[n * 4 + 3].w, 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gOaccum(row, col) = res0[0];
gOaccum(row, col + 1) = res0[1];
gOaccum(row, col + 2) = res1[0];
gOaccum(row, col + 3) = res1[1];
// col += 16;
}
// for (int j = 0; j < 4; j++)
// {
// auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + j].x, 0, acco_f32[n * 4 + j].y, 0);
// auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + j].z, 0, acco_f32[n * 4 + j].w, 0);
// auto res0 = reinterpret_cast<result_type const &>(d0);
// auto res1 = reinterpret_cast<result_type const &>(d1);
// gOaccum(row, col) = res0[0];
// gOaccum(row, col + 1) = res0[1];
// gOaccum(row, col + 2) = res1[0];
// gOaccum(row, col + 3) = res1[1];
// col += 16;
// }
}
// for (int n = 0; n < 8; n++)
// {
// using result_type = cutlass::Array<bfloat16_t, 2>;
// auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n].x, 0, acco_f32[n].y, 0);
// auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n].z, 0, acco_f32[n].w, 0);
// col = (tidx % 64 / 16) * 4 + n * 64;
// auto res0 = reinterpret_cast<result_type const &>(d0);
// auto res1 = reinterpret_cast<result_type const &>(d1);
// gOaccum(row, col) = res0[0];
// gOaccum(row, col + 1) = res0[1];
// gOaccum(row, col + 2) = res1[0];
// gOaccum(row, col + 3) = res1[1];
// }
}
}
}
} else {
constexpr bool Split = true;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor lse = softmax.template normalize_softmax_lse_fp8_tp4</*Is_dropout=*/false, Split, true>(acco_f32, sRow_sum_reduce_buffer, scale_softmax, descale_k);
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma_o.partition_C(caccO);
if (get<1>(taccOcO(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO(0, mi, 0));
if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
}
}
int tidx = threadIdx.x;
int col = 0;
int warpid = tidx / 64;
for (int m = 0; m < 1; m++) {
const int row = tidx % 16 + (warpid % 2) * 16;
if (row < params.seqlen_q - m_block * kBlockM) {
// for (int n = 0; n < 32; n++)
// {
// col = (tidx % 64 / 16) * 4 + n * 16;
// gOaccum(row, col) = acco_f32[n].x;
// gOaccum(row, col + 1) = acco_f32[n].y;
// gOaccum(row, col + 2) = acco_f32[n].z;
// gOaccum(row, col + 3) = acco_f32[n].w;
// }
for (int n = 0; n < 4; n++)
{
col = (tidx % 64 / 16) * 16 + n * 128 + (warp_id / 2) * 64;
{
gOaccum(row, col) = acco_f32[n * 4 + 0].x;
gOaccum(row, col + 1) = acco_f32[n * 4 + 1].x;
gOaccum(row, col + 2) = acco_f32[n * 4 + 2].x;
gOaccum(row, col + 3) = acco_f32[n * 4 + 3].x;
col += 4;
}
{
gOaccum(row, col) = acco_f32[n * 4 + 0].y;
gOaccum(row, col + 1) = acco_f32[n * 4 + 1].y;
gOaccum(row, col + 2) = acco_f32[n * 4 + 2].y;
gOaccum(row, col + 3) = acco_f32[n * 4 + 3].y;
col += 4;
}
{
gOaccum(row, col) = acco_f32[n * 4 + 0].z;
gOaccum(row, col + 1) = acco_f32[n * 4 + 1].z;
gOaccum(row, col + 2) = acco_f32[n * 4 + 2].z;
gOaccum(row, col + 3) = acco_f32[n * 4 + 3].z;
col += 4;
}
{
gOaccum(row, col) = acco_f32[n * 4 + 0].w;
gOaccum(row, col + 1) = acco_f32[n * 4 + 1].w;
gOaccum(row, col + 2) = acco_f32[n * 4 + 2].w;
gOaccum(row, col + 3) = acco_f32[n * 4 + 3].w;
}
// for (int j = 0; j < 4; j++) {
// gOaccum(row, col) = acco_f32[n * 4 + j].x;
// gOaccum(row, col + 1) = acco_f32[n * 4 + j].y;
// gOaccum(row, col + 2) = acco_f32[n * 4 + j].z;
// gOaccum(row, col + 3) = acco_f32[n * 4 + j].w;
// col += 16;
// }
}
}
}
}
// Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{});
// for (int n = 0; n < 8; n++)
// {
// acc_o(0, 0, n) = acco_f32[n * 2].x;
// acc_o(1, 0, n) = acco_f32[n * 2].y;
// acc_o(2, 0, n) = acco_f32[n * 2].z;
// acc_o(3, 0, n) = acco_f32[n * 2].w;
// acc_o(4, 0, n) = acco_f32[n * 2 + 1].x;
// acc_o(5, 0, n) = acco_f32[n * 2 + 1].y;
// acc_o(6, 0, n) = acco_f32[n * 2 + 1].z;
// acc_o(7, 0, n) = acco_f32[n * 2 + 1].w;
// }
// if (NoSplit)
// store_float8<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax, descale_k, scale_softmax);
// else
// store_float8<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax, descale_k, scale_softmax);
}
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1)
flash_fwd_splitkv_mla_kernel_fp8(const Flash_fwd_mla_params params) {
constexpr int kBlockN = Kernel_traits::kBlockN;
const int m_block = blockIdx.x;
const int bidh = blockIdx.y;
const int partition_idx = blockIdx.z;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<SharedStorage *>(shared_memory);
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
int4 tile_scheduler_metadata = *(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
int begin_idx = tile_scheduler_metadata.x;
int begin_seqlen = tile_scheduler_metadata.y;
int end_idx = tile_scheduler_metadata.z;
int end_seqlen = tile_scheduler_metadata.w;
if (begin_idx >= params.b) return;
int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4);
float descale_k = 1.f;
float scale_softmax = params.scale_softmax;
float scale_softmax_log2 = params.scale_softmax_log2;
float descale_q = __ldg(params.descale_q_ptr);
descale_k = __ldg(params.descale_k_ptr);
scale_softmax = scale_softmax * descale_q * descale_k;
scale_softmax_log2 = scale_softmax_log2 * descale_q * descale_k;
#pragma unroll 1
for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {
const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0;
const int seqlen_k = *(params.cu_seqlens_k + batch_id);
const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0;
const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN);
const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN);
if (batch_id > begin_idx) {
__syncthreads(); // Barrier between two tiles.
}
#if defined(__gfx938__)
{
flash::compute_attn_1rowblock_splitkv_mla_fp8_gfx938<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_k, scale_softmax, scale_softmax_log2);
}
#endif
}
}
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1)
flash_fwd_splitkv_mla_kernel_fp8_tp1(const Flash_fwd_mla_params params) {
constexpr int kBlockN = Kernel_traits::kBlockN;
const int m_block = blockIdx.x;
const int bidh = blockIdx.y;
const int partition_idx = blockIdx.z;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<SharedStorage *>(shared_memory);
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
int4 tile_scheduler_metadata = *(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
int begin_idx = tile_scheduler_metadata.x;
int begin_seqlen = tile_scheduler_metadata.y;
int end_idx = tile_scheduler_metadata.z;
int end_seqlen = tile_scheduler_metadata.w;
if (begin_idx >= params.b) return;
int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4);
float descale_k = 1.f;
float scale_softmax = params.scale_softmax;
float scale_softmax_log2 = params.scale_softmax_log2;
float descale_q = __ldg(params.descale_q_ptr);
descale_k = __ldg(params.descale_k_ptr);
scale_softmax = scale_softmax * descale_q * descale_k;
scale_softmax_log2 = scale_softmax_log2 * descale_q * descale_k;
#pragma unroll 1
for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {
const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0;
const int seqlen_k = *(params.cu_seqlens_k + batch_id);
const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0;
const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN);
const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN);
if (batch_id > begin_idx) {
__syncthreads(); // Barrier between two tiles.
}
#if defined(__gfx938__)
{
flash::compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP1<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_k, scale_softmax, scale_softmax_log2);
}
#endif
}
}
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1)
flash_fwd_splitkv_mla_kernel_fp8_tp4(const Flash_fwd_mla_params params) {
constexpr int kBlockN = Kernel_traits::kBlockN;
const int m_block = blockIdx.x;
const int bidh = blockIdx.y;
const int partition_idx = blockIdx.z;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<SharedStorage *>(shared_memory);
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
int4 tile_scheduler_metadata = *(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
int begin_idx = tile_scheduler_metadata.x;
int begin_seqlen = tile_scheduler_metadata.y;
int end_idx = tile_scheduler_metadata.z;
int end_seqlen = tile_scheduler_metadata.w;
if (begin_idx >= params.b) return;
int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4);
float descale_k = 1.f;
float scale_softmax = params.scale_softmax;
float scale_softmax_log2 = params.scale_softmax_log2;
float descale_q = __ldg(params.descale_q_ptr);
descale_k = __ldg(params.descale_k_ptr);
scale_softmax = scale_softmax * descale_q * descale_k;
scale_softmax_log2 = scale_softmax_log2 * descale_q * descale_k;
#pragma unroll 1
for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {
const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0;
const int seqlen_k = *(params.cu_seqlens_k + batch_id);
const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0;
const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN);
const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN);
if (batch_id > begin_idx) {
__syncthreads(); // Barrier between two tiles.
}
#if defined(__gfx938__)
{
flash::compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP4<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_k, scale_softmax, scale_softmax_log2);
}
#endif
}
}
} // namespace flash
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, typename SharedStorage>
void run_flash_splitkv_fwd_mla_fp8_tp1(Flash_fwd_mla_params &params, cudaStream_t stream) {
FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN);
const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
const static bool disable_asm = get_env_("FLASH_MLA_DISABLE_ASM");
if (disable_asm) {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
auto kernel = &flash::flash_fwd_splitkv_mla_kernel_fp8_tp1<Kernel_traits, Is_causal, SharedStorage>;
constexpr size_t smem_size = 65536;
// CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
kernel<<<dim3(num_m_block, params.h, params.num_sm_parts), Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
}
else {
static char* FLASH_MLA_ASM_DIR = std::getenv("FLASH_MLA_ROOT_DIR");
static_assert(FLASH_MLA_ASM_DIR);
constexpr size_t smem_size = 65536;
std::string co_file = std::string(FLASH_MLA_ASM_DIR) +
"flash_fwd_mla_fp8_gfx938-hip-amdgcn-amd-amdhsa-gfx938.co";
hipError_t status = hipSuccess;
static hipModule_t fwd_module_sample;
static bool IS_FWD_MODULE_LOADED = false;
if (IS_FWD_MODULE_LOADED == false)
{
status = hipModuleLoad(&fwd_module_sample, co_file.c_str());
if (status not_eq hipSuccess) {
printf("[flashmla] EXIT: failed to load module from %s\n", co_file.c_str());
return;
}
IS_FWD_MODULE_LOADED = true;
}
size_t params_size = sizeof(params);
void* config[] = {
HIP_LAUNCH_PARAM_BUFFER_POINTER,
&params,
HIP_LAUNCH_PARAM_BUFFER_SIZE,
&params_size,
HIP_LAUNCH_PARAM_END
};
dim3 grid(num_m_block, params.h, params.num_sm_parts);
std::string kernel_name = params.is_causal ?
"_ZN5flash36flash_fwd_splitkv_mla_kernel_fp8_tp1I38Flash_fwd_kernel_traits_mla_qkvfp8_TP1ILi576ELi64ELi64ELi8EN7cutlass12float_e4m3_tENS2_10bfloat16_tELi512EELb1ENS_26SharedStorageMLAFloat8_TP1IS5_EEEEv20Flash_fwd_mla_params":
"_ZN5flash36flash_fwd_splitkv_mla_kernel_fp8_tp1I38Flash_fwd_kernel_traits_mla_qkvfp8_TP1ILi576ELi64ELi64ELi8EN7cutlass12float_e4m3_tENS2_10bfloat16_tELi512EELb0ENS_26SharedStorageMLAFloat8_TP1IS5_EEEEv20Flash_fwd_mla_params";
hipFunction_t flash_mla_func;
status = hipModuleGetFunction(&flash_mla_func, fwd_module_sample, kernel_name.c_str());
status = hipModuleLaunchKernel(
flash_mla_func,
grid.x, grid.y, grid.z,
Kernel_traits::kNThreads, 1, 1,
smem_size, // shared memory
stream, // stream
NULL,
(void**)&config
);
if (status not_eq hipSuccess) {
printf("[flashmla] EXIT: failed to launch kernel!\n");
return;
}
}
CHECK_CUDA_KERNEL_LAUNCH();
dim3 grid_combine(params.b * params.h * params.seqlen_q);
constexpr int kNThreads = 128;
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {
auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<
typename Kernel_traits::ElementO, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits, kNThreads>;
combine_kernel<<<grid_combine, kNThreads, 0, stream>>>(params);
});
CHECK_CUDA_KERNEL_LAUNCH();
}
template<typename Kernel_traits, typename SharedStorage>
void run_flash_splitkv_fwd_mla_fp8_tp4(Flash_fwd_mla_params &params, cudaStream_t stream) {
FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN);
const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
auto kernel = &flash::flash_fwd_splitkv_mla_kernel_fp8_tp4<Kernel_traits, Is_causal, SharedStorage>;
constexpr size_t smem_size = 65536;
// CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
kernel<<<dim3(num_m_block, params.h, params.num_sm_parts), Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
CHECK_CUDA_KERNEL_LAUNCH();
dim3 grid_combine(params.b * params.h * params.seqlen_q);
constexpr int kNThreads = 128;
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {
auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<
typename Kernel_traits::ElementO, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits, kNThreads>;
combine_kernel<<<grid_combine, kNThreads, 0, stream>>>(params);
});
CHECK_CUDA_KERNEL_LAUNCH();
}
template<typename Kernel_traits, typename SharedStorage>
void run_flash_splitkv_fwd_mla_fp8(Flash_fwd_mla_params &params, cudaStream_t stream) {
FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN);
const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
const static bool enable_asm = get_env_("FLASH_MLA_ENABLE_ASM");
if (Kernel_traits::IS_WITH_CAT || !enable_asm) {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
auto kernel = &flash::flash_fwd_splitkv_mla_kernel_fp8<Kernel_traits, Is_causal, SharedStorage>;
constexpr size_t smem_size = 32768;
// CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
kernel<<<dim3(num_m_block, params.h, params.num_sm_parts), Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
}
else {
static char* FLASH_MLA_ASM_DIR = std::getenv("FLASH_MLA_ROOT_DIR");
constexpr size_t smem_size = 32768;
std::string co_file = std::string(FLASH_MLA_ASM_DIR) +
"flash_fwd_mla_fp8_gfx938-hip-amdgcn-amd-amdhsa-gfx938.co";
hipError_t status = hipSuccess;
static hipModule_t fwd_module_sample;
static bool IS_FWD_MODULE_LOADED = false;
if (IS_FWD_MODULE_LOADED == false)
{
status = hipModuleLoad(&fwd_module_sample, co_file.c_str());
if (status not_eq hipSuccess) {
printf("[flashmla] EXIT: failed to load module from %s\n", co_file.c_str());
// return 0;
}
IS_FWD_MODULE_LOADED = true;
}
size_t params_size = sizeof(params);
void* config[] = {
HIP_LAUNCH_PARAM_BUFFER_POINTER,
&params,
HIP_LAUNCH_PARAM_BUFFER_SIZE,
&params_size,
HIP_LAUNCH_PARAM_END
};
dim3 grid(num_m_block, params.h, params.num_sm_parts);
std::string kernel_name = params.is_causal ?
"_ZN5flash32flash_fwd_splitkv_mla_kernel_fp8I34Flash_fwd_kernel_traits_mla_qkvfp8ILi576ELi16ELi64ELi4EN7cutlass12float_e4m3_tENS2_10bfloat16_tELi512EELb1ENS_22SharedStorageMLAFloat8IS5_EEEEv20Flash_fwd_mla_params":
"_ZN5flash32flash_fwd_splitkv_mla_kernel_fp8I34Flash_fwd_kernel_traits_mla_qkvfp8ILi576ELi16ELi64ELi4EN7cutlass12float_e4m3_tENS2_10bfloat16_tELi512EELb0ENS_22SharedStorageMLAFloat8IS5_EEEEv20Flash_fwd_mla_params";
hipFunction_t flash_mla_func;
status = hipModuleGetFunction(&flash_mla_func, fwd_module_sample, kernel_name.c_str());
status = hipModuleLaunchKernel(
flash_mla_func,
grid.x, grid.y, grid.z,
Kernel_traits::kNThreads, 1, 1,
smem_size, // shared memory
stream, // stream
NULL,
(void**)&config
);
if (status not_eq hipSuccess) {
printf("[flashmla] EXIT: failed to launch kernel!\n");
// return 0;
}
}
CHECK_CUDA_KERNEL_LAUNCH();
dim3 grid_combine(params.b * params.h * params.seqlen_q);
constexpr int kNThreads = 128;
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {
auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<
typename Kernel_traits::ElementO, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits, kNThreads>;
combine_kernel<<<grid_combine, kNThreads, 0, stream>>>(params);
});
CHECK_CUDA_KERNEL_LAUNCH();
}
template<typename T,typename To, int Headdim>
void run_mha_fwd_splitkv_mla_fp8(Flash_fwd_mla_params &params, cudaStream_t stream, bool is_with_cat) {
static_assert(Headdim == 576);
FLASH_ASSERT(params.d_v == 512);
FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV
// printf(" params.ngroups = %d \n", params.ngroups);
if (is_with_cat) {
if constexpr (std::is_same_v<T, cutlass::float_e4m3_t>) {
using Kernel_traits = Flash_fwd_kernel_traits_mla_qkvfp8<576, 16, 64, 4, T, To, 512, true>;
run_flash_splitkv_fwd_mla_fp8<Kernel_traits, flash::SharedStorageMLAFloat8<Kernel_traits>>(params, stream);
} else {
// q为bf16
using Kernel_traits = Flash_fwd_kernel_traits_mla_qkvfp8<576, 16, 64, 4, cutlass::float_e4m3_t, To, 512, true, T>;
run_flash_splitkv_fwd_mla_fp8<Kernel_traits, flash::SharedStorageMLAFloat8<Kernel_traits>>(params, stream);
}
return;
}
if constexpr (std::is_same_v<T, cutlass::float_e4m3_t>) {
if (params.ngroups >= 64) {
using Kernel_traits = Flash_fwd_kernel_traits_mla_qkvfp8_TP1<576, 64, 64, 8, T, To, 512>;
run_flash_splitkv_fwd_mla_fp8_tp1<Kernel_traits, flash::SharedStorageMLAFloat8_TP1<Kernel_traits>>(params, stream);
} else if (params.ngroups > 16) {
using Kernel_traits = Flash_fwd_kernel_traits_mla_qkvfp8_TP4<576, 32, 64, 4, T, To, 512>;
run_flash_splitkv_fwd_mla_fp8_tp4<Kernel_traits, flash::SharedStorageMLAFloat8_TP4<Kernel_traits>>(params, stream);
} else {
using Kernel_traits = Flash_fwd_kernel_traits_mla_qkvfp8<576, 16, 64, 4, T, To, 512, false>;
run_flash_splitkv_fwd_mla_fp8<Kernel_traits, flash::SharedStorageMLAFloat8<Kernel_traits>>(params, stream);
}
}
}
#include "flash_fwd_mla_kernel.h"
static constexpr int MaxBatchSize = 4096;
__global__ void __launch_bounds__(64, 1)
get_mla_metadata_kernel(const Mla_metadata_params params) {
int *seqlens_k_ptr = params.seqlens_k_ptr;
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
int *num_splits_ptr = params.num_splits_ptr;
int batch_size = params.batch_size;
int block_size_n = params.block_size_n;
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
int num_sm_parts = params.num_sm_parts;
__shared__ int num_blocks_shared[MaxBatchSize];
__shared__ int num_splits_shared[MaxBatchSize];
int total_num_blocks = 0;
for (int i = threadIdx.x; i < batch_size; i += 64) {
int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n);
total_num_blocks += num_blocks + fixed_overhead_num_blocks;
num_blocks_shared[i] = num_blocks;
}
for (int offset = 32; offset >= 1; offset /= 2) {
// total_num_blocks += __shfl_xor(uint32_t(-1), total_num_blocks, offset);
total_num_blocks += __shfl_xor(total_num_blocks, offset, 64);
}
__syncthreads();
if (threadIdx.x == 0) {
int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
num_splits_shared[0] = 0;
for (int i = 0; i < num_sm_parts; ++i) {
int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
tile_scheduler_metadata0[0] = now_idx;
tile_scheduler_metadata0[1] = now_block * block_size_n;
tile_scheduler_metadata1 = now_n_split_idx;
int remain_payload = payload;
while (now_idx < batch_size) {
int num_blocks = num_blocks_shared[now_idx];
int now_remain_blocks = num_blocks - now_block;
if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {
cum_num_splits += now_n_split_idx + 1;
num_splits_shared[now_idx + 1] = cum_num_splits;
remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;
++now_idx;
now_block = 0;
now_n_split_idx = 0;
} else {
if (remain_payload - fixed_overhead_num_blocks > 0) {
now_block += remain_payload - fixed_overhead_num_blocks;
++now_n_split_idx;
remain_payload = 0;
}
break;
}
}
tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1];
*reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);
tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
}
//FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
}
__syncthreads();
for (int i = threadIdx.x; i <= batch_size; i += 64) {
num_splits_ptr[i] = num_splits_shared[i];
}
}
void get_mla_metadata_func(Mla_metadata_params &params, cudaStream_t stream) {
FLASH_ASSERT(params.batch_size < MaxBatchSize);
get_mla_metadata_kernel<<<1, 64, 0, stream>>>(params);
CHECK_CUDA_KERNEL_LAUNCH();
}
__global__ void __launch_bounds__(64, 1)
get_mla_decoding_metadata_kernel(const GetDecodingMetadataParams params) {
int *seqlens_k_ptr = params.seqlens_k_ptr;
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
int *num_splits_ptr = params.num_splits_ptr;
int batch_size = params.batch_size;
int block_size_n = params.block_size_n;
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
int num_sm_parts = params.num_sm_parts;
extern __shared__ int shared_mem[];
int* num_blocks_shared = shared_mem; // [batch_size]
int* num_splits_shared = shared_mem + batch_size; // [batch_size+1]
int* seqlens_k_shared = shared_mem + batch_size*2+1; // [batch_size]
int* first_block_idx_shared = shared_mem + batch_size*3+1; // [batch_size]
int* last_block_idx_shared = shared_mem + batch_size*4+1; // [batch_size]
int total_num_blocks = 0;
for (int i = threadIdx.x; i < batch_size; i += 64) {
int cur_s_k = params.topk == -1 ? __ldg(seqlens_k_ptr + i) : params.topk;
seqlens_k_shared[i] = cur_s_k;
int first_token_idx = 0;
int last_token_idx = max(cur_s_k-1, 0);
int cur_first_block_idx = first_token_idx / block_size_n;
int cur_last_block_idx = last_token_idx / block_size_n;
// NOTE Should attend to tokens [first_token_idx, last_token_idx], i.e. blocks [cur_first_block_idx, cur_last_block_idx]
// NOTE Before clamping, first_token_idx <= last_token_idx always holds, so after clamping, first_token_idx <= last_token_idx still holds.
// NOTE if seqlens_k is 0, then first_token_idx == last_token_idx == cur_first_block_idx == cur_last_block_idx == 0. So the sequence will have 1 block. We will correct this later in this kernel.
int num_blocks = cur_last_block_idx - cur_first_block_idx + 1;
total_num_blocks += num_blocks + fixed_overhead_num_blocks;
num_blocks_shared[i] = num_blocks;
first_block_idx_shared[i] = cur_first_block_idx;
last_block_idx_shared[i] = cur_last_block_idx;
}
for (int offset = 32; offset >= 1; offset /= 2) {
// total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);
total_num_blocks += __shfl_xor(total_num_blocks, offset, 64);
}
__syncthreads();
if (threadIdx.x == 0) {
int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
num_splits_shared[0] = 0;
for (int i = 0; i < num_sm_parts; ++i) {
int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
tile_scheduler_metadata0[0] = now_idx;
tile_scheduler_metadata0[1] = now_block + first_block_idx_shared[now_idx];
tile_scheduler_metadata1 = now_n_split_idx;
int remain_payload = payload;
while (now_idx < batch_size) {
int num_blocks = num_blocks_shared[now_idx];
int now_remain_blocks = num_blocks - now_block;
if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {
cum_num_splits += now_n_split_idx + 1;
num_splits_shared[now_idx + 1] = cum_num_splits;
remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;
++now_idx;
now_block = 0;
now_n_split_idx = 0;
} else {
if (remain_payload - fixed_overhead_num_blocks > 0) {
now_block += remain_payload - fixed_overhead_num_blocks;
++now_n_split_idx;
remain_payload = 0;
}
break;
}
}
tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
tile_scheduler_metadata0[3] = now_block > 0 ? now_block + first_block_idx_shared[now_idx] : (seqlens_k_shared[now_idx-1] == 0 ? 0 : last_block_idx_shared[now_idx-1] + 1);
*reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);
tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
}
// FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
}
__syncthreads();
for (int i = threadIdx.x; i <= batch_size; i += 64) {
num_splits_ptr[i] = num_splits_shared[i];
}
}
void run_get_mla_metadata_kernel(GetDecodingMetadataParams &params, cudaStream_t stream) {
int smem_size = sizeof(int) * (params.batch_size*5+1);
// CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
get_mla_decoding_metadata_kernel<<<1, 64, smem_size, stream>>>(params);
CHECK_CUDA_KERNEL_LAUNCH();
}
\ No newline at end of file
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_fwd_mla_params {
using index_t = int64_t;
int b, seqlen_q, d, d_v;
int h, h_h_k_ratio, ngroups;
bool is_causal;
float scale_softmax, scale_softmax_log2;
int *__restrict__ cu_seqlens_k;
void *__restrict__ q_ptr;
void *__restrict__ q_nope_ptr;
void *__restrict__ q_pe_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
void *__restrict__ o_ptr;
void *__restrict__ softmax_lse_ptr;
index_t q_batch_stride;
index_t q_nope_batch_stride;
index_t q_pe_batch_stride;
index_t k_batch_stride;
index_t v_batch_stride;
index_t o_batch_stride;
index_t q_row_stride;
index_t q_nope_row_stride;
index_t q_pe_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t o_row_stride;
index_t q_head_stride;
index_t q_nope_head_stride;
index_t q_pe_head_stride;
index_t k_head_stride;
index_t v_head_stride;
index_t o_head_stride;
int *__restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
int *__restrict__ tile_scheduler_metadata_ptr;
int num_sm_parts;
int *__restrict__ num_splits_ptr;
void *__restrict__ softmax_lseaccum_ptr;
void *__restrict__ oaccum_ptr;
void *__restrict__ k_scale_ptr;
float * __restrict__ descale_q_ptr ;
float * __restrict__ descale_k_ptr ;
};
struct SparsePrefillParams {
int s_q, s_kv, h_q, h_kv, d_qk, d_v, topk;
float sm_scale, sm_scale_div_log2;
// Input tensors
void* __restrict__ q; // [s_q, h_q, d_qk]
void* __restrict__ kv; // [s_kv, h_kv, d_qk]
void* __restrict__ indices; // [s_q, h_kv, topk]
int stride_q_s_q;
int stride_q_h_q;
int stride_kv_s_kv;
int stride_kv_h_kv;
int stride_indices_s_q;
int stride_indices_h_kv;
// Output tensors
void* __restrict__ out; // [s_q, h_q, d_v]
void* __restrict__ max_logits; // [s_q, h_q]
void* __restrict__ lse; // [s_q, h_q]
// cudaStream_t stream;
};
static constexpr int TileSchedulerMetaDataSize = 8;
// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _]
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim>
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, const std::string& kv_cache_dtype, hipStream_t stream, bool is_q_nope_pe = false);
template<typename T, typename To, int Headdim>
void run_mha_fwd_splitkv_mla_fp8(Flash_fwd_mla_params &params,hipStream_t stream, bool is_with_cat);
template<typename T, int Headdim>
void run_mha_fwd_sparse_prefill(const SparsePrefillParams &params, hipStream_t stream);
struct Mla_metadata_params {
int *__restrict__ seqlens_k_ptr;
int *__restrict__ tile_scheduler_metadata_ptr;
int *__restrict__ num_splits_ptr;
int batch_size;
int block_size_n;
int fixed_overhead_num_blocks;
int num_sm_parts;
};
void get_mla_metadata_func(Mla_metadata_params &params, cudaStream_t stream);
enum class Fp8KVCacheDataType {
kAuto = 0,
kFp8E4M3 = 1,
kFp8E5M2 = 2,
kInt8 = 3,
};
static inline bool get_env_(const char *env_var) {
if (char *value = std::getenv(env_var)) {
if (strcmp(value, "0") == 0) {
return false;
}
return true;
}
return false;
}
struct GetDecodingMetadataParams {
int *__restrict__ seqlens_k_ptr;
int *__restrict__ tile_scheduler_metadata_ptr;
int *__restrict__ num_splits_ptr;
int batch_size;
int block_size_n;
int fixed_overhead_num_blocks;
int num_sm_parts;
int topk;
};
void run_get_mla_metadata_kernel(GetDecodingMetadataParams &params, cudaStream_t stream);
struct DecodingParams {
using index_t = int64_t;
int b; // batch size
int s_q;
int q_seq_per_hk; // The number of q(s) per KV head, = h_q / h_k * s_q
int d, d_v; // K/V dimension
int h_q, h_k; // The number of Q/K heads
int num_blocks; // Number of blocks in total
int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k
bool is_causal;
float scale_softmax, scale_softmax_log2;
int topk;
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ o_ptr;
void *__restrict__ softmax_lse_ptr;
int *__restrict__ indices_ptr;
index_t q_batch_stride;
index_t k_batch_stride;
index_t o_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t o_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t o_head_stride;
index_t indices_batch_stride;
index_t indices_row_stride;
int *__restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
int *__restrict__ seqlens_k_ptr;
int *__restrict__ tile_scheduler_metadata_ptr;
int num_sm_parts;
int *__restrict__ num_splits_ptr;
int total_num_splits;
void *__restrict__ softmax_lseaccum_ptr;
void *__restrict__ oaccum_ptr;
};
template<typename T, int Headdim>
void run_flash_splitkv_sparse_mla_kernel(const DecodingParams &params, cudaStream_t stream);
\ No newline at end of file
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h
#pragma once
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/numeric_types.h>
#include "utils.h"
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) {
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
summary(mi) = op(summary(mi), tensor(mi, ni));
}
}
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++){
dst(i) = Allreduce<64>::run(src(i), op);
}
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void warp_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &smem_reduce, Operator &op) {
const int tidx = threadIdx.x;
const int row = tidx % 16;
const int col = tidx / 64;
const int warp_id = tidx / 64;
// static_assert(size(dst) == 1);
// 这里两种写法,一种是写连续,读不连续;另一种是读不连续,写连续。如何权衡?性能影响不大
if ((tidx % 64) / 16 == 0)
// if (tidx >= warp_id * 64 && tidx <= warp_id * 64 + 16)
{
// smem_reduce(row + warp_id * 16) = dst(0);
smem_reduce(row * 4 + warp_id * 1) = dst(0);
// smem_reduce(row, col) = dst(0);
}
__syncthreads();
if (tidx < 16)
{
smem_reduce(row + 64) = op(op(smem_reduce(row * 4), smem_reduce(row * 4 + 1)), op(smem_reduce(row * 4 + 2), smem_reduce(row * 4 + 3)));
}
__syncthreads();
dst(0) = smem_reduce(row + 64);
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void warp_allreduce_tp1(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &smem_reduce, Operator &op) {
const int tidx = threadIdx.x;
const int col = (tidx % 64) / 16;
const int warp_id = tidx / 64;
const int row = tidx % 16 + (warp_id % 4) * 16;
// 0-4 1-5 2-6 3-7
if (col == 0) {
// printf("sum %d %d %d %d %.2f \n", row * 2 + (warp_id / 4), row, tidx, warp_id, dst(0));
smem_reduce[row * 2 + (warp_id / 4)] = dst[0];
}
__syncthreads();
if (col == 0 && warp_id < 4) {
// printf("sum %d %d %d %.2f %.2f \n", row, tidx, warp_id, smem_reduce[row * 2], smem_reduce[row * 2 + warp_id / 4]);
smem_reduce[128 + row] = op(smem_reduce[row * 2], smem_reduce[row * 2 + 1]);
}
__syncthreads();
dst(0) = smem_reduce(128 + row);
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void warp_allreduce_tp4(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &smem_reduce, Operator &op) {
const int tidx = threadIdx.x;
const int col = (tidx % 64) / 16;
const int warp_id = tidx / 64;
const int row = tidx % 16 + (warp_id % 2) * 16;
// 0-4 1-5 2-6 3-7
if (col == 0) {
// printf("sum %d %d %d %d %.2f \n", row * 2 + (warp_id / 4), row, tidx, warp_id, dst(0));
smem_reduce[row * 2 + (warp_id / 2)] = dst[0];
}
__syncthreads();
if (col == 0 && warp_id < 2) {
// printf("sum %d %d %d %.2f %.2f \n", row, tidx, warp_id, smem_reduce[row * 2], smem_reduce[row * 2 + warp_id / 4]);
smem_reduce[row + 64] = op(smem_reduce[row * 2], smem_reduce[row * 2 + 1]);
}
__syncthreads();
dst(0) = smem_reduce(row + 64);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
thread_reduce_<zero_init>(tensor, summary, op);
quad_allreduce_(summary, summary, op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
MaxOp<float> max_op;
reduce_<zero_init>(tensor, max, max_op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op;
thread_reduce_<zero_init>(tensor, sum, sum_op);
}
// Apply the exp to all the elements.
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// The following macro will disable the use of fma.
// See: https://github.com/pytorch/pytorch/issues/121558 for more details
// This macro is set in PyTorch and not FlashAttention
#if 0
#ifdef UNFUSE_FMA
tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
#else
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
#endif
#else
tensor(mi, ni) = __builtin_amdgcn_exp2f(tensor(mi, ni) * scale - max_scaled);
#endif
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int kNRows>
struct Softmax {
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
TensorT row_max, row_sum;
__forceinline__ __device__ Softmax() {};
template<bool Is_first, bool Check_inf=false, bool is_tp1=false, typename Tensor0, typename Tensor1, typename Tensor2>
__forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, Tensor2 &sRow_max_reduce_buffer, float softmax_scale_log2) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
MaxOp<float> max_op;
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
if constexpr (is_tp1)
{
flash::template warp_allreduce_tp1(row_max, sRow_max_reduce_buffer, max_op);
}
else
{
flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op);
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
if constexpr (is_tp1)
{
flash::template warp_allreduce_tp1(row_max, sRow_max_reduce_buffer, max_op);
}
else
{
flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op);
} // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
#if 0
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#else
float scores_scale = __builtin_amdgcn_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#endif
// if (blockIdx.x == 0 && threadIdx.x == 0)
// {
// printf("threadIdx.x %.2f, scores_scale = %.4f\n",row_sum(mi), scores_scale );
// }
row_sum(mi) *= scores_scale;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
// if (block0())
// {
// printf("normalize_softmax_lse %.4f\n", row_sum(0));
// }
};
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor2>
__forceinline__ __device__ void softmax_rescale_o_fp8(Tensor0 &acc_s, Tensor2 &sRow_max_reduce_buffer, float softmax_scale_log2,
v4f& c0_0, v4f& c0_1, v4f& c1_0, v4f& c1_1, v4f& c2_0, v4f& c2_1, v4f& c3_0, v4f& c3_1
) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
MaxOp<float> max_op;
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
// Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(1 == kNRows);
// #pragma unroll
// for (int mi = 0; mi < size(row_max); ++mi)
{
int mi = 0;
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
#if 0
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#else
float scores_scale = __builtin_amdgcn_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#endif
row_sum(mi) *= scores_scale;
// #pragma unroll
// for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
c0_0.x *= scores_scale; c0_0.y *= scores_scale; c0_0.z *= scores_scale; c0_0.w *= scores_scale;
c0_1.x *= scores_scale; c0_1.y *= scores_scale; c0_1.z *= scores_scale; c0_1.w *= scores_scale;
c1_0.x *= scores_scale; c1_0.y *= scores_scale; c1_0.z *= scores_scale; c1_0.w *= scores_scale;
c1_1.x *= scores_scale; c1_1.y *= scores_scale; c1_1.z *= scores_scale; c1_1.w *= scores_scale;
c2_0.x *= scores_scale; c2_0.y *= scores_scale; c2_0.z *= scores_scale; c2_0.w *= scores_scale;
c2_1.x *= scores_scale; c2_1.y *= scores_scale; c2_1.z *= scores_scale; c2_1.w *= scores_scale;
c3_0.x *= scores_scale; c3_0.y *= scores_scale; c3_0.z *= scores_scale; c3_0.w *= scores_scale;
c3_1.x *= scores_scale; c3_1.y *= scores_scale; c3_1.z *= scores_scale; c3_1.w *= scores_scale;
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
};
template<bool Is_dropout=false, bool Split=false, bool is_tp1 = false, typename Tensor0, typename Tensor1>
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, Tensor1& sRow_sum_reduce_buffer, float softmax_scale, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
if constexpr (is_tp1)
{
flash::template warp_allreduce_tp1(row_sum, sRow_sum_reduce_buffer, sum_op);
}
else
{
flash::template warp_allreduce_(row_sum, sRow_sum_reduce_buffer, sum_op);
}
// if (block0())
// {
// printf("is_tp1 %d %d normalize_softmax_lse %.4f\n",is_tp1, threadIdx.x, row_sum(0));
// }
TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
}
return lse;
};
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
__forceinline__ __device__ void softmax_rescale_o_prefill(Tensor0 &acc_s, Tensor1 &acc_o, Tensor2 &sRow_max_reduce_buffer, float softmax_scale_log2) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
MaxOp<float> max_op;
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = !true
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
#if 0
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#else
float scores_scale = __builtin_amdgcn_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#endif
// if (blockIdx.x == 0 && threadIdx.x == 0)
// {
// printf("threadIdx.x %.2f, scores_scale = %.4f\n",row_sum(mi), scores_scale );
// }
row_sum(mi) *= scores_scale;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
}
// if (blockIdx.x == 2)
// {
// printf("threadIdx.x %.2f \n",row_sum(mi) );
// }
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
};
template<bool Is_dropout=false, bool Split=false, typename Tensor0, typename Tensor1>
__forceinline__ __device__ TensorT normalize_softmax_lse_prefill(Tensor0 &acc_o, Tensor1& sRow_sum_reduce_buffer, float softmax_scale, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
flash::template warp_allreduce_(row_sum, sRow_sum_reduce_buffer, sum_op);
TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __log2f(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
}
return lse;
};
template<bool Is_dropout=false, bool Split=false, typename Tensor0, typename Tensor1>
__forceinline__ __device__ TensorT normalize_softmax_lse_fp8(Tensor0 &acc_o, Tensor1& sRow_sum_reduce_buffer, float softmax_scale,float descale_v, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
flash::template warp_allreduce_(row_sum, sRow_sum_reduce_buffer, sum_op);
TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
}
return lse;
};
template<bool Is_first, bool Check_inf=false, bool is_tp1=false, typename Tensor0, typename Tensor2>
__forceinline__ __device__ void softmax_rescale_o_fp8_tp1(Tensor0 &acc_s, Tensor2 &sRow_max_reduce_buffer, float softmax_scale_log2,
v4f *acco_f32
) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
MaxOp<float> max_op;
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
if constexpr (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
if constexpr (is_tp1)
{
flash::template warp_allreduce_tp1(row_max, sRow_max_reduce_buffer, max_op);
}
else
{
flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op);
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
if constexpr (is_tp1)
{
flash::template warp_allreduce_tp1(row_max, sRow_max_reduce_buffer, max_op);
}
else
{
flash::template warp_allreduce_(row_max, sRow_max_reduce_buffer, max_op);
} // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
// Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(1 == kNRows);
// #pragma unroll
// for (int mi = 0; mi < size(row_max); ++mi)
{
int mi = 0;
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
#if 0
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#else
float scores_scale = __builtin_amdgcn_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#endif
row_sum(mi) *= scores_scale;
for (int i = 0; i < 16; i++)
{
acco_f32[i].x *= scores_scale;
acco_f32[i].y *= scores_scale;
acco_f32[i].z *= scores_scale;
acco_f32[i].w *= scores_scale;
}
// #pragma unroll
// for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
// c0_0.x *= scores_scale; c0_0.y *= scores_scale; c0_0.z *= scores_scale; c0_0.w *= scores_scale;
// c0_1.x *= scores_scale; c0_1.y *= scores_scale; c0_1.z *= scores_scale; c0_1.w *= scores_scale;
// c1_0.x *= scores_scale; c1_0.y *= scores_scale; c1_0.z *= scores_scale; c1_0.w *= scores_scale;
// c1_1.x *= scores_scale; c1_1.y *= scores_scale; c1_1.z *= scores_scale; c1_1.w *= scores_scale;
// c2_0.x *= scores_scale; c2_0.y *= scores_scale; c2_0.z *= scores_scale; c2_0.w *= scores_scale;
// c2_1.x *= scores_scale; c2_1.y *= scores_scale; c2_1.z *= scores_scale; c2_1.w *= scores_scale;
// c3_0.x *= scores_scale; c3_0.y *= scores_scale; c3_0.z *= scores_scale; c3_0.w *= scores_scale;
// c3_1.x *= scores_scale; c3_1.y *= scores_scale; c3_1.z *= scores_scale; c3_1.w *= scores_scale;
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
};
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor2>
__forceinline__ __device__ void softmax_rescale_o_fp8_tp4(Tensor0 &acc_s, Tensor2 &sRow_max_reduce_buffer, float softmax_scale_log2,
v4f *acco_f32
) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
MaxOp<float> max_op;
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
if constexpr (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::template warp_allreduce_tp4(row_max, sRow_max_reduce_buffer, max_op);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
flash::template warp_allreduce_tp4(row_max, sRow_max_reduce_buffer, max_op);
// Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(1 == kNRows);
// #pragma unroll
// for (int mi = 0; mi < size(row_max); ++mi)
{
int mi = 0;
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
#if 0
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#else
float scores_scale = __builtin_amdgcn_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#endif
row_sum(mi) *= scores_scale;
for (int i = 0; i < 16; i++)
{
acco_f32[i].x *= scores_scale;
acco_f32[i].y *= scores_scale;
acco_f32[i].z *= scores_scale;
acco_f32[i].w *= scores_scale;
}
// #pragma unroll
// for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
// c0_0.x *= scores_scale; c0_0.y *= scores_scale; c0_0.z *= scores_scale; c0_0.w *= scores_scale;
// c0_1.x *= scores_scale; c0_1.y *= scores_scale; c0_1.z *= scores_scale; c0_1.w *= scores_scale;
// c1_0.x *= scores_scale; c1_0.y *= scores_scale; c1_0.z *= scores_scale; c1_0.w *= scores_scale;
// c1_1.x *= scores_scale; c1_1.y *= scores_scale; c1_1.z *= scores_scale; c1_1.w *= scores_scale;
// c2_0.x *= scores_scale; c2_0.y *= scores_scale; c2_0.z *= scores_scale; c2_0.w *= scores_scale;
// c2_1.x *= scores_scale; c2_1.y *= scores_scale; c2_1.z *= scores_scale; c2_1.w *= scores_scale;
// c3_0.x *= scores_scale; c3_0.y *= scores_scale; c3_0.z *= scores_scale; c3_0.w *= scores_scale;
// c3_1.x *= scores_scale; c3_1.y *= scores_scale; c3_1.z *= scores_scale; c3_1.w *= scores_scale;
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
};
template<bool Is_dropout=false, bool Split=false, bool is_tp1=false, typename Tensor1>
__forceinline__ __device__ TensorT normalize_softmax_lse_fp8_tp1(v4f *acco_f, Tensor1& sRow_sum_reduce_buffer, float softmax_scale,float descale_v, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
if constexpr (is_tp1)
{
flash::template warp_allreduce_tp1(row_sum, sRow_sum_reduce_buffer, sum_op);
}
else
{
flash::template warp_allreduce_(row_sum, sRow_sum_reduce_buffer, sum_op);
}
TensorT lse = make_fragment_like(row_sum);
// Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
// static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < 1; ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
// #pragma unroll
// for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
for (int i = 0; i < 16; i++)
{
acco_f[i].x *= scale;
acco_f[i].y *= scale;
acco_f[i].z *= scale;
acco_f[i].w *= scale;
}
}
return lse;
};
template<bool Is_dropout=false, bool Split=false, bool is_tp1=false, typename Tensor1>
__forceinline__ __device__ TensorT normalize_softmax_lse_fp8_tp4(v4f *acco_f, Tensor1& sRow_sum_reduce_buffer, float softmax_scale,float descale_v, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
flash::template warp_allreduce_tp4(row_sum, sRow_sum_reduce_buffer, sum_op);
TensorT lse = make_fragment_like(row_sum);
// Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
// static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < 1; ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
// #pragma unroll
// for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
for (int i = 0; i < 16; i++)
{
acco_f[i].x *= scale;
acco_f[i].y *= scale;
acco_f[i].z *= scale;
acco_f[i].w *= scale;
}
}
return lse;
};
};
} // namespace flash
#pragma once
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
exit(1); \
} \
} while(0)
#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
#define FLASH_ASSERT(cond) \
do { \
if (not (cond)) { \
fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
exit(1); \
} \
} while(0)
#define FLASH_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) { \
printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
asm("S_ENDPGM;"); \
} \
} while(0)
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \
[&] { \
if (NUM_SPLITS <= 32) { \
constexpr static int NAME = 32; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 64) { \
constexpr static int NAME = 64; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 72) { \
constexpr static int NAME = 72; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 96) { \
constexpr static int NAME = 96; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 128) { \
constexpr static int NAME = 128; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 144) { \
constexpr static int NAME = 144; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 160) { \
constexpr static int NAME = 160; \
return __VA_ARGS__(); \
} else { \
FLASH_ASSERT(false); \
} \
}()
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
// #include <cuda_bf16.h>
// #include <cute/tensor.hpp>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include "flash_mla.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 64 || THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4 || THREADS == 2);
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor(x, OFFSET, 64));
return Allreduce<OFFSET>::run(x, op);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<1> {
// static_assert(THREADS == 64 || THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4 || THREADS == 2);
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<32> {
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
x = op(x, __shfl_xor(x, 16, 64));
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename To_type, typename Engine, typename Layout>
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
if constexpr (std::is_same_v<To_type, From_type>)
{
return tensor;
}
constexpr int numel = decltype(size(tensor))::value;
Tensor tensor_To_type = make_tensor<To_type>(layout(tensor));
cutlass::Array<To_type, numel> *result_ptr = reinterpret_cast<cutlass::Array<To_type, numel> *>(tensor_To_type.data());
#if defined(__gfx938__)
{
if constexpr (std::is_same_v<To_type, cutlass::bfloat16_t>) {
cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_to_nearest> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
}
else if constexpr (std::is_same_v<To_type, cutlass::float_e4m3_t>) {
cutlass::NumericArrayConverter<To_type, From_type, numel,cutlass::FloatRoundStyle::round_to_nearest> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
}
else {
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
}
return tensor_To_type;
}
#else
{
if constexpr (std::is_same_v<To_type, cutlass::bfloat16_t>) {
#ifndef FLASH_MLA_BF16_TYPE
#define FLASH_MLA_BF16_TYPE 0
#endif
#if FLASH_MLA_BF16_TYPE == 0
cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_toward_zero> convert_op;
#else
cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_half_ulp_truncate> convert_op;
#endif
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
} else {
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
}
return tensor_To_type;
}
#endif
// cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// // HACK: this requires tensor to be "contiguous"
// auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
// return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int row, int col, typename Tensor0, typename Tensor1>
__forceinline__ __device__ void __ds_read_m32x16_row_col(Tensor0& src, Tensor1& dst)
{
auto lds = reinterpret_cast<__fp16 *>(src.data().get());
auto layout = src.layout();
constexpr short offset = layout(0, row, col) * 2;
auto d = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset);
uint16_t * d_ptr = reinterpret_cast<uint16_t*>(&d);
uint16_t * dst_ptr = reinterpret_cast<uint16_t*>(&(dst(0, row, col)));
dst_ptr[0] = d_ptr[0];
dst_ptr[1] = d_ptr[1];
dst_ptr[2] = d_ptr[2];
dst_ptr[3] = d_ptr[3];
dst_ptr[4] = d_ptr[4];
dst_ptr[5] = d_ptr[5];
dst_ptr[6] = d_ptr[6];
dst_ptr[7] = d_ptr[7];
}
template<int row, int col, int r_row, typename Tensor0, typename Tensor1>
__forceinline__ __device__ void __ds_read_m32x16_row_col_rrow(Tensor0& src, Tensor1& dst)
{
auto lds = reinterpret_cast<__fp16 *>(src.data().get());
auto layout = src.layout();
constexpr short offset = layout(0, row, col) * 2;
auto d = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset);
uint16_t * d_ptr = reinterpret_cast<uint16_t*>(&d);
uint16_t * dst_ptr = reinterpret_cast<uint16_t*>(&(dst(0, r_row, col)));
dst_ptr[0] = d_ptr[0];
dst_ptr[1] = d_ptr[1];
dst_ptr[2] = d_ptr[2];
dst_ptr[3] = d_ptr[3];
dst_ptr[4] = d_ptr[4];
dst_ptr[5] = d_ptr[5];
dst_ptr[6] = d_ptr[6];
dst_ptr[7] = d_ptr[7];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Layout>
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
// static_assert(decltype(size<0>(acc_layout))::value == 4 || decltype(size<0>(acc_layout))::value == 8);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_1>{}); // (_4,_1,_2):(_1,_0,_4) -> ((_1,_4),_1,_2):((_0,_1),_0,_4)
return make_layout(make_layout(get<1>(l)), make_layout(get<1>(get<0>(l)), get<2>(l))); // (1, (4, 2)):((_0),(_1,_4))
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <class TiledMma,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1
>
__forceinline__ __device__ auto convert_layout_acc_Aregs(const TiledMma& tiled_mma, Tensor<Engine0, Layout0> const& tOrP,
Tensor<Engine1, Layout1> const& sAcc)
{
using Value_type = typename Engine0::value_type;
int tidx = threadIdx.x;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
auto smem_tiled_copy_ACC = make_tiled_copy_C(Copy_Atom<DefaultCopy, Value_type>{}, tiled_mma);
auto smem_thr_copy_ACC = smem_tiled_copy_ACC.get_thread_slice(tidx);
Tensor taccOr = smem_thr_copy_ACC.retile_S(tOrP);
Tensor taccOs = smem_thr_copy_ACC.partition_D(sAcc);
// if (cute::thread0())
// { taccOr
// raw_ptr_16b(0x2000000000010) o ((_1,_4),_1,_4):((_0,_1),_0,_4)
// print("taccOr\n"); print(taccOr); print("\n");
// }
cute::copy(smem_tiled_copy_ACC, taccOr, taccOs);
// asm volatile("s_waitcnt lgkmcnt(0)\n\t");
__syncthreads();
auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom<DefaultCopy, Value_type>{}, tiled_mma);
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx);
Tensor tSsACC = smem_thr_copy_A.partition_S(sAcc);
Tensor tSrACC = thr_mma.partition_fragment_A(sAcc);
Tensor tSrACC_copy_view = smem_thr_copy_A.retile_D(tSrACC);
cute::copy(smem_tiled_copy_ACC, tSsACC, tSrACC_copy_view);
// asm volatile("s_waitcnt lgkmcnt(0)\n\t");
// __syncthreads(); // 取消这个sync,2024.06.13
return tSrACC;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0, int begin_k=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) {
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
} else if (Clear_OOB_MN) {
cute::clear(D(_, m, _));
}
}
}
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy_k_idx(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0, int k_idx=0, int k_idx_smem=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
// CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
if (Is_even_K || predicate_K(k_idx)) {
cute::copy(tiled_copy, S(_, m, k_idx), D(_, m, k_idx_smem));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k_idx));
}
} else if (Clear_OOB_MN) {
cute::clear(D(_, m, _));
}
}
}
template <int N>
CUTE_HOST_DEVICE
void wait_vmcnt() {
asm volatile("s_waitcnt vmcnt(%0) ;\n\t"
"s_barrier; \n\t"
:: "n"(N));
}
template<
class SrcEngine, class SrcLayout>
CUTE_HOST_DEVICE
void
asm_ds_write(const uint128_t & src, Tensor<SrcEngine, SrcLayout> & dst, int k_idx)
{
uint128_t* d = reinterpret_cast<uint128_t*>(&dst(0, 0, k_idx));
d[0] = src;
}
template<
class SrcEngine, class SrcLayout>
CUTE_HOST_DEVICE
void
buffer_to_tensor(const uint128_t & src, Tensor<SrcEngine, SrcLayout> & dst, int k_idx)
{
uint128_t* d = reinterpret_cast<uint128_t*>(&dst(0, 0, k_idx));
d[0] = src;
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool mma_layout = false,
bool use_asm = false,
class SrcEngine, class SrcLayout
>
CUTE_HOST_DEVICE
void
buffer_load_copy(
Tensor<SrcEngine, SrcLayout> const& src,
uint128_t & dst,
int k_idx_, const int row_stride,
int offset_k,
const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 2;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
constexpr int elements_per_thread = 8;
if constexpr (mma_layout)
{
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
int mma_k = 32*64;
int row = tidx % 16;
int col = lane / 16;
int row_offset = row + (warp_id * 16) ;
int col_offset = col * elements_per_thread + k_idx * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if constexpr(use_asm) {
asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr)
);
}
else {
auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
dst = *reinterpret_cast<uint128_t*>(&res);
}
}
else
{
uint32x4_t global_addr = {0};
*(uint64_t*)&global_addr = reinterpret_cast<uint64_t>(src.data().get());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
int mma_k = 32*64;
int row = tidx / 4;
int col = lane % 4;
int row_offset = row;
int col_offset = col * elements_per_thread + k_idx * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if constexpr(use_asm) {
asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr)
);
}
else {
auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
dst = *reinterpret_cast<uint128_t*>(&res);
}
}
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool mma_layout = false,
bool use_asm = false,
class SrcEngine, class SrcLayout
>
CUTE_HOST_DEVICE
void
buffer_load_copy_fp8(
Tensor<SrcEngine, SrcLayout> const& src,
uint32x4_t & dst,
int k_idx_, const int row_stride,
int offset_k,
const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
constexpr int elements_per_thread = 16;
if constexpr (mma_layout)
{
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
int mma_k = 32*64;
int row = tidx % 16;
int col = lane / 16;
int row_offset = row + (warp_id * 16) ;
int col_offset = col * elements_per_thread + k_idx * 64;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if constexpr(use_asm) {
asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr)
);
}
else {
auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
dst = *reinterpret_cast<uint32x4_t*>(&res);
}
}
}
typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool mma_layout = false,
bool use_asm = false,
class SrcEngine, class SrcLayout
>
CUTE_HOST_DEVICE
void
buffer_load_copy_fp8x2(
Tensor<SrcEngine, SrcLayout> const& src,
uint32x2_t & dst,
int k_idx_, const int row_stride,
int offset_k,
const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
constexpr int elements_per_thread = 8;
if constexpr (mma_layout)
{
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
int mma_k = 32*64;
int row = lane / 8;
int col = lane % 8;
int row_offset = row * 4 + ((warp_id % 4)) + offset_k;
int col_offset = col * elements_per_thread + (warp_id / 4 ) * 64 + k_idx * 128;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if (col_offset >= 576) offset_v = -1;
if constexpr(use_asm) {
asm volatile(
"buffer_load_dwordx2 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr)
);
}
else {
auto res = __builtin_amdgcn_buffer_load_dwordx2(global_addr, 0, offset_v, false, false);
dst = *reinterpret_cast<uint32x2_t*>(&res);
}
}
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool mma_layout = false,
bool use_asm = false,
class SrcEngine, class SrcLayout
>
CUTE_HOST_DEVICE
void
buffer_load_copy_qkvfp8(
Tensor<SrcEngine, SrcLayout> const& src,
uint128_t & dst,
int k_idx_, const int row_stride,
int offset_k,
const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
constexpr int elements_per_thread = 16;
if constexpr (mma_layout)
{
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
int mma_k = 32*64;
int row = tidx % 16;
int col = lane / 16;
int row_offset = row + (warp_id * 16) ;
int col_offset = col * elements_per_thread + k_idx * 64;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if constexpr(use_asm) {
asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr)
);
}
else {
auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
dst = *reinterpret_cast<uint128_t*>(&res);
}
}
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool Is_load_Q=false,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy_qkvfp8_pe(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_, const int row_stride,
const int max_MN=0)
{
if constexpr (Is_load_Q) {
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 16;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*256;
int row = lane % 16;
int col = lane / 16;
int row_offset = row ;
int col_offset = (col + warp_id * 4) * elements_per_thread;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if (!Is_even_K && col_offset >= 576) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool Is_load_Q=false,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy_qkvfp8(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_, const int row_stride,
const int max_MN=0)
{
if constexpr (Is_load_Q) {
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 16;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*256;
int row = lane % 16;
int col = lane / 16;
int row_offset = row ;
int col_offset = (col + warp_id * 4) * elements_per_thread + k_idx * 256;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if (!Is_even_K && col_offset >= 576) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
} else {
constexpr int warp_size = 64;
int tidx = threadIdx.x;//0-256
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;//0-63
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);//576
const int offset_s = 0;
// global addr
// uint32x4_t global_addr = {0};
// *(uint64_t*)&global_addr = reinterpret_cast<uint64_t>(src.data().get());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
// global_addr[2] = 0xfffffffe;
// global_addr[3] = 0x00020000;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 16;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;//64*16*1
int mma_k = 64*64;
// int row = lane / 4;
// int col = lane % 4;
// int swizzle_col = ((row / 2) ^ (col )) * 4 + (col % 4);
// 此处待优化,后8行,行号需要交换
int virtual_row = lane / 8;//0
int virtual_col = lane % 8;//0
int swizzle_col = virtual_row ^ virtual_col;
int row = lane / 4;//0
// 8->9 9->8
row = (row >= 8 ) ^ row;
// row = row >= 8 ? (swizzle_col / 4) > 0 ? row + 1 : row - 1 : row;
int col = swizzle_col % 4;
int row_offset = row + (warp_id * 16) ;
int col_offset = col * elements_per_thread + k_idx * 64;
int offset_v = row_offset * row_stride + (col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
//int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + (k_idx % 2) * mma_k * element_size;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + (k_idx) * mma_k * element_size;
#if defined(__gfx938__)
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
}
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool Is_load_Q=false,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy_fp8(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_, const int row_stride,
const int max_MN=0)
{
if constexpr (Is_load_Q) {
} else {
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0;
// global addr
// uint32x4_t global_addr = {0};
// *(uint64_t*)&global_addr = reinterpret_cast<uint64_t>(src.data().get());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
// global_addr[2] = 0xfffffffe;
// global_addr[3] = 0x00020000;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 16;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 64*64;
// int row = lane / 4;
// int col = lane % 4;
// int swizzle_col = ((row / 2) ^ (col )) * 4 + (col % 4);
// 此处待优化,后8行,行号需要交换
int virtual_row = lane / 8;
int virtual_col = lane % 8;
int swizzle_col = virtual_row ^ virtual_col;
int row = lane / 4;
// 8->9 9->8
row = (row >= 8 ) ^ row;
// row = row >= 8 ? (swizzle_col / 4) > 0 ? row + 1 : row - 1 : row;
int col = swizzle_col % 4;
int row_offset = row + (warp_id * 16) ;
int col_offset = col * elements_per_thread + k_idx * 64;
int offset_v = row_offset * row_stride + (col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
// if (thread(0))
// {
// printf("offset_v = %d %d \n", offset_v, warp_id * bytes_per_warp + k_idx * mma_k * element_size);
// }
#if defined(__gfx936__) || defined(__gfx938__)
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
}
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy_tp1(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_, const int row_stride, int offset_r,
const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 2;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0;
// global addr
// uint32x4_t global_addr = {0};
// *(uint64_t*)&global_addr = reinterpret_cast<uint64_t>(src.data().get());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
// global_addr[2] = 0xfffffffe;
// global_addr[3] = 0x00020000;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * 8 * element_size;
int mma_k = 32*64;
int row = (lane / 8) * 4 + warp_id % 4 + offset_r;
int col = (lane % 8);
int row_offset = row ;
int col_offset = col * elements_per_thread + k_idx * 128 + (warp_id / 4) * 64;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if (!Is_even_K && col_offset >= 576) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + (warp_id / 4) * (32 * 64 * 2) +
(warp_id % 4) * 8 * 64 * 2 + k_idx * 32 * 128 * 2;
ldsAddrPerWave |= (((warp_id % 4) * 2) << 16);
// if (block0() && lane == 0)
// {
// printf(" %x \n", ldsAddrPerWave);
// }
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
template<int row, int col, typename Tensor1>
__forceinline__ __device__ void __ds_read_m32x16_row_col_lds(__fp16* lds_ptr, Tensor1& dst)
{
// auto lds = reinterpret_cast<__fp16 *>(src.data().get());
// auto layout = src.layout();
constexpr short offset = row * 32 * 64 * 2;
auto d = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds_ptr), offset);
uint16_t * d_ptr = reinterpret_cast<uint16_t*>(&d);
uint16_t * dst_ptr = reinterpret_cast<uint16_t*>(&(dst(0, row, col)));
dst_ptr[0] = d_ptr[0];
dst_ptr[1] = d_ptr[1];
dst_ptr[2] = d_ptr[2];
dst_ptr[3] = d_ptr[3];
dst_ptr[4] = d_ptr[4];
dst_ptr[5] = d_ptr[5];
dst_ptr[6] = d_ptr[6];
dst_ptr[7] = d_ptr[7];
}
template <class TiledMma, class TiledMma_O,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1
>
__forceinline__ __device__ auto convert_layout_acc_Aregs_tp1(const TiledMma& tiled_mma, const TiledMma_O& tiled_mma_o, Tensor<Engine0, Layout0> const& tOrP,
Tensor<Engine1, Layout1> const& sAcc)
{
int tidx = threadIdx.x;
int lane_id = tidx % 64;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / 64);
int row = (tidx % 16) + (warp_id % 4) * 16;
int col = lane_id / 16;
sAcc[(warp_id / 4) * 4 + col + (warp_id % 4) * 64 * 8 + (tidx % 16) * 8] = tOrP(0, 0, 0);
sAcc[(warp_id / 4) * 4 + col + (warp_id % 4) * 64 * 8 + (tidx % 16) * 8 + 16 * 8] = tOrP(1, 0, 0);
sAcc[(warp_id / 4) * 4 + col + (warp_id % 4) * 64 * 8 + (tidx % 16) * 8 + 32 * 8] = tOrP(2, 0, 0);
sAcc[(warp_id / 4) * 4 + col + (warp_id % 4) * 64 * 8 + (tidx % 16) * 8 + 48 * 8] = tOrP(3, 0, 0);
// sAcc[(warp_id / 4) * 4 + col + (warp_id % 4) * 64 * 8 + (tidx % 16) * 8 + 0 * 8 + 64 * 32] = tOrP(0, 0, 1);
// sAcc[(warp_id / 4) * 4 + col + (warp_id % 4) * 64 * 8 + (tidx % 16) * 8 + 16 * 8 + 64 * 32] = tOrP(1, 0, 1);
// sAcc[(warp_id / 4) * 4 + col + (warp_id % 4) * 64 * 8 + (tidx % 16) * 8 + 32 * 8 + 64 * 32] = tOrP(2, 0, 1);
// sAcc[(warp_id / 4) * 4 + col + (warp_id % 4) * 64 * 8 + (tidx % 16) * 8 + 48 * 8 + 64 * 32] = tOrP(3, 0, 1);
// sAcc[(warp_id / 4) * 4 + col + row * 8 + 64 * 8] = tOrP(1, 0, 0);
// sAcc[(warp_id / 4) * 4 + col + row * 8 + 64 * 16] = tOrP(2, 0, 0);
// sAcc[(warp_id / 4) * 4 + col + row * 8 + 64 * 24] = tOrP(3, 0, 0);
// sAcc[(warp_id / 4) * 4 + col + row * 8 + 64 * 32] = tOrP(0, 0, 1);
// sAcc[(warp_id / 4) * 4 + col + row * 8 + 64 * 8 + 64 * 32] = tOrP(1, 0, 1);
// sAcc[(warp_id / 4) * 4 + col + row * 8 + 64 * 16 + 64 * 32] = tOrP(2, 0, 1);
// sAcc[(warp_id / 4) * 4 + col + row * 8 + 64 * 24 + 64 * 32] = tOrP(3, 0, 1);
// 每个线程写入 4个元素, 0 256 1 257以此类推
// sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + (warp_id / 4) * 4 + 0] = tOrP(0, 0, 0);
// sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + (warp_id / 4) * 4 + 1] = tOrP(1, 0, 0);
// sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + (warp_id / 4) * 4 + 2] = tOrP(2, 0, 0);
// sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + (warp_id / 4) * 4 + 3] = tOrP(3, 0, 0);
// sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + (warp_id / 4) * 4 + 0 + 64*32] = tOrP(0, 0, 1);
// sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + (warp_id / 4) * 4 + 1 + 64*32] = tOrP(1, 0, 1);
// sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + (warp_id / 4) * 4 + 2 + 64*32] = tOrP(2, 0, 1);
// sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + (warp_id / 4) * 4 + 3 + 64*32] = tOrP(3, 0, 1);
// for (int n = 0; n < 2; n++)
// {
// for (int k = 0; k < 4; j++)
// {
// sAcc[]
// }
// }
// auto thr_mma = tiled_mma.get_thread_slice(tidx);
__syncthreads();
using SmemLayoutAtomP = Layout<Shape<Int<64>, Int<32>>, Stride<Int<64>, _1>>;
using SmemLayoutP = decltype(tile_to_shape(
SmemLayoutAtomP{},
Shape<Int<64>, Int<32>>{}));
Tensor sP_tmp = make_tensor(sAcc.data(),SmemLayoutP{});
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
Tensor tSrACC = thr_mma_o.partition_fragment_A(sP_tmp);
tSrACC(0, 0, 0) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 0];
tSrACC(1, 0, 0) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 1];
tSrACC(2, 0, 0) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 2];
tSrACC(3, 0, 0) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 3];
tSrACC(0, 0, 1) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 4];
tSrACC(1, 0, 1) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 5];
tSrACC(2, 0, 1) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 6];
tSrACC(3, 0, 1) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 7];
// tSrACC(0, 0, 2) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 0 + 64*32];
// tSrACC(1, 0, 2) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 1 + 64*32];
// tSrACC(2, 0, 2) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 2 + 64*32];
// tSrACC(3, 0, 2) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 3 + 64*32];
// tSrACC(0, 0, 3) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 4 + 64*32];
// tSrACC(1, 0, 3) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 5 + 64*32];
// tSrACC(2, 0, 3) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 6 + 64*32];
// tSrACC(3, 0, 3) = sAcc[lane_id * 8 + (warp_id % 4) * 64 * 8 + 7 + 64*32];
// if (tidx < 64 && block0())
// {
// printf(" %d %.2f %.2f %.2f %.2f\n ", tidx, float(tSrACC(0, 0, 1)),
// float(tSrACC(1, 0, 1)),
// float(tSrACC(2, 0, 1)),
// float(tSrACC(3, 0, 1))
// );
// }
return tSrACC;
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool Is_load_Q=false,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy_sparse_k(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_, const int row_stride,
const int max_MN=0)
{
#if defined(__gfx936__) || defined(__gfx938__)
if constexpr (Is_load_Q) {
// // 32x64
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 2;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * 8 * element_size;
int mma_k = 16*128;
int row = lane % 16;
int col = lane / 16;
int row_offset = row ;
int col_offset = (col + warp_id * 4) * elements_per_thread + k_idx * 128;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if (!Is_even_K && col_offset >= 576) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + 0 * mma_k * element_size;
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
#endif
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool Is_load_Q=false,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_, const int row_stride,
const int max_MN=0)
{
#if defined(__gfx936__) || defined(__gfx938__)
{
if constexpr (Is_load_Q) {
// // 32x64
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 2;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * 8 * element_size;
int mma_k = 16*128;
int row = lane % 16;
int col = lane / 16;
int row_offset = row ;
int col_offset = (col + warp_id * 4) * elements_per_thread + k_idx * 128;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if (!Is_even_K && col_offset >= 576) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
else {
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 2;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0;
// global addr
// uint32x4_t global_addr = {0};
// *(uint64_t*)&global_addr = reinterpret_cast<uint64_t>(src.data().get());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
// global_addr[2] = 0xfffffffe;
// global_addr[3] = 0x00020000;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * 8 * element_size;
int mma_k = 32*64;
// int row = lane / 4;
// int col = lane % 4;
// int swizzle_col = ((row / 2) ^ (col )) * 4 + (col % 4);
// 此处待优化,后8行,行号需要交换
int virtual_row = lane / 8;
int virtual_col = lane % 8;
int swizzle_col = virtual_row ^ virtual_col;
int row = lane / 4;
// 8->9 9->8
row = (row >= 8 ) ^ row;
// row = row >= 8 ? (swizzle_col / 4) > 0 ? row + 1 : row - 1 : row;
int col = swizzle_col % 4;
int row_offset = row + (warp_id * 16) ;
int col_offset = col * elements_per_thread + k_idx * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
}
#elif defined(__gfx928__)
{
}
#endif
}
template <bool Is_even_K=true,
bool Is_even_MN=true,
bool Use_cache_swizzle = true,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout
// class IdxEngine, class IdxLayout
>
CUTE_HOST_DEVICE
void
lds_direct_copy_for_prefill_sparse_mla(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int row_offset,
int col,
int k_idx_, const int row_stride, int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 2;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0;
// global addr
// uint32x4_t global_addr = {0};
// *(uint64_t*)&global_addr = reinterpret_cast<uint64_t>(src.data().get());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
// global_addr[2] = 0xfffffffe;
// global_addr[3] = 0x00020000;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
glob_ptr.latter |= ((row_stride * 2) << 16); // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
global_addr[2] = max_MN;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * 8 * element_size;
int mma_k = 32*64;
// int row = lane / 4;
// int col = lane % 4;
// int swizzle_col = ((row / 2) ^ (col )) * 4 + (col % 4);
// 此处待优化,后8行,行号需要交换
// int virtual_row = lane / 8;
// int virtual_col = lane % 8;
// int swizzle_col = virtual_row ^ virtual_col;
// int row = lane / 4;
// // 8->9 9->8
// row = (row >= 8 ) ^ row;
// // row = row >= 8 ? (swizzle_col / 4) > 0 ? row + 1 : row - 1 : row;
// int col = swizzle_col % 4;
// int row_offset = row + (warp_id * 16) ;
// row_offset = gIndices[row_offset];
int col_offset = col * elements_per_thread + k_idx * 32;
int offset_v = (col_offset) * element_size; // bytes
// int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
// if (!Is_even_MN && (row_offset >= max_MN || row_offset < 0)) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + (k_idx % 4) * mma_k * element_size;
typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
uint32x2_t index_offset = {0};
index_offset[0] = row_offset == -1 ? max_MN : row_offset;
index_offset[1] = offset_v;
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool mma_layout = false,
bool use_asm = false,
class SrcEngine, class SrcLayout
>
CUTE_HOST_DEVICE
void
buffer_load_copy_sparse_fp8(
Tensor<SrcEngine, SrcLayout> const& src,
uint64_t & dst,
int block_idx, int batch_stride,
int row_offset, int col,
int k_idx_, const int row_stride,
const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
constexpr int elements_per_thread = 8;
if constexpr (mma_layout)
{
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
glob_ptr.latter += (row_stride << 16); // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0xFFFFFFFE;
global_addr[3] = 0x00020000;
// int row = tidx % 16;
// int col = lane / 16;
// int row_offset = row + (warp_id * 16) ;
uint32_t col_offset = col * elements_per_thread + k_idx * 32;
// int offset_v = (((row_offset + 64 ) % 64) * row_stride + col_offset) * element_size; // bytes
// int offset_v = (((row_offset + 64 ) % 64) * row_stride + col_offset) * element_size + block_idx * batch_stride; // bytes
// uint32_t offset_v = col_offset * element_size + (batch_stride) * block_idx; // bytes
uint32_t offset_v = col_offset * element_size; // bytes
if (row_offset < 0) offset_v = -1;
if constexpr(use_asm) {
typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
uint32x2_t index_offset = {0};
index_offset[0] = (row_offset + 64 ) % 64;
index_offset[1] = offset_v;
asm volatile(
"buffer_load_dwordx2 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst),
"+v"(index_offset), "+s"(global_addr)
);
}
else {
// auto res = __builtin_amdgcn_buffer_load_dwordx2(global_addr, (row_offset + 64 ) % 64 , offset_v, false, false);
auto res = __builtin_amdgcn_buffer_load_dwordx2(global_addr, (row_offset + 64 ) % 64 + (batch_stride / row_stride) * block_idx , offset_v, false, false);
dst = *reinterpret_cast<uint64_t*>(&res);
}
}
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool mma_layout = false,
bool use_asm = false,
class SrcEngine, class SrcLayout
>
CUTE_HOST_DEVICE
void
buffer_load_copy_sparse_decoding(
Tensor<SrcEngine, SrcLayout> const& src,
uint32x4_t & dst,
int block_idx, int batch_stride,
int row_offset, int col,
int k_idx_, const int row_stride,
const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 2;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
constexpr int elements_per_thread = 8;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get()) + 512 + 16 ;
// glob_ptr.latter |= ((2) << 16); // 62 bit: cache swizzle; 48~61: Stride
glob_ptr.latter |= ((row_stride) << 16); // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
// global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
// global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
// global_addr[2] = 0x80000000;
global_addr[2] = 0xFFFFFFFE;
global_addr[3] = 0x00020000;
int mma_k = 32*64;
// int row = tidx % 16;
// int col = lane / 16;
// int row_offset = row + (warp_id * 16) ;
uint32_t col_offset = col * elements_per_thread + k_idx * 32;
// int offset_v = ((row_offset % 64 ) * row_stride + col_offset * element_size) + 512 + 16 + block_idx * batch_stride; // bytes
// uint32_t offset_v = (col_offset * element_size) + 512 + 16 + block_idx * batch_stride; // bytes
uint32_t offset_v = (col_offset * element_size) + ((row_offset + 64 ) % 64 ) * row_stride; // bytes
// uint32_t offset_v = (col_offset * element_size); // bytes
// uint32_t offset_v = (col_offset * element_size) + ((row_offset + 64 ) % 64 ) * row_stride; // bytes
// uint32_t offset_v = (col_offset * element_size) + 512 + 16; // bytes
// int offset_v = (row_offset * row_stride + col_offset) * element_size + 512 + 16; // bytes
// if (row_offset == -1) offset_v = -1;
if constexpr(use_asm) {
typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
uint32x2_t index_offset = {0};
index_offset[0] = (row_offset + 64 ) % 64;
index_offset[1] = offset_v;
asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst),
"+v"(index_offset), "+s"(global_addr)
);
}
else {
// auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, (row_offset + 64 ) % 64 + (batch_stride / row_stride) * block_idx , offset_v, false, false);
auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, (batch_stride / row_stride) * block_idx , offset_v, false, false);
// auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, (row_offset + 64 ) % 64, offset_v, false, false);
dst = *reinterpret_cast<uint32x4_t*>(&res);
}
}
// /*
// for _64x32, use thread layout is 64x4, per thread get 8 elements, get 64x32 data, put data in lds with 32x64
// for _16x128, use thread layout is 16x16, per thread get 8 elements, get 16x128 data, put data in lds with 32x64
// for _16x192, use thread layout is 16x16, per thread get 12 elements, get 16x192 data, put data in lds with 48x64
// for _16x64_128, use thread layout is 16x16, per thread get 4 elements with offset 128, get 16x64 data, put data in lds with 16x64
// */
// enum MMA_LAYOUT{ _64x32 /* for gemm0 load K */, _16x128 /* for gemm1 load V */, _16x192 /* for dim 192 */, _16x64_128 /* for dim 64 */, _16x64_64 /*for load dim 64 V*/ };
// template <bool Is_even_K=true,
// bool Is_even_MN=true,
// MMA_LAYOUT mma_layout = _64x32,
// int K_BUFF_SIZE = 0,
// class SrcEngine, class SrcLayout,
// class DstEngine, class DstLayout>
// CUTE_HOST_DEVICE
// void
// lds_direct_copy(
// Tensor<SrcEngine, SrcLayout> const& src,
// Tensor<DstEngine, DstLayout> & dst,
// int k_idx_, const int row_stride,
// const int max_K = 0, const int max_MN=0)
// {
// constexpr int warp_size = 64;
// int tidx = threadIdx.x;
// int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
// int lane = tidx % warp_size;
// constexpr int element_size = 2;
// int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
// int k_slide = k_idx;
// if constexpr(K_BUFF_SIZE) {
// k_slide = (k_idx % K_BUFF_SIZE);
// }
// const int offset_s = 0;
// // global addr
// uint32x4_t global_addr = {0};
// *(uint64_t*)&global_addr = reinterpret_cast<uint64_t>(src.data().get());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
// global_addr[2] = 0x80000000;
// global_addr[3] = 0x00020000;
// if constexpr(mma_layout == _64x32) {
// constexpr int elements_per_thread = 8;
// constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
// int mma_k = 32*64;
// int row = tidx % 16;
// int col = lane / 16;
// int row_offset = row + (warp_id * 16) ;
// int col_offset = col * elements_per_thread + k_idx * 32;
// int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
// if (!Is_even_K && col_offset >= max_K) offset_v = -1;
// if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
// int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
// #if defined(__gfx936__)
// asm volatile(
// "s_mov_b32 m0, %1 \n\t"
// "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
// "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
// :);
// #else
// #endif
// } else if constexpr(mma_layout == _16x128) {
// constexpr int elements_per_thread = 8;
// constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
// int mma_k = 16*128;
// int row = lane / 4;
// int col = tidx % 4;
// int row_offset = row + k_idx * 16;
// int col_offset = col * elements_per_thread + warp_id * 32;
// int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
// if (!Is_even_K && col_offset >= max_K) offset_v = -1;
// if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
// int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
// #if defined(__gfx936__)
// asm volatile(
// "s_mov_b32 m0, %1 \n\t"
// "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
// "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
// :);
// #endif
// } else if constexpr(mma_layout == _16x192) {
// constexpr int elements_per_thread = 8;
// constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
// int mma_k = 48*64;
// int row = lane / 4;
// int col = tidx % 4;
// int row_offset = row + k_idx * 16;
// int col_offset = col * elements_per_thread + warp_id * 32;
// int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
// if (!Is_even_K && col_offset >= max_K) offset_v = -1;
// if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
// int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
// #if defined(__gfx936__)
// asm volatile(
// "s_mov_b32 m0, %1 \n\t"
// "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
// "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
// :);
// #endif
// constexpr int elements_per_thread_tail = 4;
// constexpr int bytes_per_warp_tail = warp_size * elements_per_thread_tail * element_size;
// row = (tidx / 8) % 16;
// col = tidx % 8;
// row_offset = row + k_idx * 16;
// col_offset = col * elements_per_thread_tail + warp_id / 2 * 32 + /* pre offset */128 ;
// offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
// if (!Is_even_K && col_offset >= max_K) offset_v = -1;
// if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
// ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + /* pre offset */64*32 * element_size + warp_id * bytes_per_warp_tail + k_slide * mma_k * element_size;
// // if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
// #if defined(__gfx936__)
// asm volatile(
// "s_mov_b32 m0, %1 \n\t"
// "buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
// "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
// :);
// #endif
// } else if constexpr(mma_layout == _16x64_128) {
// #if 0
// constexpr int elements_per_thread = 4;
// constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
// int mma_k = 16*64;
// int row = (tidx / 8) % 16;
// int col = tidx % 8;
// int row_offset = row + k_idx * 16;
// int col_offset = col * elements_per_thread + warp_id / 2 * 32 + /* pre offset */128 ;
// int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
// if (!Is_even_K && col_offset >= max_K) offset_v = -1;
// if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
// int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
// // if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
// #if defined(__gfx936__)
// asm volatile(
// "s_mov_b32 m0, %1 \n\t"
// "buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, glc lds \n" ::"v"(offset_v),
// "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
// :);
// #endif
// #else
// constexpr int elements_per_thread = 8;
// constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
// int mma_k = 16*64;
// int row = lane / 4 + (warp_id / 2) * 16;
// int col = tidx % 4;
// int row_offset = row + k_idx * 16;
// int col_offset = col * elements_per_thread + (warp_id % 2) * 32 + 128;
// int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
// if (!Is_even_K && col_offset >= max_K) offset_v = -1;
// if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
// int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + (warp_id % 2) * bytes_per_warp + k_slide * mma_k * element_size + (warp_id/2)*mma_k * element_size ;
// // if (tidx < 256) printf("tid:%d offset_v:%d row %d col %d ldsAddrPerWave:%d\n", tidx, offset_v, row_offset, col_offset, (warp_id % 2) * bytes_per_warp + k_slide * mma_k * element_size + (warp_id/2)*mma_k * element_size);
// #if defined(__gfx936__)
// asm volatile(
// "s_mov_b32 m0, %1 \n\t"
// "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
// "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
// :);
// #endif
// #endif
// } else if constexpr(mma_layout == _16x64_64) {
// constexpr int elements_per_thread = 4;
// constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
// int mma_k = 16*64;
// int row = (tidx / 8) % 16;
// int col = tidx % 8;
// int row_offset = row + k_idx * 16;
// int col_offset = col * elements_per_thread + warp_id / 2 * 32;
// int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
// if (!Is_even_K && col_offset >= max_K) offset_v = -1;
// if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
// int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
// // if (tidx < 64) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
// #if defined(__gfx936__)
// asm volatile(
// "s_mov_b32 m0, %1 \n\t"
// "buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, glc lds \n" ::"v"(offset_v),
// "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
// :);
// #endif
// }
// }
#if 1
#define fp8 unsigned char
inline __device__ float fp8e4m3_to_fp32(const fp8& input) {
// const uint32_t w = (uint32_t)input << 24;
// const uint32_t sign = w & UINT32_C(0x80000000);
// const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
// uint32_t renorm_shift = __clz(nonsign);
// renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
// const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
// uint32_t result = sign | (((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) & (~zero_mask) );
const uint32_t w = (uint32_t)input << 24;
const uint32_t sign = w & UINT32_C(0x80000000);
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
uint32_t renorm_shift = __clz(nonsign);
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
uint32_t result = sign | ((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23));
union {
uint32_t as_bits;
float as_value;
} fp32 = {result};
return fp32.as_value;
}
__forceinline__ __device__ cutlass::bfloat16_t fp8e4m3_to_bf16(const fp8& input) {
const uint16_t w = (uint16_t)input << 8;
const uint16_t sign = w & UINT16_C(0x8000);
const uint16_t nonsign = w & UINT16_C(0x7FFF);
constexpr uint16_t exp_offset=(0x78 << 7);
uint16_t result = sign | ((nonsign >> 4) + exp_offset);
// if(nonsign == 0x0000) result = 0x0000;
// if (thread0() && nonsign == 0x0000)
// {
// printf(" input = %x result = %x\n", input, result);
// }
return cutlass::bfloat16_t::bitcast(result);
}
__forceinline__ __device__ float fp8e5m2_to_fp32(const fp8& input) {
union uf16{
uint16_t as_bits;
_Float16 as_value;
} ;
union uf32 {
uint32_t as_bits;
float as_value;
};
uf16 u16;
uf32 u32;
u16.as_bits = (uint16_t)input << 8;
u32.as_value = (float)u16.as_value;
// return u32.as_bits>>16;
return u32.as_value;
}
__forceinline__ __device__ cutlass::half_t fp8e5m2_to_fp16(const fp8& input) {
union uf16{
uint16_t as_bits;
__fp16 as_value;
} ;
union uf32 {
uint32_t as_bits;
float as_value;
};
uf16 u16;
// uf32 u32;
// u16.as_bits = (uint16_t)input << 8;
// u32.as_value = (float)u16.as_value;
// return u32.as_bits>>16;
uint16_t output = (uint16_t)(input << 8);
return cutlass::half_t::bitcast(output);
}
#else
#endif
template<typename Element, bool is_scale_equal_one, Fp8KVCacheDataType KV_DTYPE, typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB_int8, Tensor3 &tCrB, Tensor4 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B, const float& k_scale )
{
typedef __fp16 __fp16x8_t __attribute__((ext_vector_type(8)));
typedef unsigned int __hip_fp8x4_storage_t;
typedef unsigned short int __hip_fp8x2_storage_t;
typedef unsigned char __hip_fp8_storage_t;
union {
__fp16x8_t data_128;
__hip_fp8x4_storage_t fp8_array[4];
} data[8];
__builtin_amdgcn_sched_barrier(0);
wait_vmcnt<8>();
data[0].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 0));
wait_vmcnt<7>();
data[1].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 1));
wait_vmcnt<6>();
data[2].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 2));
wait_vmcnt<5>();
data[3].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 3));
wait_vmcnt<4>();
data[4].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 4));
wait_vmcnt<3>();
data[5].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 5));
wait_vmcnt<2>();
data[6].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 6));
wait_vmcnt<1>();
data[7].data_128 = *reinterpret_cast<__fp16x8_t *>(&tCsB(0, 0, 7));
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int k_idx = 0; k_idx < 8; k_idx++)
{
#pragma unroll
for (int j = 0; j < 16; j+=4) {
auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&data[k_idx].fp8_array[j / 4]);
auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&(data[k_idx].fp8_array[j / 4])) + 1);
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v<Element, cutlass::bfloat16_t>) {
auto f1 = (static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = (static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = (static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = (static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
auto rst0 = fp8e4m3_to_bf16(f1);
auto rst1 = fp8e4m3_to_bf16(f2);
auto rst2 = fp8e4m3_to_bf16(f3);
auto rst3 = fp8e4m3_to_bf16(f4);
tCrB(j, 0, k_idx) = rst0;
tCrB(j + 1, 0, k_idx) = rst1;
tCrB(j + 2, 0, k_idx) = rst2;
tCrB(j + 3, 0, k_idx) = rst3;
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v<Element, cutlass::bfloat16_t>) {
auto f1 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
if constexpr(!is_scale_equal_one) {
f1 *= k_scale;
f2 *= k_scale;
f3 *= k_scale;
f4 *= k_scale;
}
cutlass::NumericConverter<Element, float, cutlass::FloatRoundStyle::round_toward_zero> convert_;
auto rst0 = convert_(f1);
auto rst1 = convert_(f2);
auto rst2 = convert_(f3);
auto rst3 = convert_(f4);
tCrB(j, 0, k_idx) = rst0;
tCrB(j + 1, 0, k_idx) = rst1;
tCrB(j + 2, 0, k_idx) = rst2;
tCrB(j + 3, 0, k_idx) = rst3;
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v<Element, cutlass::half_t>) {
// auto f1 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
// auto f2 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
// auto f3 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
// auto f4 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
// tCrB(j, 0, k_idx) = f1;
// tCrB(j + 1, 0, k_idx) = f2;
// tCrB(j + 2, 0, k_idx) = f3;
// tCrB(j + 3, 0, k_idx) = f4;
__hip_fp8x4_storage_t fp8_data = data[k_idx].fp8_array[j / 4];
union Fp8_data_union{
__hip_fp8x4_storage_t fp8x4;
uint16_t fp16[2];
} ;
Fp8_data_union first_fp8, last_fp8;
first_fp8.fp8x4 = ((fp8_data & 0xff00ff00));
last_fp8.fp8x4 = ((fp8_data & 0x00ff00ff) << 8);
tCrB(j, 0, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[0]);
tCrB(j + 1, 0, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[0]);;
tCrB(j + 2, 0, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[1]);;
tCrB(j + 3, 0, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[1]);;;
}
else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v<Element, cutlass::half_t>) {
auto f1 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
if constexpr(!is_scale_equal_one) {
f1 *= k_scale;
f2 *= k_scale;
f3 *= k_scale;
f4 *= k_scale;
}
auto rst0 = __builtin_amdgcn_cvt_pkrtz(f1, f2);
auto rst1 = __builtin_amdgcn_cvt_pkrtz(f3, f4);
cutlass::Array<half_t, 2> result0 = reinterpret_cast<cutlass::Array<half_t, 2> &>(rst0);
cutlass::Array<half_t, 2> result1 = reinterpret_cast<cutlass::Array<half_t, 2> &>(rst1);
tCrB(j, 0, k_idx) = result0[0];
tCrB(j + 1, 0, k_idx) = result0[1];
tCrB(j + 2, 0, k_idx) = result1[0];
tCrB(j + 3, 0, k_idx) = result1[1];
}
}
cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc);
}
}
#if 1
template<typename Element, int k_idx, bool is_scale_equal_one, Fp8KVCacheDataType KV_DTYPE, typename Tensor0, typename Tensor1,typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm_k_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, TiledMma tiled_mma, uint32x4_t& _data, const float& k_scale)
{
typedef __fp16 __fp16x8_t __attribute__((ext_vector_type(8)));
typedef unsigned int __hip_fp8x4_storage_t;
typedef unsigned short int __hip_fp8x2_storage_t;
typedef unsigned char __hip_fp8_storage_t;
union {
uint32x4_t data_128;
__hip_fp8x4_storage_t fp8_array[4];
} data;
data.data_128 = _data;
for (int j = 0; j < 16; j+=4) {
auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&data.fp8_array[j / 4]);
auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&(data.fp8_array[j / 4])) + 1);
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v<Element, cutlass::bfloat16_t>) {
auto f1 = (static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = (static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = (static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = (static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
auto rst0 = fp8e4m3_to_bf16(f1);
auto rst1 = fp8e4m3_to_bf16(f2);
auto rst2 = fp8e4m3_to_bf16(f3);
auto rst3 = fp8e4m3_to_bf16(f4);
tCrB(j, 0, k_idx) = rst0;
tCrB(j + 1, 0, k_idx) = rst1;
tCrB(j + 2, 0, k_idx) = rst2;
tCrB(j + 3, 0, k_idx) = rst3;
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v<Element, cutlass::bfloat16_t>) {
auto f1 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
if constexpr(!is_scale_equal_one) {
f1 *= k_scale;
f2 *= k_scale;
f3 *= k_scale;
f4 *= k_scale;
}
// if (thread0()) {
// printf(" static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8) = %x f1 = %.2f\n", static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), f1);
// }
cutlass::NumericConverter<Element, float, cutlass::FloatRoundStyle::round_toward_zero> convert_;
auto rst0 = convert_(f1);
auto rst1 = convert_(f2);
auto rst2 = convert_(f3);
auto rst3 = convert_(f4);
tCrB(j, 0, k_idx) = rst0;
tCrB(j + 1, 0, k_idx) = rst1;
tCrB(j + 2, 0, k_idx) = rst2;
tCrB(j + 3, 0, k_idx) = rst3;
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v<Element, cutlass::half_t>) {
// auto f1 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
// auto f2 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
// auto f3 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
// auto f4 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
// tCrB(j, 0, k_idx) = f1;
// tCrB(j + 1, 0, k_idx) = f2;
// tCrB(j + 2, 0, k_idx) = f3;
// tCrB(j + 3, 0, k_idx) = f4;
__hip_fp8x4_storage_t fp8_data = data.fp8_array[j / 4];
union Fp8_data_union{
__hip_fp8x4_storage_t fp8x4;
uint16_t fp16[2];
} ;
Fp8_data_union first_fp8, last_fp8;
first_fp8.fp8x4 = ((fp8_data & 0xff00ff00));
last_fp8.fp8x4 = ((fp8_data & 0x00ff00ff) << 8);
tCrB(j, 0, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[0]);
tCrB(j + 1, 0, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[0]);;
tCrB(j + 2, 0, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[1]);;
tCrB(j + 3, 0, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[1]);;;
}
else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v<Element, cutlass::half_t>) {
auto f1 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
if constexpr(!is_scale_equal_one) {
f1 *= k_scale;
f2 *= k_scale;
f3 *= k_scale;
f4 *= k_scale;
}
auto rst0 = __builtin_amdgcn_cvt_pkrtz(f1, f2);
auto rst1 = __builtin_amdgcn_cvt_pkrtz(f3, f4);
cutlass::Array<half_t, 2> result0 = reinterpret_cast<cutlass::Array<half_t, 2> &>(rst0);
cutlass::Array<half_t, 2> result1 = reinterpret_cast<cutlass::Array<half_t, 2> &>(rst1);
tCrB(j, 0, k_idx) = result0[0];
tCrB(j + 1, 0, k_idx) = result0[1];
tCrB(j + 2, 0, k_idx) = result1[0];
tCrB(j + 3, 0, k_idx) = result1[1];
}
}
cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc);
}
#else
#endif
template<typename Element, bool is_scale_equal_one, Fp8KVCacheDataType KV_DTYPE, typename Tensor0, typename Tensor1, typename Tensor3, typename Tensor4,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm1_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor3 &tCrB, Tensor4 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B, const float& k_scale ) {
typedef __fp16 __fp16x8_t __attribute__((ext_vector_type(8)));
typedef unsigned int __hip_fp8x4_storage_t;
typedef unsigned short int __hip_fp8x2_storage_t;
typedef unsigned char __hip_fp8_storage_t;
auto lds = reinterpret_cast<__fp16 *>(&tCsB(0, 0, 0));
auto layout = tCsB.layout();
union {
__fp16x8_t data_128;
__hip_fp8x4_storage_t fp8_array[4];
} data[8];
constexpr short offset0 = layout(0, 0, 0) * 2;
data[0].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset0);
constexpr short offset1 = layout(0, 1, 0) * 2;
data[1].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset1);
constexpr short offset2 = layout(0, 0, 1) * 2;
data[2].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset2);
constexpr short offset3 = layout(0, 1, 1) * 2;
data[3].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset3);
constexpr short offset4 = layout(0, 0, 2) * 2;
data[4].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset4);
constexpr short offset5 = layout(0, 1, 2) * 2;
data[5].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset5);
constexpr short offset6 = layout(0, 0, 3) * 2;
data[6].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset6);
constexpr short offset7 = layout(0, 1, 3) * 2;
data[7].data_128 = __builtin_amdgcn_ds_read_m32x16f16((__attribute__((address_space(3))) __fp16*)(lds), offset7);
#pragma unroll
for (int k_idx = 0; k_idx < 4; k_idx++) {
#pragma unroll
for (int i = 0; i < 2; i++) {
#pragma unroll
for (int j = 0; j < 16; j+=4) {
auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&data[k_idx * 2 + i].fp8_array[j / 4]);
auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&(data[k_idx * 2 + i].fp8_array[j / 4])) + 1);
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v<Element, cutlass::bfloat16_t>) {
// cutlass::NumericConverter<Element, float, cutlass::FloatRoundStyle::round_toward_zero> convert_;
auto f1 = (static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = (static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = (static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = (static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
auto rst0 = fp8e4m3_to_bf16(f1);
auto rst1 = fp8e4m3_to_bf16(f2);
auto rst2 = fp8e4m3_to_bf16(f3);
auto rst3 = fp8e4m3_to_bf16(f4);
tCrB(j, i, k_idx) = rst0;
tCrB(j + 1, i, k_idx) = rst1;
tCrB(j + 2, i, k_idx) = rst2;
tCrB(j + 3, i, k_idx) = rst3;
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v<Element, cutlass::bfloat16_t>) {
auto f1 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = fp8e5m2_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
if constexpr(!is_scale_equal_one) {
f1 *= k_scale;
f2 *= k_scale;
f3 *= k_scale;
f4 *= k_scale;
}
cutlass::NumericConverter<Element, float, cutlass::FloatRoundStyle::round_toward_zero> convert_;
auto rst0 = convert_(f1);
auto rst1 = convert_(f2);
auto rst2 = convert_(f3);
auto rst3 = convert_(f4);
tCrB(j, i, k_idx) = rst0;
tCrB(j + 1, i, k_idx) = rst1;
tCrB(j + 2, i, k_idx) = rst2;
tCrB(j + 3, i, k_idx) = rst3;
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E5M2 && std::is_same_v<Element, cutlass::half_t>) {
// auto f1 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
// auto f2 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
// auto f3 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
// auto f4 = fp8e5m2_to_fp16(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
// tCrB(j, i, k_idx) = f1;
// tCrB(j + 1, i, k_idx) = f2;
// tCrB(j + 2, i, k_idx) = f3;
// tCrB(j + 3, i, k_idx) = f4;
__hip_fp8x4_storage_t fp8_data = data[k_idx * 2 + i].fp8_array[j / 4];
union Fp8_data_union{
__hip_fp8x4_storage_t fp8x4;
uint16_t fp16[2];
} ;
Fp8_data_union first_fp8, last_fp8;
first_fp8.fp8x4 = ((fp8_data & 0xff00ff00));
last_fp8.fp8x4 = ((fp8_data & 0x00ff00ff) << 8);
tCrB(j, i, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[0]);
tCrB(j + 1, i, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[0]);;
tCrB(j + 2, i, k_idx) = cutlass::half_t::bitcast(last_fp8.fp16[1]);;
tCrB(j + 3, i, k_idx) = cutlass::half_t::bitcast(first_fp8.fp16[1]);;;
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kFp8E4M3 && std::is_same_v<Element, cutlass::half_t>){
auto f1 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
if constexpr(!is_scale_equal_one) {
f1 *= k_scale;
f2 *= k_scale;
f3 *= k_scale;
f4 *= k_scale;
}
auto rst0 = __builtin_amdgcn_cvt_pkrtz(f1, f3);
auto rst1 = __builtin_amdgcn_cvt_pkrtz(f2, f4);
cutlass::Array<half_t, 2> result0 = reinterpret_cast<cutlass::Array<half_t, 2> &>(rst0);
cutlass::Array<half_t, 2> result1 = reinterpret_cast<cutlass::Array<half_t, 2> &>(rst1);
tCrB(j, i, k_idx) = result0[0];
tCrB(j + 1, i, k_idx) = result1[0];
tCrB(j + 2, i, k_idx) = result0[1];
tCrB(j + 3, i, k_idx) = result1[1];
}
}
}
cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc);
}
}
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm_k_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B, int k_idx) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_setprio(0);
// __builtin_amdgcn_sched_barrier(0);
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// {
// printf(" %d %p\n", threadIdx.x, &tCsB(0, 0, k_idx));
// }
cute::copy(smem_tiled_copy_B, tCsB(_, _, k_idx), tCrB_copy_view(_, _, k_idx));
// if (block0())
// {
// printf("thrid %d %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f \n", threadIdx.x, float(tCrB_copy_view(0, 0, 0)), float(tCrB_copy_view(1, 0, 0)), float(tCrB_copy_view(2, 0, 0)),
// float(tCrB_copy_view(3, 0, 0)), float(tCrB_copy_view(4, 0, 0)), float(tCrB_copy_view(5, 0, 0)), float(tCrB_copy_view(6, 0, 0)), float(tCrB_copy_view(7, 0, 0))
// );
// }
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_setprio(1);
// __builtin_amdgcn_sched_barrier(0);
cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc);
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool Is_load_Q=false,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy_qkvfp8_q_tp1(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_, const int row_stride,
const int max_MN=0
)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;//0-256
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;//0-63
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);//576
const int offset_s = 0;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 16;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;//64*16*1
int mma_k = 64*64;
int row = tidx % 16 + (warp_id % 4) * 16;
int col = (lane / 16) * 16 + (warp_id / 4) * 64 + k_idx * 128;
int offset_v = row * row_stride + (col) * element_size; // bytes
if (!Is_even_MN && row >= max_MN) offset_v = -1;
if (!Is_even_K && col >= 576) offset_v = -1;
//int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + (k_idx % 2) * mma_k * element_size;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + (warp_id % 4) * bytes_per_warp + (k_idx ) * 64*128 * element_size + (warp_id / 4) * 64 * 64;
#if defined(__gfx938__)
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool Is_load_Q=false,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy_qkvfp8_q_tp4(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_, const int row_stride,
const int max_MN=0
)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;//0-256
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;//0-63
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);//576
const int offset_s = 0;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 16;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;//64*16*1
int mma_k = 64*64;
int row = tidx % 16 + (warp_id % 2) * 16;
int col = (lane / 16) * 16 + (warp_id / 2) * 64 + k_idx * 128;
int offset_v = row * row_stride + (col) * element_size; // bytes
if (!Is_even_MN && row >= max_MN) offset_v = -1;
if (!Is_even_K && col >= 576) offset_v = -1;
//int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + (k_idx % 2) * mma_k * element_size;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + (warp_id % 2) * bytes_per_warp + (k_idx ) * 32*128 * element_size + (warp_id / 2) * 32 * 64;
#if defined(__gfx938__)
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool Is_load_Q=false,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy_qkvfp8_tp1(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_, const int row_stride,
const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;//0-256
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;//0-63
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);//576
const int offset_s = 0;
// global addr
// uint32x4_t global_addr = {0};
// *(uint64_t*)&global_addr = reinterpret_cast<uint64_t>(src.data().get());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
// global_addr[2] = 0xfffffffe;
// global_addr[3] = 0x00020000;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 16;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;//64*16*1
int mma_k = 64*64;
// int row = lane / 4;
// int col = lane % 4;
// int swizzle_col = ((row / 2) ^ (col )) * 4 + (col % 4);
// 此处待优化,后8行,行号需要交换
int virtual_row = lane / 8;//0
int virtual_col = lane % 8;//0
int swizzle_col = virtual_row ^ virtual_col;
int row = lane / 4;//0
// 8->9 9->8
row = (row >= 8 ) ^ row;
// row = row >= 8 ? (swizzle_col / 4) > 0 ? row + 1 : row - 1 : row;
int col = swizzle_col % 4;
int row_offset = row + ((warp_id % 4) * 16) ;
int col_offset = col * elements_per_thread + k_idx * 128 + (warp_id / 4) * 64;
int offset_v = row_offset * row_stride + (col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if (!Is_even_K && col_offset >= 576) offset_v = -1;
//int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + (k_idx % 2) * mma_k * element_size;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + (warp_id % 4) * bytes_per_warp + (k_idx ) * 64*128 * element_size + (warp_id / 4) * 64 * 64;
#if defined(__gfx938__)
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
}
// template <class TiledMma, class TiledMma_O,
// typename Engine0, typename Layout0,
// typename Engine1, typename Layout1
// >
// __forceinline__ __device__ auto convert_layout_acc_Aregs(const TiledMma& tiled_mma, const TiledMma_O& tiled_mma_o, Tensor<Engine0, Layout0> const& tOrP,
// Tensor<Engine1, Layout1> const& sAcc)
// {
// using Value_type = typename Engine0::value_type;
// int tidx = threadIdx.x;
// auto smem_tiled_copy_ACC = make_tiled_copy_C(Copy_Atom<DefaultCopy, Value_type>{}, tiled_mma);
// auto smem_thr_copy_ACC = smem_tiled_copy_ACC.get_thread_slice(tidx);
// Tensor taccOr = smem_thr_copy_ACC.retile_S(tOrP);
// Tensor taccOs = smem_thr_copy_ACC.partition_D(sAcc);
// cute::copy(smem_tiled_copy_ACC, taccOr, taccOs);
// // asm volatile("s_waitcnt lgkmcnt(0)\n\t");
// __syncthreads();
// // wangaq debug
// // if (tidx == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// // int col = 8;
// // for (int i = 0; i < 16*64/col; ++i) {
// // printf("sP:%d ", i);
// // for (int j = 0; j < col; ++j) {
// // printf("%10.4f ", float(sAcc(i*col+j)));
// // }
// // printf("\n");
// // }
// // }
// auto thr_mma = tiled_mma_o.get_thread_slice(tidx);
// auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom<DefaultCopy, Value_type>{}, tiled_mma_o);
// auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx);
// Tensor tSsACC = smem_thr_copy_A.partition_S(sAcc);
// Tensor tSrACC = thr_mma.partition_fragment_A(sAcc);
// Tensor tSrACC_copy_view = smem_thr_copy_A.retile_D(tSrACC);
// cute::copy(smem_tiled_copy_ACC, tSsACC, tSrACC_copy_view);
// return tSrACC;
// }
template <class TiledMma, class TiledMma_O,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1
>
__forceinline__ __device__ auto convert_layout_acc_Aregs(const TiledMma& tiled_mma, const TiledMma_O& tiled_mma_o, Tensor<Engine0, Layout0> const& tOrP,
Tensor<Engine1, Layout1> const& sAcc)
{
using Value_type = typename Engine0::value_type;
int tid = threadIdx.x % 64;
int warp_id = threadIdx.x / 64;
// __fp16 *smem_ptr =
// sAcc((tid % 16 ) * 4 + (tid / 16) + warp_id * 16 * 16) = tOrP(0, 0, 0);
// sAcc((tid % 16 ) * 4 + (tid / 16) + 16 * 4 + warp_id * 16 * 16) = tOrP(1, 0, 0);
// sAcc((tid % 16 ) * 4 + (tid / 16) + 2 * 16 * 4 + warp_id * 16 * 16) = tOrP(2, 0, 0);
// sAcc((tid % 16 ) * 4 + (tid / 16) + 3 * 16 * 4 + warp_id * 16 * 16) = tOrP(3, 0, 0);
sAcc((tid % 16 ) * 8 + (tid / 16) + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(0, 0, 0);
sAcc((tid % 16 ) * 8 + (tid / 16) + 1 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(1, 0, 0);
sAcc((tid % 16 ) * 8 + (tid / 16) + 2 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(2, 0, 0);
sAcc((tid % 16 ) * 8 + (tid / 16) + 3 * 16 * 8 + (warp_id % 2) * 4 + (warp_id / 2) * 16 * 32) = tOrP(3, 0, 0);
__syncthreads();
using SmemLayoutAtomP = Layout<Shape<Int<16>, Int<64>>, Stride<Int<64>, _1>>;
using SmemLayoutP = decltype(tile_to_shape(
SmemLayoutAtomP{},
Shape<Int<16>, Int<64>>{}));
Tensor sP_tmp = make_tensor(sAcc.data(),SmemLayoutP{});
auto thr_mma = tiled_mma_o.get_thread_slice(tid);
Tensor tSrACC = thr_mma.partition_fragment_A(sP_tmp);
// for (int i = 0; i < 4; i++)
// {
// tSrACC(i, 0, 0) = sAcc(tid * 8 + i);
// tSrACC(i, 0, 1) = sAcc(tid * 8 + i + 4);
// tSrACC(i, 0, 2) = sAcc(tid * 8 + i + 16 * 32);
// tSrACC(i, 0, 3) = sAcc(tid * 8 + i + 16 * 32 + 4);
// }
tSrACC(0, 0, 0) = sAcc(tid * 8 + 0);
tSrACC(1, 0, 0) = sAcc(tid * 8 + 1);
tSrACC(2, 0, 0) = sAcc(tid * 8 + 2);
tSrACC(3, 0, 0) = sAcc(tid * 8 + 3);
tSrACC(0, 0, 1) = sAcc(tid * 8 + 0 + 4);
tSrACC(1, 0, 1) = sAcc(tid * 8 + 1 + 4);
tSrACC(2, 0, 1) = sAcc(tid * 8 + 2 + 4);
tSrACC(3, 0, 1) = sAcc(tid * 8 + 3 + 4);
tSrACC(0, 0, 2) = sAcc(tid * 8 + 0 + 16*32);
tSrACC(1, 0, 2) = sAcc(tid * 8 + 1 + 16*32);
tSrACC(2, 0, 2) = sAcc(tid * 8 + 2 + 16*32);
tSrACC(3, 0, 2) = sAcc(tid * 8 + 3 + 16*32);
tSrACC(0, 0, 3) = sAcc(tid * 8 + 0 + 4 + 16*32);
tSrACC(1, 0, 3) = sAcc(tid * 8 + 1 + 4 + 16*32);
tSrACC(2, 0, 3) = sAcc(tid * 8 + 2 + 4 + 16*32);
tSrACC(3, 0, 3) = sAcc(tid * 8 + 3 + 4 + 16*32);
// tSrACC(i, 0, 1) = sAcc(tid * 8 + i + 4);
// tSrACC(i, 0, 2) = sAcc(tid * 8 + i + 16 * 32);
// tSrACC(i, 0, 3) = sAcc(tid * 8 + i + 16 * 32 + 4);
// tSrACC(1, 0, 0) = sAcc(tid * 8);
// for (int k = 0; k < 4; k++)
// {
// tSrACC(0, 0, k) = sAcc(k * 16 * 16 + tid * 4);
// tSrACC(1, 0, k) = sAcc(k * 16 * 16 + tid * 4 + 1);
// tSrACC(2, 0, k) = sAcc(k * 16 * 16 + tid * 4 + 2);
// tSrACC(3, 0, k) = sAcc(k * 16 * 16 + tid * 4 + 3);
// }
// for (int k = 0; k < 4; k++)
// {
// tSrACC(0, 0, k) = sAcc(k * 16 * 16 + tid * 4);
// tSrACC(1, 0, k) = sAcc(k * 16 * 16 + tid * 4 + 1);
// tSrACC(2, 0, k) = sAcc(k * 16 * 16 + tid * 4 + 2);
// tSrACC(3, 0, k) = sAcc(k * 16 * 16 + tid * 4 + 3);
// }
// auto smem_tiled_copy_ACC = make_tiled_copy_C(Copy_Atom<DefaultCopy, Value_type>{}, tiled_mma);
// auto smem_thr_copy_ACC = smem_tiled_copy_ACC.get_thread_slice(tidx);
// Tensor taccOr = smem_thr_copy_ACC.retile_S(tOrP);
// Tensor taccOs = smem_thr_copy_ACC.partition_D(sAcc);
// cute::copy(smem_tiled_copy_ACC, taccOr, taccOs);
// // asm volatile("s_waitcnt lgkmcnt(0)\n\t");
// __syncthreads();
// wangaq debug
// if (tidx == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// int col = 8;
// for (int i = 0; i < 16*64/col; ++i) {
// printf("sP:%d ", i);
// for (int j = 0; j < col; ++j) {
// printf("%10.4f ", float(sAcc(i*col+j)));
// }
// printf("\n");
// }
// }
// auto thr_mma = tiled_mma_o.get_thread_slice(tidx);
// auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom<DefaultCopy, Value_type>{}, tiled_mma_o);
// auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx);
// Tensor tSsACC = smem_thr_copy_A.partition_S(sAcc);
// Tensor tSrACC = thr_mma.partition_fragment_A(sP_tmp);
// Tensor tSrACC_copy_view = smem_thr_copy_A.retile_D(tSrACC);
// cute::copy(smem_tiled_copy_ACC, tSsACC, tSrACC_copy_view);
return tSrACC;
}
#if 0
template <class TiledMma, class TiledMma_O,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1
>
__forceinline__ __device__ auto convert_layout_acc_Aregs_fp8(const TiledMma& tiled_mma, const TiledMma_O& tiled_mma_o, Tensor<Engine0, Layout0> const& tOrP,
Tensor<Engine1, Layout1> const& sAcc)
{
using Value_type = typename Engine0::value_type;
int tid = threadIdx.x % 64;
int warp_id = threadIdx.x / 64;
sAcc((tid % 16 ) * 8 + (tid / 16)*4 + (tid / 32)*120 + 0 + warp_id * 32 * 8) = tOrP(0, 0, 0);
sAcc((tid % 16 ) * 8 + (tid / 16)*4 + (tid / 32)*120 + 1 + warp_id * 32 * 8) = tOrP(1, 0, 0);
sAcc((tid % 16 ) * 8 + (tid / 16)*4 + (tid / 32)*120 + 2 + warp_id * 32 * 8) = tOrP(2, 0, 0);
sAcc((tid % 16 ) * 8 + (tid / 16)*4 + (tid / 32)*120 + 3 + warp_id * 32 * 8) = tOrP(3, 0, 0);
__syncthreads();
using SmemLayoutAtomP = Layout<Shape<Int<16>, Int<64>>, Stride<Int<64>, _1>>;
using SmemLayoutP = decltype(tile_to_shape(
SmemLayoutAtomP{},
Shape<Int<16>, Int<64>>{}));
Tensor sP_tmp = make_tensor(sAcc.data(),SmemLayoutP{});
auto thr_mma = tiled_mma_o.get_thread_slice(tid);
Tensor tSrACC = thr_mma.partition_fragment_A(sP_tmp);
tSrACC(0, 0, 0) = sAcc(tid * 8 + 0);
tSrACC(1, 0, 0) = sAcc(tid * 8 + 1);
tSrACC(2, 0, 0) = sAcc(tid * 8 + 2);
tSrACC(3, 0, 0) = sAcc(tid * 8 + 3);
tSrACC(4, 0, 0) = sAcc(tid * 8 + 4);
tSrACC(5, 0, 0) = sAcc(tid * 8 + 5);
tSrACC(6, 0, 0) = sAcc(tid * 8 + 6);
tSrACC(7, 0, 0) = sAcc(tid * 8 + 7);
tSrACC(0, 0, 1) = sAcc(tid * 8 + 0 + 16 * 32);
tSrACC(1, 0, 1) = sAcc(tid * 8 + 1 + 16 * 32);
tSrACC(2, 0, 1) = sAcc(tid * 8 + 2 + 16 * 32);
tSrACC(3, 0, 1) = sAcc(tid * 8 + 3 + 16 * 32);
tSrACC(4, 0, 1) = sAcc(tid * 8 + 4 + 16 * 32);
tSrACC(5, 0, 1) = sAcc(tid * 8 + 5 + 16 * 32);
tSrACC(6, 0, 1) = sAcc(tid * 8 + 6 + 16 * 32);
tSrACC(7, 0, 1) = sAcc(tid * 8 + 7 + 16 * 32);
return tSrACC;
}
#else
template <class TiledMma, class TiledMma_O,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1
>
__forceinline__ __device__ void convert_layout_acc_Aregs_fp8(const TiledMma& tiled_mma, const TiledMma_O& tiled_mma_o, Tensor<Engine0, Layout0> const& tOrP,
Tensor<Engine1, Layout1> const& sAcc, intx4_t &data)
{
using Value_type = typename Engine0::value_type;
int tid = threadIdx.x % 64;
int warp_id = threadIdx.x / 64;
sAcc[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id % 2) * (16 * 32) + (warp_id / 2) * (8) + 0] = tOrP(0, 0, 0);
sAcc[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id % 2) * (16 * 32) + (warp_id / 2) * (8) + 1] = tOrP(1, 0, 0);
sAcc[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id % 2) * (16 * 32) + (warp_id / 2) * (8) + 2] = tOrP(2, 0, 0);
sAcc[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id % 2) * (16 * 32) + (warp_id / 2) * (8) + 3] = tOrP(3, 0, 0);
__syncthreads();
data = *reinterpret_cast<intx4_t*>(&(sAcc[tid * 16]));
}
#endif
#if 0
template<int row, int col, int r_row, typename Tensor0, typename Tensor1>
__forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, Tensor1& dst)
{
auto lds = reinterpret_cast<int *>(src.data().get());
auto layout = src.layout();
constexpr short offset = layout(0, row, col) * 1;
auto d = __builtin_amdgcn_ds_read_m32x32u8((__attribute__((address_space(3))) int*)(lds), offset);
uint8_t * d_ptr = reinterpret_cast<uint8_t*>(&d);
uint8_t * dst_ptr = reinterpret_cast<uint8_t*>(&(dst(0, r_row, col)));
dst_ptr[0] = d_ptr[0];
dst_ptr[1] = d_ptr[1];
dst_ptr[2] = d_ptr[2];
dst_ptr[3] = d_ptr[3];
dst_ptr[4] = d_ptr[4];
dst_ptr[5] = d_ptr[5];
dst_ptr[6] = d_ptr[6];
dst_ptr[7] = d_ptr[7];
dst_ptr[8] = d_ptr[8];
dst_ptr[9] = d_ptr[9];
dst_ptr[10] = d_ptr[10];
dst_ptr[11] = d_ptr[11];
dst_ptr[12] = d_ptr[12];
dst_ptr[13] = d_ptr[13];
dst_ptr[14] = d_ptr[14];
dst_ptr[15] = d_ptr[15];
}
#else
template<int row, int col, int r_row, typename Tensor0>
__forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, intx4_t& dst)
{
auto lds = reinterpret_cast<int *>(src.data().get());
auto layout = src.layout();
constexpr short offset = layout(0, row, col) * 1;
auto d = __builtin_amdgcn_ds_read_m32x32u8((__attribute__((address_space(3))) int*)(lds), offset);
dst = d;
}
#endif
/*
原来的 exp2f 对于极小数有特殊处理, 对于小于 -126 的输入 x , exp2f 计算方式是 2^(x + 64) * 2^{-64}
但是对于深度学习来说, 2^-126 的数字其实没那么重要了, 因此只需要保留 v_exp_f32 直接暴力计算即可
*/
extern __device__ __attribute__((const)) float __llvm_exp2_f32(float) __asm("llvm.exp2.f32");
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
...@@ -17,6 +17,7 @@ compute_attn_1rowblock_splitkv_mla_kvfp8(const DenseAttnDecodeParams_fp8 &params ...@@ -17,6 +17,7 @@ compute_attn_1rowblock_splitkv_mla_kvfp8(const DenseAttnDecodeParams_fp8 &params
const int n_split_idx, const int seqlen_k, const int n_split_idx, const int seqlen_k,
const int n_block_min, const int n_block_max, const bool NoSplit) const int n_block_min, const int n_block_max, const bool NoSplit)
{ {
#if 0
constexpr static bool Is_causal = T::Is_causal; constexpr static bool Is_causal = T::Is_causal;
constexpr int kBlockM = T::kBlockM; constexpr int kBlockM = T::kBlockM;
constexpr int kBlockN = T::kBlockN; constexpr int kBlockN = T::kBlockN;
...@@ -384,9 +385,7 @@ compute_attn_1rowblock_splitkv_mla_kvfp8(const DenseAttnDecodeParams_fp8 &params ...@@ -384,9 +385,7 @@ compute_attn_1rowblock_splitkv_mla_kvfp8(const DenseAttnDecodeParams_fp8 &params
} }
} }
} }
#endif
} }
template<typename T> template<typename T>
......
...@@ -184,9 +184,9 @@ inline constexpr bool is_decode_v = std::bool_constant<FWD_MODE == SparseAttnFwd ...@@ -184,9 +184,9 @@ inline constexpr bool is_decode_v = std::bool_constant<FWD_MODE == SparseAttnFwd
template<SparseAttnFwdMode FWD_MODE> template<SparseAttnFwdMode FWD_MODE>
using SparseFwdArgT = std::conditional_t<is_decode_v<FWD_MODE>, SparseAttnDecodeParams, SparseAttnFwdParams>; using SparseFwdArgT = std::conditional_t<is_decode_v<FWD_MODE>, SparseAttnDecodeParams, SparseAttnFwdParams>;
enum class Fp8KVCacheDataType { // enum class Fp8KVCacheDataType {
kAuto = 0, // kAuto = 0,
kFp8E4M3 = 1, // kFp8E4M3 = 1,
kFp8E5M2 = 2, // kFp8E5M2 = 2,
kInt8 = 3, // kInt8 = 3,
}; // };
...@@ -1081,7 +1081,7 @@ void wait_vmcnt() { ...@@ -1081,7 +1081,7 @@ void wait_vmcnt() {
"s_barrier; \n\t" "s_barrier; \n\t"
:: "n"(N)); :: "n"(N));
} }
#if 0
template<typename Element, bool is_scale_equal_one, Fp8KVCacheDataType KV_DTYPE, typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, typename Tensor4, template<typename Element, bool is_scale_equal_one, Fp8KVCacheDataType KV_DTYPE, typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename TiledCopy, typename ThrCopy> typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB_int8, Tensor3 &tCrB, Tensor4 const& tCsB, __forceinline__ __device__ void gemm_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB_int8, Tensor3 &tCrB, Tensor4 const& tCsB,
...@@ -1302,7 +1302,7 @@ __forceinline__ __device__ void gemm_k_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tens ...@@ -1302,7 +1302,7 @@ __forceinline__ __device__ void gemm_k_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tens
} }
cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc); cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc);
} }
#endif
template < template <
bool Is_even_MN=true, bool Is_even_MN=true,
bool Is_even_K=true, bool Is_even_K=true,
...@@ -1367,7 +1367,7 @@ buffer_load_copy_fp8( ...@@ -1367,7 +1367,7 @@ buffer_load_copy_fp8(
} }
} }
#if 0
template<typename Element, bool is_scale_equal_one, Fp8KVCacheDataType KV_DTYPE, typename Tensor0, typename Tensor1, typename Tensor3, typename Tensor4, template<typename Element, bool is_scale_equal_one, Fp8KVCacheDataType KV_DTYPE, typename Tensor0, typename Tensor1, typename Tensor3, typename Tensor4,
typename TiledMma, typename TiledCopy, typename ThrCopy> typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm1_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor3 &tCrB, Tensor4 const& tCsB, __forceinline__ __device__ void gemm1_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor3 &tCrB, Tensor4 const& tCsB,
...@@ -1500,8 +1500,7 @@ __forceinline__ __device__ void gemm1_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor ...@@ -1500,8 +1500,7 @@ __forceinline__ __device__ void gemm1_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor
} }
} }
#endif
} }
\ No newline at end of file
...@@ -4,14 +4,27 @@ from flash_mla.flash_mla_interface import ( ...@@ -4,14 +4,27 @@ from flash_mla.flash_mla_interface import (
get_mla_metadata, get_mla_metadata,
flash_mla_with_kvcache, flash_mla_with_kvcache,
flash_mla_sparse_fwd, flash_mla_sparse_fwd,
flash_mla_with_kvcache_qkvfp8, get_mla_decoding_metadata_dense_fp8,
flash_mla_with_kvcache_kvfp8 flash_mla_with_kvcache_quantization,
flash_mla_with_kvcache_q_nope_pe,
flash_mla_with_kvcache_quantization_q_nope_pe,
flash_mla_with_kvcache_fp8,
flash_mla_with_kvcache_fp8_with_cat
) )
__all__ = [ __all__ = [
"get_mla_metadata", "get_mla_metadata",
"flash_mla_with_kvcache", "flash_mla_with_kvcache",
"flash_mla_sparse_fwd", "flash_mla_sparse_fwd",
"flash_mla_with_kvcache_qkvfp8", "get_mla_decoding_metadata_dense_fp8",
"flash_mla_with_kvcache_kvfp8" "flash_mla_with_kvcache_quantization",
"flash_mla_with_kvcache_q_nope_pe",
"flash_mla_with_kvcache_quantization_q_nope_pe",
"flash_mla_with_kvcache_fp8",
"flash_mla_with_kvcache_fp8_with_cat"
] ]
import os
FLASH_MLA_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
# print(FLUX_ROOT_DIR)
os.environ["FLASH_MLA_ROOT_DIR"] = FLASH_MLA_ROOT_DIR + "/asm/"
\ No newline at end of file
clang -x assembler -target amdgcn-amd-amdhsa -mcode-object-version=4 -mcpu=gfx938:sramecc+ -c -o flash_fwd_mla_fp8_gfx938-hip-amdgcn-amd-amdhsa-gfx938.o flash_fwd_mla_fp8_gfx938-hip-amdgcn-amd-amdhsa-gfx938.s
clang -target amdgcn-amd-amdhsa flash_fwd_mla_fp8_gfx938-hip-amdgcn-amd-amdhsa-gfx938.o -o flash_fwd_mla_fp8_gfx938-hip-amdgcn-amd-amdhsa-gfx938.co
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment