Commit 8a69b46c authored by zhanghj2's avatar zhanghj2
Browse files

Merge branch 'feature/bmz-das-prefill-glm5-nheads64' into 'master'

优化 nmz和bmz dsa prefill,nhead=64

See merge request dcutoolkit/deeplearing/flashmla!6
parents a9ef79c6 14b2cfc5
......@@ -124,5 +124,35 @@ static void run(const SparseAttnFwdParams &params);
};
template<int D_QK, bool HAVE_TOPK_LENGTH, bool IS_TOPK_2048>
class KernelTemplate_B_H_64
{
public:
static constexpr int D_Q = D_QK;
static constexpr int D_K = D_QK;
static constexpr int D_V = 512;
static constexpr int kNWarps = 4;
static constexpr int B_H = 64;
static constexpr int B_TOPK = 64; // TopK block size
static constexpr int NUM_THREADS = kNWarps * 64;
static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits)
using Element = cutlass::bfloat16_t;
using elem_type = Element;
using ElementAccum = float;
using index_t = int64_t;
static constexpr int kBlockM = B_H;
static constexpr int kBlockN = B_TOPK;
static constexpr int kHeadDim = D_QK;
static constexpr int kHeadDimV = D_V;
static __device__ __forceinline__ void
devfunc(const SparseAttnFwdParams &params);
static void run(const SparseAttnFwdParams &params);
};
};
......@@ -10,6 +10,762 @@ namespace gfx93::fwd {
using namespace cute;
template<int D_QK, bool HAVE_TOPK_LENGTH, bool IS_TOPK_2048>
__device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::devfunc(const SparseAttnFwdParams &params) {
const int tidx = threadIdx.x;
static constexpr int kBlockM = B_H;
static constexpr int kBlockN = B_TOPK;
static constexpr int kHeadDim = D_QK;
static constexpr int kHeadDimV = D_V;
const int warp_idx = __builtin_amdgcn_readfirstlane(tidx / 64);
const int s_q_idx = blockIdx.y;
const int bidh = blockIdx.x;
const int lane_idx = tidx % 64;
extern __shared__ Element smem[];
Element* q_lds = (Element*)&(smem);
Element* k_lds = q_lds;
Element* v_lds = q_lds;
int* sIndices = (int *)(q_lds + 8192);
const index_t row_offset_q = s_q_idx * static_cast<index_t>(params.stride_q_s_q) + bidh * kBlockM * params.stride_q_h_q;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.stride_q_h_q, _1{}));
const index_t row_offset_k = 0 * params.stride_kv_h_kv;
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.kv) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.stride_kv_s_kv, _1{}));
const index_t row_offset_topk = s_q_idx * params.stride_indices_s_q;
int* gIndices = reinterpret_cast<int *>(params.indices) + row_offset_topk;
typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8)));
typedef __bf16 __fp16x4_t __attribute__((ext_vector_type(4)));
typedef __bf16 __fp16x2_t __attribute__((ext_vector_type(2)));
union Bf16_storage {
__fp16x8_t data_128;
__fp16x4_t data_64[2];
__fp16x2_t data_32[4];
uint16_t data_array[8];
};
union Bf16_storage_x4 {
__fp16x4_t data_64;
__fp16x2_t data_32[2];
uint16_t data[4];
};
const int topk_length = HAVE_TOPK_LENGTH ? __ldg(params.topk_length + s_q_idx) : params.topk;
const int num_topk_blocks = IS_TOPK_2048? 2048 / B_TOPK : HAVE_TOPK_LENGTH ? ku::ceil_div(topk_length, (int)B_TOPK) : (int)((unsigned int)params.topk/(unsigned int)B_TOPK);
// TiledMMA tiled_mma = TiledMma{};
// auto thr_mma = tiled_mma.get_thread_slice(tidx);
flash::Softmax<1> softmax;
// #if 1
// #if defined(__gfx938__)
// #else
int virtual_row_ = lane_idx / 8;//0
int virtual_col_ = lane_idx % 8;//0
int swizzle_col_ = virtual_row_ ^ virtual_col_;
int row_ = lane_idx / 4;//0
// 8->9 9->8
// row_ = (row_ >= 8 ) ^ row_;
int col_ = swizzle_col_ % 4;
// #endif
auto calc_row_and_col_k = [&](const int block_idx) -> std::tuple<int, int> {
constexpr int elements_per_thread = 8;
// int row = lane_idx % 16;
// int col = lane_idx / 16;
// int row_offset = row * 4 + warp_idx + block_idx * kBlockN;
#if defined(__gfx938__)
// int row = lane_idx / 4;
// int col = lane_idx % 4;
// col = (col + (4 - (row / 2) % 4)) % 4;
// int row_offset = row + warp_idx * 16 + block_idx * kBlockN;
// int col_offset = col * 8;
int row_offset = row_ + warp_idx * 16 + block_idx * kBlockN;
int col_offset = col_ * 8;
#else
int row_offset = row_ * 4 + warp_idx + block_idx * kBlockN;
int col_offset = col_ * 8;
#endif
// int row_offset = row + warp_idx * 16 + block_idx * kBlockN;
if constexpr (IS_TOPK_2048) {
row_offset = sIndices[row_offset % 1024];
} else {
row_offset = gIndices[row_offset];
}
return {row_offset, col_offset};
};
auto calc_row_and_col_v = [&](const int block_idx, int i) -> int {
int row = lane_idx / 4;
// int col = lane_idx % 4;
int row_offset = row + i * 16 + block_idx * kBlockN;;
// int col_offset = col * 8 + warp_idx * 32;
if constexpr (IS_TOPK_2048) {
row_offset = sIndices[row_offset % 1024];
} else {
row_offset = gIndices[row_offset];
}
row_offset = row_offset == -1 ? params.s_kv : row_offset;
return row_offset;
};
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr_q;
*(uint64_t*)&glob_ptr_q = reinterpret_cast<uint64_t>(gQ.data().get());
glob_ptr_q.latter |= ((params.stride_q_h_q * 2) << 16);
glob_ptr_q.latter |= 0x40000000;
uint32x4_t global_addr_q = {0};
global_addr_q[0] = (glob_ptr_q.former);
global_addr_q[1] = (glob_ptr_q.latter);
global_addr_q[2] = 64;
global_addr_q[3] = 0x00020000;
PtrWrapper glob_ptr_indices;
*(uint64_t*)&glob_ptr_indices = reinterpret_cast<uint64_t>(gIndices);
// glob_ptr_indices.latter |= ((params.stride_indices_s_q * 4) << 16);
// *(uint64_t*)&glob_ptr_indices = reinterpret_cast<uint64_t>(params.indices);
// glob_ptr_indices.latter |= ((params.stride_indices_s_q * 4) << 16);
glob_ptr_indices.latter |= 0x40000000;
uint32x4_t global_addr_indices = {0};
global_addr_indices[0] = (glob_ptr_indices.former);
global_addr_indices[1] = (glob_ptr_indices.latter);
global_addr_indices[2] = 0x80000000;
global_addr_indices[3] = 0x00020000;
auto buffer_load_lds_indices = [&] (int n) {
constexpr int element_size = 4;
int ldsAddrPerWave = reinterpret_cast<size_t>(sIndices) + warp_idx * 64 * 4 * 4;
typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
// uint32x2_t index_offset = {0};
// index_offset[0] = s_q_idx;
// index_offset[1] = lane_idx * 4 * 4 + warp_idx * 64 * 4 * 4;
const int offset_v = lane_idx * 4 * 4 + warp_idx * 64 * 4 * 4;
const int offset_s = n * 1024 * 4;
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr_indices), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0);
};
buffer_load_lds_indices(0);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
PtrWrapper glob_ptr_k;
*(uint64_t*)&glob_ptr_k = reinterpret_cast<uint64_t>(gK.data().get());
glob_ptr_k.latter |= ((params.stride_kv_s_kv * 2) << 16);
glob_ptr_k.latter |= 0x40000000;
uint32x4_t global_addr_k = {0};
global_addr_k[0] = (glob_ptr_k.former);
global_addr_k[1] = (glob_ptr_k.latter);
global_addr_k[2] = params.s_kv;
global_addr_k[3] = 0x00020000;
auto buffer_load_lds_k = [&](int row_offset, int col, int k_idx) {
constexpr int element_size = 2;
// int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
// struct PtrWrapper {
// uint32_t former;
// uint32_t latter;
// };
// PtrWrapper glob_ptr;
// *(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(gK.data().get());
// glob_ptr.latter |= ((row_stride * 2) << 16);
// uint32x4_t global_addr = {0};
// global_addr[0] = (glob_ptr.former);
// global_addr[1] = (glob_ptr.latter);
// global_addr[2] = max_MN;
// global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8;
int col_offset = col;
int offset_v = col_offset * 2;
int ldsAddrPerWave = reinterpret_cast<size_t>(k_lds) + warp_idx * 16 * 32 * 2 + (k_idx % 4) * 64 * 32 * 2;
typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
uint32x2_t index_offset = {0};
index_offset[0] = row_offset;
index_offset[1] = offset_v;
const int offset_s = k_idx * 32 * 2;
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset),
"s"(ldsAddrPerWave), "s"(global_addr_k), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0);
};
auto buffer_load_lds_v = [&](int row_offset, int col, int k_idx, int n_idx) {
constexpr int element_size = 2;
// int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
// struct PtrWrapper {
// uint32_t former;
// uint32_t latter;
// };
// PtrWrapper glob_ptr;
// *(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(gK.data().get());
// glob_ptr.latter |= ((row_stride * 2) << 16);
// uint32x4_t global_addr = {0};
// global_addr[0] = (glob_ptr.former);
// global_addr[1] = (glob_ptr.latter);
// global_addr[2] = max_MN;
// global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8;
int col_offset = col;
// int v_idx = row_offset;
int offset_v = col_offset * 2;
int ldsAddrPerWave = reinterpret_cast<size_t>(v_lds) + warp_idx * 16 * 32 * 2 + (k_idx % 1) * 512 * 16 * 2 + n_idx * 128 * 16 * 2;
typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
uint32x2_t index_offset = {0};
index_offset[0] = row_offset;
index_offset[1] = offset_v;
const int offset_s = n_idx * 128 * 2;
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset),
"s"(ldsAddrPerWave), "s"(global_addr_k), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0);
};
const int v_lds_read_ptr = reinterpret_cast<size_t>(v_lds + lane_idx * 8);
auto k_lds_read_offset = [&] () -> int {
// #if defined(__gfx938__)
// int row = lane_idx % 16;
// int col = lane_idx / 16;
// col = (col + (row / 2) % 4) % 4;
// const auto lds_offset = row * 32 + col * 8;
// #else
int row = lane_idx % 16;
int col = lane_idx / 16;
col = (row / 2) ^ col;
col = col % 4;
// row = (row >= 8) ^ row;
const auto lds_offset = row * 32 + col * 8;
// #endif
return lds_offset;
};
Element* q_lds_read_ptr = (q_lds + warp_idx * 16 * 32 + lane_idx * 8);
Element* k_lds_read_ptr = (k_lds + k_lds_read_offset());
Bf16_storage q_reg[18];
for (int i = 0; i < 18; i++)
{
constexpr int elements_per_thread = 8;
int row = lane_idx % 16;
int col = lane_idx / 16;
int row_offset = row + warp_idx * 16;
int col_offset = col * 8;
int offset_v = col_offset * 2 + i * 32 * 2;
q_reg[i].data_128 = __builtin_amdgcn_buffer_load_dwordx4(global_addr_q, row_offset, offset_v, false, false);
}
__syncthreads();
v4f acco_f32[32];
for (int i = 0; i < 32; i++)
{
acco_f32[i].x = 0.0f;
acco_f32[i].y = 0.0f;
acco_f32[i].z = 0.0f;
acco_f32[i].w = 0.0f;
}
int col_offset_v = (lane_idx % 4) * 8 + warp_idx * 32;
struct IsFirstBlock {};
struct IsOtherBlock {};
auto float2bf16 = [] (float s) -> uint16_t {
uint32_t x32 = reinterpret_cast<uint32_t const &>(s);
#ifndef FLASH_MLA_BF16_TYPE
#define FLASH_MLA_BF16_TYPE 0
#endif
#if FLASH_MLA_BF16_TYPE == 1
x32 += 0x8000u;
#endif
return uint16_t(x32 >> 16);
};
auto process_one_block = [&] (int block_idx, auto is_block_t) {
static constexpr bool IS_FIRST_BLOCK = std::is_same_v<decltype(is_block_t), IsFirstBlock>;
static constexpr bool IS_OTHER_BLOCK = std::is_same_v<decltype(is_block_t), IsOtherBlock>;
v4f accs_f32[4];
for (int i = 0; i < 4; i++)
{
accs_f32[i].x = 0.0f;
accs_f32[i].y = 0.0f;
accs_f32[i].z = 0.0f;
accs_f32[i].w = 0.0f;
}
auto [row_offset, col] = calc_row_and_col_k(block_idx);
row_offset = row_offset == -1 ? params.s_kv : row_offset;
#if 1
#define LOAD_K_AND_QK_GEMM(k) \
{ \
constexpr int k_val = (k); \
buffer_load_lds_k(row_offset, col, k_val - 3); \
flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
}
{
constexpr int k_val = (17);
buffer_load_lds_k(row_offset, col, k_val);
buffer_load_lds_k(row_offset, col, k_val - 1);
buffer_load_lds_k(row_offset, col, k_val - 2);
buffer_load_lds_k(row_offset, col, k_val - 3);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
LOAD_K_AND_QK_GEMM(16);
LOAD_K_AND_QK_GEMM(15);
LOAD_K_AND_QK_GEMM(14);
LOAD_K_AND_QK_GEMM(13);
LOAD_K_AND_QK_GEMM(12);
LOAD_K_AND_QK_GEMM(11);
LOAD_K_AND_QK_GEMM(10);
LOAD_K_AND_QK_GEMM(9);
LOAD_K_AND_QK_GEMM(8);
LOAD_K_AND_QK_GEMM(7);
LOAD_K_AND_QK_GEMM(6);
LOAD_K_AND_QK_GEMM(5);
LOAD_K_AND_QK_GEMM(4);
LOAD_K_AND_QK_GEMM(3);
flash::qk_gemm<Element, k_val - 15>(q_reg[k_val - 15].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, k_val - 16>(q_reg[k_val - 16].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, k_val - 17>(q_reg[k_val - 17].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
}
#else
#define LOAD_K_AND_QK_GEMM(k) \
{ \
constexpr int k_val = (k); \
buffer_load_lds_k(row_offset, col, k_val); \
buffer_load_lds_k(row_offset, col, k_val + 1); \
buffer_load_lds_k(row_offset, col, k_val + 2); \
buffer_load_lds_k(row_offset, col, k_val + 3); \
buffer_load_lds_k(row_offset, col, k_val + 4); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::qk_gemm<Element, k_val + 1>(q_reg[k_val + 1].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::qk_gemm<Element, k_val + 2>(q_reg[k_val + 2].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::qk_gemm<Element, k_val + 3>(q_reg[k_val + 3].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::qk_gemm<Element, k_val + 4>(q_reg[k_val + 4].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_barrier \n\t"); \
__builtin_amdgcn_sched_barrier(0); \
}
LOAD_K_AND_QK_GEMM(0);
LOAD_K_AND_QK_GEMM(5);
LOAD_K_AND_QK_GEMM(10);
{
constexpr int k_val = (15);
buffer_load_lds_k(row_offset, col, k_val);
buffer_load_lds_k(row_offset, col, k_val + 1);
buffer_load_lds_k(row_offset, col, k_val + 2);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, k_val + 1>(q_reg[k_val + 1].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, k_val + 2>(q_reg[k_val + 2].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier \n\t");
__builtin_amdgcn_sched_barrier(0);
}
#endif
auto is_valid_token = [&](const int idx) -> bool {
const int n_idx = (lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16;
int offs = n_idx + block_idx * kBlockN;
int t;
if constexpr (IS_TOPK_2048) {
t = sIndices[offs % 1024];
} else {
t = gIndices[offs];
}
bool is_cur_token_valid = t >= 0 && t < params.s_kv;
if constexpr (HAVE_TOPK_LENGTH) {
is_cur_token_valid = is_cur_token_valid && (offs < topk_length);
}
return is_cur_token_valid;
};
for (int i = 0; i < 16; ++i) {
#if defined(__gfx938__)
if (!is_valid_token(i)) accs_f32[i/4][i%4] = -INFINITY;
#else
if (!is_valid_token(i)) accs_f32[i%4][i/4] = -INFINITY;
#endif
}
// Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
// Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
Tensor scores = make_tensor<float>(Shape<_1, _16>{});
for (int i = 0; i < 16; i++) {
#if defined(__gfx938__)
scores(0, i) = accs_f32[i/4][i%4];
#else
scores(0, i) = accs_f32[i%4][i/4];
#endif
}
softmax.template softmax_rescale_o_prefill_4x1</*Is_first=*/IS_FIRST_BLOCK, /*Check_inf=*//*Is_local=*/false>(scores, acco_f32, params.sm_scale_div_log2);
Bf16_storage_x4 p[4];
for (int i = 0; i < 4; i++)
{
#if defined(__gfx938__)
p[i].data_32[0] = __builtin_hcu_cvt_pk_bf16_f32(0, scores(0, i * 4), 0, scores(0, i * 4 + 1), 0);
p[i].data_32[1] = __builtin_hcu_cvt_pk_bf16_f32(0, scores(0, i * 4 + 2), 0, scores(0, i * 4 + 3), 0);
#else
p[i].data[0] = float2bf16(scores(0, i * 4));
p[i].data[1] = float2bf16(scores(0, i * 4 + 1));
p[i].data[2] = float2bf16(scores(0, i * 4 + 2));
p[i].data[3] = float2bf16(scores(0, i * 4 + 3));
#endif
}
int row_offset_v[4];
for (int i = 0; i < 4; i++)
{
row_offset_v[i] = calc_row_and_col_v(block_idx, i);
}
__syncthreads();
#if 1
{
constexpr int k_val = (0);
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 0);
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 1);
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 2);
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 3);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::pv_gemm<k_val, 0>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 0);
flash::pv_gemm<k_val, 4>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 5>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 6>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 7>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 1);
flash::pv_gemm<k_val, 8>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 9>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 10>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 11>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 2);
flash::pv_gemm<k_val, 12>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 13>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 14>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 15>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 3);
}
#define LOAD_V_AND_PV_GEMM(k) \
{ \
constexpr int k_val = (k); \
flash::pv_gemm<k_val, 0>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 0); \
flash::pv_gemm<k_val, 4>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 5>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 6>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 7>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 1); \
flash::pv_gemm<k_val, 8>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 9>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 10>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 11>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 2); \
flash::pv_gemm<k_val, 12>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 13>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 14>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 15>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 3); \
}
LOAD_V_AND_PV_GEMM(1);
LOAD_V_AND_PV_GEMM(2);
{
constexpr int k_val = (3);
flash::pv_gemm<k_val, 0>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::pv_gemm<k_val, 4>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 5>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 6>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 7>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::pv_gemm<k_val, 8>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 9>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 10>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 11>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::pv_gemm<k_val, 12>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 13>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 14>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
flash::pv_gemm<k_val, 15>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
}
#else
#define LOAD_V_AND_PV_GEMM(k) \
{ \
constexpr int k_val = (k); \
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 0); \
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 1); \
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 2); \
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 3); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::pv_gemm<k_val, 0>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::pv_gemm<k_val, 4>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 5>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 6>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 7>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::pv_gemm<k_val, 8>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 9>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 10>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 11>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::pv_gemm<k_val, 12>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 13>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 14>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 15>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_barrier \n\t"); \
__builtin_amdgcn_sched_barrier(0); \
}
LOAD_V_AND_PV_GEMM(0);
LOAD_V_AND_PV_GEMM(1);
LOAD_V_AND_PV_GEMM(2);
LOAD_V_AND_PV_GEMM(3);
#endif
};
if constexpr (IS_TOPK_2048)
{
process_one_block(0, IsFirstBlock{});
for (int block_idx = 1; block_idx < 1024 / B_TOPK; block_idx ++)
{
process_one_block(block_idx, IsOtherBlock{});
}
buffer_load_lds_indices(1);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
for (int block_idx = 1024/B_TOPK; block_idx < 2048 / B_TOPK; block_idx ++)
{
process_one_block(block_idx, IsOtherBlock{});
}
}
else
{
process_one_block(0, IsFirstBlock{});
for (int block_idx = 1; block_idx < num_topk_blocks; block_idx ++)
{
process_one_block(block_idx, IsOtherBlock{});
}
}
Tensor lse = softmax.template normalize_softmax_lse_prefill_4x1<false>(acco_f32, params.sm_scale);
// if (block0())
// {
// printf(" threadIdx.x %d %.3f %.3f %.3f %.3f \n", threadIdx.x,
// acco_f32[0].x,
// acco_f32[0].y,
// acco_f32[0].z,
// acco_f32[0].w
// );
// }
const index_t row_offset_o = s_q_idx * static_cast<index_t>(params.h_q * params.d_v) + bidh * kBlockM * params.d_v;
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.out) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(params.d_v, _1{}));
const index_t row_offset_lse = s_q_idx * params.h_q + bidh * kBlockM;
float* gLSE = reinterpret_cast<float *>(params.lse) + row_offset_lse;
// const index_t row_offset_lse = m_block * params.h_q;
float* gMax_logits = reinterpret_cast<float *>(params.max_logits) + row_offset_lse;
{
// store O and gLSE
// auto rO = flash::convert_type<Element>(acc_o);
int row, col;
// const int warpId = tidx / 64;
// const int laneId = tidx % 64;
for (int mi = 0; mi < 1; ++mi) {
row = mi * kBlockM + lane_idx % 16 + warp_idx * 16;
// if (row < params.h_q)
{
for (int ni = 0; ni < 16; ++ni) {
#if defined(__gfx938__)
Bf16_storage res;
col = (lane_idx / 16) * 8 + ni * 32 ;
res.data_32[0] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][0], 0, acco_f32[ni * 2 + 1][0], 0);
res.data_32[1] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][1], 0, acco_f32[ni * 2 + 1][1], 0);
res.data_32[2] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][2], 0, acco_f32[ni * 2 + 1][2], 0);
res.data_32[3] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][3], 0, acco_f32[ni * 2 + 1][3], 0);
*(__fp16x8_t*)(&gO(row, col)) = res.data_128;
#else
col = (lane_idx / 16) * 2 + ni * 32 ;
using result_type = cutlass::Array<Element, 2>;
for (int ei = 0; ei < 4; ei++)
{
result_type res;
Element e0, e1;
e0.storage = float2bf16(acco_f32[ni * 2][ei]);
e1.storage = float2bf16(acco_f32[ni * 2 + 1][ei]);
res[0] = e0;
res[1] = e1;
// gO(row, col) = res[0];
// gO(row, col + 1) = res[1];
*(result_type*)(&gO(row, col)) = res;
col += 8;
}
#endif
}
gLSE[row] = lse(mi);
if constexpr (HAVE_TOPK_LENGTH)
{
gMax_logits[row] = topk_length == 0 ? -INFINITY : softmax.row_max(mi) * params.sm_scale;
}
else
{
gMax_logits[row] = softmax.row_max(mi) * params.sm_scale;
}
}
}
}
}
template<int D_QK, bool HAVE_TOPK_LENGTH>
__device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttnFwdParams &params) {
extern __shared__ char smem_[];
......@@ -529,7 +1285,9 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
template<typename Kernel>
__global__ void __launch_bounds__(Kernel::NUM_THREADS, 1)
sparse_attn_fwd_kernel(const SparseAttnFwdParams params) {
// #if defined(__gfx936__)
Kernel::devfunc(params);
// #endif
}
template<int D_QK, bool HAVE_TOPK_LENGTH>
......@@ -545,9 +1303,36 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams &para
KU_CHECK_KERNEL_LAUNCH();
}
template<int D_QK, bool HAVE_TOPK_LENGTH, bool IS_TOPK_2048>
void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::run(const SparseAttnFwdParams &params) {
KU_ASSERT(params.h_kv == 1);
// KU_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings
KU_ASSERT(params.topk > 0);
// KU_ASSERT(params.h_q % B_H == 0);
auto kernel = &sparse_attn_fwd_kernel<KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>>;
constexpr size_t smem_size = 16384 + 4096; // 做了lds复用
dim3 grid((params.h_q + B_H - 1) / B_H, params.s_q, 1);
kernel<<<grid, NUM_THREADS, smem_size, params.stream>>>(params);
KU_CHECK_KERNEL_LAUNCH();
}
template<int D_QK, bool HAVE_TOPK_LENGTH>
void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) {
KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(params);
if (params.h_q == 64 && !HAVE_TOPK_LENGTH && D_QK == 576 && !params.attn_sink)
{
if (params.topk == 2048)
{
KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, true>::run(params);
}
else
{
KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, false>::run(params);
}
}
else
{
KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(params);
}
}
}
......@@ -602,6 +602,95 @@ struct Softmax {
}
return lse;
};
template<bool Is_first, bool Check_inf=false, typename Tensor0>
__forceinline__ __device__ void softmax_rescale_o_prefill_4x1(Tensor0& scores, v4f* acc_o, float softmax_scale_log2) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
MaxOp<float> max_op;
// Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
if constexpr(Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
// static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = !true
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
#if 0
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#else
float scores_scale = __builtin_amdgcn_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#endif
// if (blockIdx.x == 0 && threadIdx.x == 0)
// {
// printf("threadIdx.x %.2f, scores_scale = %.4f\n",row_sum(mi), scores_scale );
// }
row_sum(mi) *= scores_scale;
for (int i = 0; i < 32; i++)
{
acc_o[i].x *= scores_scale;
acc_o[i].y *= scores_scale;
acc_o[i].z *= scores_scale;
acc_o[i].w *= scores_scale;
}
}
// if (blockIdx.x == 2)
// {
// printf("threadIdx.x %.2f \n",row_sum(mi) );
// }
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
// if (thread0())
// {
// printf("max sum %.3f %.3f \n", row_max(0), row_sum(0));
// }
};
template<bool Is_dropout=false, bool Split=false>
__forceinline__ __device__ TensorT normalize_softmax_lse_prefill_4x1(v4f *acc_o, float softmax_scale, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
// flash::template warp_allreduce_(row_sum, sRow_sum_reduce_buffer, sum_op);
TensorT lse = make_fragment_like(row_sum);
// Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
// static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
// if (thread0())
// {
// printf(" %.3f %.3f \n", row_max(0), row_sum(0));
// }
#pragma unroll
for (int mi = 0; mi < 1; ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
for (int i = 0; i < 32; i++)
{
acc_o[i].x *= scale;
acc_o[i].y *= scale;
acc_o[i].z *= scale;
acc_o[i].w *= scale;
}
}
return lse;
};
};
......
......@@ -1523,6 +1523,91 @@ __forceinline__ __device__ void gemm1_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor
}
#endif
typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8)));
template<typename Element, int k_idx>
__forceinline__ __device__ void qk_gemm(const __fp16x8_t& q_data, Element* k_lds_read_ptr, v4f* accs_f32)
{
typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8)));
typedef __bf16 __fp16x4_t __attribute__((ext_vector_type(4)));
union Bf16_storage {
__fp16x8_t data_128;
__fp16x4_t data_64[2];
uint16_t data_array[8];
};
constexpr int k_idx_even = k_idx % 4;
constexpr int n_offset = 16 * 32;
constexpr int k_offset = k_idx_even * 64 * 32;
Bf16_storage q_reg;
Bf16_storage k_reg;
q_reg.data_128 = q_data;
k_reg.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_read_ptr + k_offset);
// q_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(q_lds_read_ptr), k_offset, 2, 1, 0);
// k_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(k_lds_read_ptr), 0 * n_offset + k_offset, 2, 1, 0);
#if defined(__gfx938__)
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[0], k_reg.data_64[0], accs_f32[0], true,false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[1], k_reg.data_64[1], accs_f32[0], true,false);
#else
accs_f32[0] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[0], k_reg.data_64[0], accs_f32[0]);
accs_f32[0] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[1], k_reg.data_64[1], accs_f32[0]);
#endif
k_reg.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_read_ptr + k_offset + 1 * n_offset);
#if defined(__gfx938__)
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[0], k_reg.data_64[0], accs_f32[1], true,false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[1], k_reg.data_64[1], accs_f32[1], true,false);
#else
accs_f32[1] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[0], k_reg.data_64[0], accs_f32[1]);
accs_f32[1] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[1], k_reg.data_64[1], accs_f32[1]);
#endif
// k_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(k_lds_read_ptr), 1 * n_offset + k_offset, 2, 1, 0);
k_reg.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_read_ptr + k_offset + 2 * n_offset);
// k_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(k_lds_read_ptr), 2 * n_offset + k_offset, 2, 1, 0);
#if defined(__gfx938__)
accs_f32[2] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[0], k_reg.data_64[0], accs_f32[2], true,false);
accs_f32[2] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[1], k_reg.data_64[1], accs_f32[2], true,false);
#else
accs_f32[2] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[0], k_reg.data_64[0], accs_f32[2]);
accs_f32[2] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[1], k_reg.data_64[1], accs_f32[2]);
#endif
k_reg.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_read_ptr + k_offset + 3 * n_offset);
// k_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(k_lds_read_ptr), 3 * n_offset + k_offset, 2, 1, 0);
#if defined(__gfx938__)
accs_f32[3] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[0], k_reg.data_64[0], accs_f32[3], true,false);
accs_f32[3] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[1], k_reg.data_64[1], accs_f32[3], true,false);
#else
accs_f32[3] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[0], k_reg.data_64[0], accs_f32[3]);
accs_f32[3] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[1], k_reg.data_64[1], accs_f32[3]);
#endif
}
typedef __bf16 __fp16x4_t __attribute__((ext_vector_type(4)));
template<int k_idx, int n_idx_val>
__forceinline__ __device__ void pv_gemm(const __fp16x4_t& p, int v_lds_read_ptr, v4f* acco_f32)
{
constexpr int k_idx_even = k_idx % 1;
constexpr int n_offset = 16 * 32 * 2;
typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8)));
union Bf16_storage {
__fp16x8_t data_128;
__fp16x4_t data_64[2];
uint16_t data_array[8];
};
constexpr int k_offset = k_idx_even * 16 * 512 * 2;
// #if 1
Bf16_storage v_reg;
v_reg.data_128 = __builtin_amdgcn_ds_read_m32x16f16_alt((__attribute__((address_space(3))) __fp16*)(v_lds_read_ptr), k_offset + n_idx_val * n_offset);
#if defined(__gfx938__)
acco_f32[n_idx_val * 2] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(p, v_reg.data_64[0], acco_f32[n_idx_val * 2], true, false);
acco_f32[n_idx_val * 2 + 1] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(p, v_reg.data_64[1], acco_f32[n_idx_val * 2 + 1], true, false);
#else
acco_f32[n_idx_val * 2] = __builtin_amdgcn_mmac_f32_16x16x16bf16(p, v_reg.data_64[0], acco_f32[n_idx_val * 2]);
acco_f32[n_idx_val * 2 + 1] = __builtin_amdgcn_mmac_f32_16x16x16bf16(p, v_reg.data_64[1], acco_f32[n_idx_val * 2 + 1]);
#endif
}
}
\ No newline at end of file
......@@ -77,7 +77,7 @@ if __name__ == '__main__':
(1840, 256),
(1592, 384),
(1521, 512),
(3000, 2048),
# Irregular shapes with OOB TopK
(95, 128),
(153, 256),
......@@ -146,6 +146,7 @@ if __name__ == '__main__':
performance_case_templates = [
# V3.2
(576, 128, 2048, [8192, 32768, 65536, 98304, 131072]),
(576, 64, 2048, [8192, 32768, 65536, 98304, 131072]),
# MODEL1 CONFIG1
(512, 64, 512, [8192, 32768, 49152, 65536]),
# MODEL1 CONFIG2
......@@ -154,9 +155,10 @@ if __name__ == '__main__':
]
performance_cases = [
TestParam(s_q, s_kv, topk, h_q=h_q, d_qk=d_qk, have_attn_sink=True)
TestParam(s_q, s_kv, topk, h_q=h_q, d_qk=d_qk, have_attn_sink=have_attn_sink)
for (d_qk, h_q, topk, s_kv_list) in performance_case_templates
for s_q in [4096]
for have_attn_sink in [False, True]
for s_kv in s_kv_list
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment