Commit a1eef562 authored by shenzhe's avatar shenzhe Committed by zhanghj2
Browse files

Add DSA MLS sparse prefill dispatch

parent 4e0bdf6e
#pragma once #pragma once
#include <cstdlib>
#include "common.h" #include "common.h"
#include "params.h" #include "params.h"
#include "gfx93/prefill/sparse/dsa_mls/fwd.h"
#include "gfx93/prefill/sparse/phase1.h" #include "gfx93/prefill/sparse/phase1.h"
...@@ -39,6 +42,12 @@ class Fwd_Sm90_Impl : public FwdImplBase { ...@@ -39,6 +42,12 @@ class Fwd_Sm90_Impl : public FwdImplBase {
protected: protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override { void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
if ((std::getenv("FLASH_MLA_FORCE_DSA_MLS_PREFILL") != nullptr && gfx93::fwd::dsa_mls::can_run(params)) ||
gfx93::fwd::dsa_mls::should_run(params)) {
gfx93::fwd::dsa_mls::run(params);
return;
}
DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() { DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {
DISPATCH_BOOLEAN_FLAG(params.topk_length != nullptr, HAVE_TOPK_LENGTH, [&]() { DISPATCH_BOOLEAN_FLAG(params.topk_length != nullptr, HAVE_TOPK_LENGTH, [&]() {
gfx93::fwd::run_fwd_phase1_kernel<HEAD_DIM_QK, HAVE_TOPK_LENGTH>(params); gfx93::fwd::run_fwd_phase1_kernel<HEAD_DIM_QK, HAVE_TOPK_LENGTH>(params);
......
#pragma once
#include <algorithm>
#include <hip/hip_runtime.h>
#include "legacy/include/flash.h"
#include "legacy/include/kernel_traits.h"
#include "legacy/include/static_switch.h"
#include "legacy/src/flash_fwd_b16_mla.h"
namespace gfx93::fwd::dsa_mls {
template<typename T, int Headdim, int HeaddimV>
void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStream_t stream) {
constexpr int kBlockM = 64;
constexpr int kBlockN = 64;
constexpr int WARP_M = 16;
dim3 dimBlock;
dimBlock.x = std::min((kBlockM / WARP_M) * 64, 1024);
dimBlock.y = 1;
dimBlock.z = 1;
dim3 dimGrid;
dimGrid.x = (params.seqlen_q + kBlockM - 1) / kBlockM;
dimGrid.y = 1;
dimGrid.z = params.b;
using Kernel_traits = Flash_fwd_kernel_traits<
Headdim, HeaddimV, kBlockM, kBlockN, 32, WARP_M, 64, 2,
false, false, T, T>;
constexpr bool Is_dropout = false;
constexpr bool IsEvenMNConst = false;
BOOL_SWITCH(params.mtp > 1, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (params.topk == 2048) {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64<
Kernel_traits, true, Is_dropout, false, Is_causal,
IsEvenMNConst, true, false, Is_MTP, 0, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
} else {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64_topk1024<
Kernel_traits, true, Is_dropout, false, Is_causal,
IsEvenMNConst, true, false, Is_MTP, 0, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
}
});
});
}
} // namespace gfx93::fwd::dsa_mls
#include "fwd.h"
#include <cstring>
#include <stdexcept>
#include <string>
#include "dispatch.h"
namespace gfx93::fwd::dsa_mls {
bool can_run(const SparseAttnFwdParams& params) {
if (params.d_v != 512) return false;
if (params.d_qk != 512 && params.d_qk != 576) return false;
if (params.h_kv != 1) return false;
if (params.h_q != 64 && params.h_q != 128) return false;
if (!(params.topk <= 1024 || params.topk == 2048)) return false;
if (params.topk == 2048 && (params.attn_sink != nullptr || params.topk_length != nullptr)) return false;
return true;
}
bool should_run(const SparseAttnFwdParams& params) {
if (!can_run(params)) return false;
if (params.d_qk == 512 &&
((params.h_q == 64 && params.topk == 512) ||
(params.h_q == 128 && params.topk == 1024))) {
return true;
}
if (params.d_qk == 576 && params.h_q == 64 && params.topk == 2048 && params.s_kv >= 32768) {
return true;
}
return false;
}
static Flash_fwd_mla_params_dsa make_legacy_params(const SparseAttnFwdParams& src) {
if (!can_run(src)) {
throw std::runtime_error(
"DSA MLS sparse prefill only supports d_qk=512/576, d_v=512, "
"h_kv=1, h_q=64/128, topk<=1024 or topk=2048 without attn_sink/topk_length");
}
Flash_fwd_mla_params_dsa dst;
std::memset(&dst, 0, sizeof(dst));
dst.layout = 1;
dst.b = 1;
dst.h = 1;
dst.h_k = src.h_kv;
dst.h_h_k_ratio = dst.h / dst.h_k;
dst.mtp = 1;
dst.ngroups = src.h_q / src.h_kv;
dst.topk = src.topk;
dst.d = src.d_qk;
dst.d_v = src.d_v;
dst.scale_softmax = src.sm_scale;
dst.scale_softmax_log2 = src.sm_scale_div_log2;
dst.cu_seqlens_q = nullptr;
dst.cu_seqlens_k = nullptr;
dst.cu_seqlens_k_new = nullptr;
dst.topk_length = src.topk_length;
dst.attn_sink = src.attn_sink;
dst.q_ptr = src.q;
dst.k_ptr = src.kv;
dst.v_ptr = src.kv;
dst.o_ptr = src.out;
dst.sparse_indices = src.indices;
dst.softmax_lse_ptr = src.lse;
dst.scores_max_ptr = src.max_logits;
dst.scores_sum_ptr = nullptr;
dst.block_table = nullptr;
dst.block_table_batch_stride = 0;
dst.page_block_size = 0;
dst.is_causal = false;
dst.q_batch_stride = 0;
dst.q_token_stride = src.stride_q_s_q;
dst.q_head_stride = src.stride_q_h_q;
dst.q_row_stride = dst.q_head_stride;
dst.k_batch_stride = 0;
dst.k_row_stride = src.stride_kv_s_kv;
dst.k_head_stride = src.stride_kv_h_kv;
dst.v_batch_stride = 0;
dst.v_row_stride = src.stride_kv_s_kv;
dst.v_head_stride = src.stride_kv_h_kv;
dst.o_batch_stride = 0;
dst.o_row_stride = src.h_q * src.d_v;
dst.o_head_stride = src.d_v;
dst.sparse_indices_batch_stride = 0;
dst.sparse_indices_row_stride = src.stride_indices_s_q;
dst.sparse_indices_head_stride = src.stride_indices_h_kv;
dst.sparse_indices_topk_stride = 1;
dst.seqlen_q = src.s_q * dst.ngroups;
dst.seqlen_k = src.s_kv;
dst.max_seqlen = src.s_q;
dst.is_bf16 = true;
dst.is_e4m3 = false;
dst.is_int8 = false;
dst.cu_count = src.num_sm;
dst.seqlenq_ngroups_swapped = true;
dst.is_seqlens_k_cumulative = false;
dst.splitkv_use_fp32_as_accum = false;
dst.num_splits = 0;
dst.partition_size = src.topk;
return dst;
}
void run(const SparseAttnFwdParams& params) {
Flash_fwd_mla_params_dsa legacy_params = make_legacy_params(params);
hipStream_t stream = reinterpret_cast<hipStream_t>(params.stream);
if (params.d_qk == 512) {
run_dsa_prefill_nopage_64_dispatch<BFloat16, 512, 512>(legacy_params, stream);
} else if (params.d_qk == 576) {
run_dsa_prefill_nopage_64_dispatch<BFloat16, 576, 512>(legacy_params, stream);
} else {
throw std::runtime_error("Unsupported d_qk value in DSA MLS sparse prefill");
}
}
} // namespace gfx93::fwd::dsa_mls
#pragma once
#include "../../../../params.h"
namespace gfx93::fwd::dsa_mls {
bool can_run(const SparseAttnFwdParams& params);
bool should_run(const SparseAttnFwdParams& params);
void run(const SparseAttnFwdParams& params);
} // namespace gfx93::fwd::dsa_mls
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Varlen=true, bool Is_Kvcache=false, bool USE_BSHD_LAYOUT = false>
struct BlockInfo {
template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb)
: sum_s_q((!Varlen || params.cu_seqlens_q == nullptr) ? -1 : params.cu_seqlens_q[bidb])
, sum_s_k((!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative) ? -1 : params.cu_seqlens_k[bidb])
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr || Is_Kvcache ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(seqlen_k_cache/* + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)*/)
, nheads(params.h)
, nheads_k(params.h_k)
, leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])
{
}
template <typename index_t>
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride;
}
inline __device__ int q_offset1(const int batch_stride, const int row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : (USE_BSHD_LAYOUT ? uint32_t(sum_s_q) * row_stride : uint32_t(sum_s_q) * row_stride * nheads);
}
inline __device__ int k_offset1(const int batch_stride, const int row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : (USE_BSHD_LAYOUT ? uint32_t(sum_s_k) * row_stride : uint32_t(sum_s_k) * row_stride * nheads_k);
}
inline __device__ int k_offset1_write(const int batch_stride, const int row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : (USE_BSHD_LAYOUT ? uint32_t(sum_s_k) * row_stride : uint32_t(sum_s_k) * row_stride * nheads);
}
inline __device__ int q_offset2(const int head_stride, const int bidh) const {
return (USE_BSHD_LAYOUT || sum_s_q == -1) ? bidh * head_stride : uint32_t(actual_seqlen_q) * head_stride * bidh;
}
inline __device__ int k_offset2(const int head_stride, const int bidh) const {
return (USE_BSHD_LAYOUT || sum_s_k == -1) ? bidh * head_stride : uint32_t(actual_seqlen_k) * head_stride *bidh;
}
const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const int leftpad_k;
const int seqlen_k_cache;
int actual_seqlen_k;
const int nheads;
const int nheads_k;
};
// Simplified blockinfo for tranditional varlen fwd inference
template<bool USE_BSHD_LAYOUT=false>
struct SimplifyBlockInfo {
template<typename Params>
__device__ SimplifyBlockInfo(const Params &params, const int bidb)
: sum_s_q(params.cu_seqlens_q[bidb])
, sum_s_k(params.cu_seqlens_k[bidb])
, actual_seqlen_q(params.cu_seqlens_q[bidb + 1] - sum_s_q)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache((params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(seqlen_k_cache/* + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)*/)
, nheads(params.h)
, nheads_k(params.h_k)
// , leftpad_k(0)
{
}
inline __device__ int q_offset1(const int batch_stride, const int row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : (USE_BSHD_LAYOUT ? uint32_t(sum_s_q) * row_stride : uint32_t(sum_s_q) * row_stride * nheads);
}
inline __device__ int k_offset1(const int batch_stride, const int row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : (USE_BSHD_LAYOUT ? uint32_t(sum_s_k) * row_stride : uint32_t(sum_s_k) * row_stride * nheads_k);
}
inline __device__ int q_offset2(const int head_stride, const int bidh) const {
return (USE_BSHD_LAYOUT || sum_s_q == -1) ? bidh * head_stride : uint32_t(actual_seqlen_q) * head_stride * bidh;
}
inline __device__ int k_offset2(const int head_stride, const int bidh) const {
return (USE_BSHD_LAYOUT || sum_s_k == -1) ? bidh * head_stride : uint32_t(actual_seqlen_k) * head_stride *bidh;
}
const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
// const int leftpad_k;
const int seqlen_k_cache;
int actual_seqlen_k;
const int nheads;
const int nheads_k;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SafeDecodeBlockInfo {
__device__ SafeDecodeBlockInfo() = default;
template<typename Params, bool Is_Q_varlen, bool Is_K_Cumulative>
__device__ void set_params(const Params &params, const int bidb) {
// process Q
if constexpr (Is_Q_varlen) { // Is_Q_varlen also means Is_Q_Cumulative = true
this->sum_s_q = params.cu_seqlens_q[bidb];
this->actual_seqlen_q = params.cu_seqlens_q[bidb + 1] - this->sum_s_q;
} else {
this->actual_seqlen_q = params.seqlen_q;
}
// process KV
if constexpr (Is_K_Cumulative) {
this->sum_s_k = params.cu_seqlens_k[bidb];
this->actual_seqlen_k = params.cu_seqlens_k[bidb + 1] - sum_s_k;
} else {
this->actual_seqlen_k = params.cu_seqlens_k[bidb];
}
}
int sum_s_q;
int sum_s_k;
int actual_seqlen_q;
int actual_seqlen_k;
};
} // namespace flash
#pragma once
#include <block_info.h>
#include "utils.h"
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template<bool Clear_dQaccum=true, bool Is_even_MN, class Element, class ElementAccumType, int kBlockM_, int kBlockN_, int WARP_M_, int WARP_N_, int kHeadDim_, int STAGES, bool USE_BSHD_LAYOUT, typename Params>
inline __device__ void compute_dot_do_o(const Params &params) {
Element *do_ptr = static_cast<Element*>(params.do_ptr);
Element *o_ptr = static_cast<Element*>(params.o_ptr);
ElementAccumType* dsoftmax_sum = static_cast<ElementAccumType*>(params.dsoftmax_sum);
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.z;
// The block index for the head.
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;
//wave size should be defined in launch file. Here use 64 threads
int lane_id = threadIdx.x & 63; //lane id, 0-63
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id =0;
__shared__ Element do_lds[STAGES*(kBlockM_/32) * (kBlockN_/32)*(32*34)];
__shared__ Element o_lds[STAGES*(kBlockM_/32) * (kBlockN_/32)*(32*34)];
float dP_sum_cur[(kBlockM_/16)] = {0.0f};
int stage_id = 0;
constexpr int kBlockM = kBlockM_;
constexpr int kBlockN = kBlockN_;
constexpr int kHeadDim = kHeadDim_;
const int WARP_NUM = (kBlockM_)/(WARP_M_);
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
const int row_offset_do = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + m_block * kBlockM * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + m_block * kBlockM * seqlen_o_stride;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
// Element *gdO = reinterpret_cast<Element *>(do_ptr) + row_offset_do;
auto gdO = tcp_cache_swizzle_func<kHeadDim_, Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_do);
// Element *gO = reinterpret_cast<Element *>(o_ptr) + row_offset_o;
auto gO = tcp_cache_swizzle_func<kHeadDim_, Element>(reinterpret_cast<Element *>(o_ptr) + row_offset_o);
ElementAccumType *dP_sum = reinterpret_cast<ElementAccumType *>(dsoftmax_sum) + row_offset_dpsum;
asm volatile("v_readfirstlane_b32 %0,%1"
: "=s"(warp_id)
: "v"(warp_id_vec)
:);
vec2_Element<Element> do_reg[(kHeadDim_/kBlockN_)*((WARP_M_*kBlockN_)/(32*32))*2][4]; //ds_read mini size is 32*32,2 is seq, 4 is head dim
vec2_Element<Element> o_reg[(kHeadDim_/kBlockN_)*((WARP_M_*kBlockN_)/(32*32))*2][4]; //ds_read mini size is 32*32,2 is seq, 4 is head dim
// int A_lane_m_idx = (lane_id >> 4);
int do_lane_m_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1); //(0, 1, 2, 3) --> (0, 2, 1, 3)
int do_lane_head_dim_idx = (lane_id & 15);
//global->lds, left matrix
// printf("kBlockN_==%d, kHeadDim_=%d, WARP_M_=%d\n",kBlockN_, kHeadDim_, WARP_M_);
for(int k_loop=0; k_loop<kHeadDim_/kBlockN_; k_loop++) {
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
int do_block_buffer_load_global_offset = k_loop * kBlockN_;
// do_ptr buffer load mini size is 4*32, (kBlockM_ * kBlockN_) mini size is (32*32)
const int do_lds_load_num = (kBlockM_ * kBlockN_) / (4*32);
int do_lds_stage_offset = stage_id * (kBlockM_/32) * (kBlockN_/32)*(32*34);
for(int warp_loop=warp_id; warp_loop < do_lds_load_num; warp_loop+=WARP_NUM) {
int padding = (warp_loop & 7)*2; // padding size in shared memory per buffer load, to avoid bank conflict
int do_warp_buffer_load_m_id = (warp_loop & (kBlockM_/4 - 1)); //这样子对L1和utlc1有啥影响呢?
int do_warp_buffer_load_k_id = (warp_loop / (kBlockM_/4));
int do_warp_buffer_load_lds_offset = do_lds_stage_offset + (do_warp_buffer_load_k_id * kBlockM_ * 34) + ((do_warp_buffer_load_m_id >> 3)*(32*34) + (do_warp_buffer_load_m_id & 7)*(4*32)) ;
int do_warp_buffer_load_global_offset = (do_warp_buffer_load_k_id * 32);
int gsOffset = (do_block_buffer_load_global_offset + do_warp_buffer_load_global_offset)/2 ;
// int gvOffset = (do_lane_m_idx * kHeadDim_)/2 + do_lane_head_dim_idx;
int lds_offset = (do_warp_buffer_load_lds_offset + padding)/2;
{
int gvOffset;
if constexpr (!Is_even_MN) {
gvOffset = (min((do_lane_m_idx + (do_warp_buffer_load_m_id * 4)),binfo.actual_seqlen_q - m_block * kBlockM - 1) * seqlen_do_stride)/2 + do_lane_head_dim_idx;
} else {
gvOffset = ((do_lane_m_idx + (do_warp_buffer_load_m_id * 4)) * seqlen_do_stride)/2 + do_lane_head_dim_idx;
}
builtin_buffer_load_dword_lds(do_lds, gdO, lds_offset, gsOffset, gvOffset);
}
{
int gvOffset;
if constexpr (!Is_even_MN) {
gvOffset = (min((do_lane_m_idx + (do_warp_buffer_load_m_id * 4)),binfo.actual_seqlen_q - m_block * kBlockM - 1) * seqlen_o_stride)/2 + do_lane_head_dim_idx;
} else {
gvOffset = ((do_lane_m_idx + (do_warp_buffer_load_m_id * 4)) * seqlen_o_stride)/2 + do_lane_head_dim_idx;
}
builtin_buffer_load_dword_lds(o_lds, gO, lds_offset, gsOffset, gvOffset);
}
}
vmcnt_wait(0);
// By right we need to scale dP up by 1/params.p_dropout, but instead we don't and only scale the final
// results (dQ and dK) by 1/params.p_dropout. So we need to keep dP_sum scaled down by params.p_dropout here,
// so that (dP - dP_sum) is on the same scale.
{
//lds -> vgpr use ds_read_m; left matrix
int do_warp_m_id = (warp_id & ((kBlockM_/WARP_M_) - 1));
int do_lds_stage_offset = stage_id * (kBlockM_/32) * (kBlockN_/32)*(32*17);
vec2_Element<Element> *do_lds_v2fp16 = (vec2_Element<Element> *)(do_lds);
vec2_Element<Element> *o_lds_v2fp16 = (vec2_Element<Element> *)(o_lds);
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(kBlockN_/32); head_dim_idx++) { //32 half in col direction
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) {
//a warp load min size is (row, col) = (32,16) float
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { //sequence direction
#pragma unroll
for(int vec_id=0; vec_id<4; vec_id++) { //head dim direction
int lds_offset = do_lds_stage_offset + head_dim_idx*kBlockM_*17 + (warp_id*(WARP_M_/32) + m_idx)*(32*17) + vec_id*2 + min_tile_m*32 + (lane_id & 1)*16 + ((lane_id & 15)>>1)*64 + /*padding*/ ((lane_id & 15)>>1) + ((lane_id/16) &1)*8 + (lane_id/32);
inline_ds_read_b32_wait(do_lds_v2fp16, lds_offset, do_reg[/*(k_loop)*((WARP_M_*kBlockN_)/(32*32))*2 +*/ (head_dim_idx*(WARP_M_/32) + m_idx)*2 + min_tile_m][vec_id]);
}
}
}
}
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(kBlockN_/32); head_dim_idx++) { //32 half in col direction
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) {
//a warp load min size is (row, col) = (32,16) float
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { //sequence direction
#pragma unroll
for(int vec_id=0; vec_id<4; vec_id++) { //head dim direction
int lds_offset = do_lds_stage_offset + head_dim_idx*kBlockM_*17 + (warp_id*(WARP_M_/32) + m_idx)*(32*17) + vec_id*2 + min_tile_m*32 + (lane_id & 1)*16 + ((lane_id & 15)>>1)*64 + /*padding*/ ((lane_id & 15)>>1) + ((lane_id/16) &1)*8 + (lane_id/32);
inline_ds_read_b32_wait(o_lds_v2fp16, lds_offset, o_reg[/*(k_loop)*((WARP_M_*kBlockN_)/(32*32))*2 +*/ (head_dim_idx*(WARP_M_/32) + m_idx)*2 + min_tile_m][vec_id]);
}
}
}
}
}
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < (kBlockN/32); ++head_dim_idx) {
#pragma unroll
for(int vec_id = 0; vec_id<4; vec_id++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
if (Is_even_MN || (m_block * kBlockM + mi*32 + min_tile_m + (threadIdx.x & 15)*2) < binfo.actual_seqlen_q) {
dP_sum_cur[mi*2 + min_tile_m] += UpCast<Element,float,true>(do_reg[(head_dim_idx*(WARP_M_/32) + mi)*2 + min_tile_m][vec_id][min_tile_n]) * UpCast<Element,float,true>(o_reg[(head_dim_idx*(WARP_M_/32) + mi)*2 + min_tile_m][vec_id][min_tile_n]);
}
}
}
}
}
}
}
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
flash::SumOp<float> sum_op;
dP_sum_cur[mi*2 + min_tile_m] = flash::Allreduce<64>::run(dP_sum_cur[mi*2 + min_tile_m], sum_op) * params.p_dropout;
if ((threadIdx.x >> 4) == 0) {
dP_sum[mi*32 + min_tile_m + (threadIdx.x & 15)*2] = dP_sum_cur[mi*2 + min_tile_m];
}
}
}
}
\ No newline at end of file
#pragma once
#include <block_info.h>
#include "utils.h"
#include "prefetch.h"
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template<bool Clear_dQaccum=true, bool Is_even_MN, class Element, class ElementAccum, int kBlockM, int kBlockN, int WARP_M, int WARP_N, int K, int STAGES, bool USE_BSHD_LAYOUT, typename Params>
inline __device__ void compute_dot_do_o_gfx938(const Params &params) {
Element *do_ptr = static_cast<Element*>(params.do_ptr);
Element *o_ptr = static_cast<Element*>(params.o_ptr);
ElementAccum* dsoftmax_sum = static_cast<ElementAccum*>(params.dsoftmax_sum);
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.z;
// The block index for the head.
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;
//wave size should be defined in launch file. Here use 64 threads
int lane_id = threadIdx.x & 63; //lane id, 0-63
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = 0;
__shared__ Element dO_lds[kBlockM * kBlockN];
__shared__ Element O_lds[kBlockM * kBlockN];
float dP_sum_cur[(kBlockM/16)] = {0.0f};
const int WARP_NUM = (kBlockM)/(WARP_M);
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
const int row_offset_do = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + m_block * kBlockM * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + m_block * kBlockM * seqlen_o_stride;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
auto gdO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_do, seqlen_do_stride);
auto gO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(o_ptr) + row_offset_o, seqlen_o_stride);
ElementAccum *dP_sum = reinterpret_cast<ElementAccum *>(dsoftmax_sum) + row_offset_dpsum;
warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
union_vec4_f16x2<Element> dO_reg[((WARP_M*kBlockN)/(32*32))*2];
union_vec4_f16x2<Element> O_reg[((WARP_M*kBlockN)/(32*32))*2];
for(int k_loop=0; k_loop<K/kBlockN; k_loop++) {
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
int do_block_buffer_load_global_offset = k_loop * kBlockN;
//read 32 * 128
prefetch_to_lds_gfx938<true, kBlockM, kBlockN, Element, ElementAccum, Is_even_MN, 1>(gdO, do_block_buffer_load_global_offset, dO_lds, binfo.actual_seqlen_q - m_block * kBlockM, warp_id);
prefetch_to_lds_gfx938<true, kBlockM, kBlockN, Element, ElementAccum, Is_even_MN, 1>(gO, do_block_buffer_load_global_offset, O_lds, binfo.actual_seqlen_q - m_block * kBlockM, warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
for(int i = 0; i < kBlockN / 32; ++i) {
DS_READ_MATRIX_32X32_B16(ds_offset_cast(dO_lds + i * 32 * 32), dO_reg[i * 2 + 0].f16, dO_reg[i * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(ds_offset_cast(O_lds + i * 32 * 32), O_reg[i * 2 + 0].f16, O_reg[i * 2 + 1].f16, true);
// if constexpr (std::is_same_v<Element, half_t>) {
// dO_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(dO_lds + i * 32 * 32, 0, 2, 1, 0);
// dO_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(dO_lds + i * 32 * 32, 1024, 2, 1, 0);
// O_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(O_lds + i * 32 * 32, 0, 2, 1, 0);
// O_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(O_lds + i * 32 * 32, 1024, 2, 1, 0);
// } else {
// dO_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(dO_lds + i * 32 * 32, 0, 2, 1, 0);
// dO_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(dO_lds + i * 32 * 32, 1024, 2, 1, 0);
// O_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(O_lds + i * 32 * 32, 0, 2, 1, 0);
// O_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(O_lds + i * 32 * 32, 1024, 2, 1, 0);
// }
}
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < (kBlockN/32); ++head_dim_idx) {
#pragma unroll
for(int vec_id = 0; vec_id<4; vec_id++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
if (Is_even_MN || (m_block * kBlockM + min_tile_m*16 + (threadIdx.x & 15)) < binfo.actual_seqlen_q) {
dP_sum_cur[min_tile_m] += UpCast<Element,float,false>(dO_reg[head_dim_idx*2 + min_tile_m].f16[vec_id * 2 + min_tile_n]) * UpCast<Element,float,false>(O_reg[head_dim_idx*2 + min_tile_m].f16[vec_id * 2 + min_tile_n]);
}
}
}
}
}
}
#pragma unroll
for (int mi = 0; mi < (WARP_M/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
flash::SumOp<float> sum_op;
dP_sum_cur[mi*2 + min_tile_m] = flash::Allreduce<64>::run(dP_sum_cur[mi*2 + min_tile_m], sum_op) * params.p_dropout;
if ((threadIdx.x >> 4) == 0) {
dP_sum[mi*32 + min_tile_m * 16 + (threadIdx.x & 15)] = dP_sum_cur[mi*2 + min_tile_m];
}
}
}
}
#pragma once
#include <block_info.h>
#include "utils.h"
#include "prefetch.h"
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template<bool Clear_dQaccum=true, bool Is_even_MN, class Element, class ElementAccum, int kBlockM, int kBlockN, int WARP_M, int WARP_N, int K, int STAGES, bool USE_BSHD_LAYOUT, typename Params>
inline __device__ void compute_dot_do_o_gfx946(const Params &params) {
Element *do_ptr = static_cast<Element*>(params.do_ptr);
Element *o_ptr = static_cast<Element*>(params.o_ptr);
ElementAccum* dsoftmax_sum = static_cast<ElementAccum*>(params.dsoftmax_sum);
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.z;
// The block index for the head.
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;
//wave size should be defined in launch file. Here use 64 threads
int lane_id = threadIdx.x & 63; //lane id, 0-63
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = 0;
__shared__ Element dO_lds[kBlockM * kBlockN];
__shared__ Element O_lds[kBlockM * kBlockN];
float dP_sum_cur[(kBlockM/16)] = {0.0f};
const int WARP_NUM = (kBlockM)/(WARP_M);
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
const int row_offset_do = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + m_block * kBlockM * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + m_block * kBlockM * seqlen_o_stride;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
auto gdO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_do, seqlen_do_stride);
auto gO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(o_ptr) + row_offset_o, seqlen_o_stride);
ElementAccum *dP_sum = reinterpret_cast<ElementAccum *>(dsoftmax_sum) + row_offset_dpsum;
asm volatile("v_readfirstlane_b32 %0,%1"
: "=s"(warp_id)
: "v"(warp_id_vec)
:);
union_vec4_f16x2<Element> dO_reg[((WARP_M*kBlockN)/(32*32))*2];
union_vec4_f16x2<Element> O_reg[((WARP_M*kBlockN)/(32*32))*2];
for(int k_loop=0; k_loop<K/kBlockN; k_loop++) {
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
int do_block_buffer_load_global_offset = k_loop * kBlockN;
//read 32 * 128
prefetch_to_lds_gfx938<true, kBlockM, kBlockN, Element, ElementAccum, Is_even_MN, 1>(gdO, do_block_buffer_load_global_offset, dO_lds, binfo.actual_seqlen_q - m_block * kBlockM, warp_id);
prefetch_to_lds_gfx938<true, kBlockM, kBlockN, Element, ElementAccum, Is_even_MN, 1>(gO, do_block_buffer_load_global_offset, O_lds, binfo.actual_seqlen_q - m_block * kBlockM, warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
for(int i = 0; i < kBlockN / 32; ++i) {
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(dO_lds + i * 32 * 32), dO_reg[i * 2 + 0].f16, dO_reg[i * 2 + 1].f16, true);
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(O_lds + i * 32 * 32), O_reg[i * 2 + 0].f16, O_reg[i * 2 + 1].f16, true);
if constexpr (std::is_same_v<Element, half_t>) {
dO_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(dO_lds + i * 32 * 32, 0, 2, 1, 0);
dO_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(dO_lds + i * 32 * 32, 1024, 2, 1, 0);
O_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(O_lds + i * 32 * 32, 0, 2, 1, 0);
O_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(O_lds + i * 32 * 32, 1024, 2, 1, 0);
} else {
dO_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(dO_lds + i * 32 * 32, 0, 2, 1, 0);
dO_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(dO_lds + i * 32 * 32, 1024, 2, 1, 0);
O_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(O_lds + i * 32 * 32, 0, 2, 1, 0);
O_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(O_lds + i * 32 * 32, 1024, 2, 1, 0);
}
}
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < (kBlockN/32); ++head_dim_idx) {
#pragma unroll
for(int vec_id = 0; vec_id<4; vec_id++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
if (Is_even_MN || (m_block * kBlockM + min_tile_m*16 + (threadIdx.x & 15)) < binfo.actual_seqlen_q) {
dP_sum_cur[min_tile_m] += UpCast<Element,float,false>(dO_reg[head_dim_idx*2 + min_tile_m].f16[vec_id * 2 + min_tile_n]) * UpCast<Element,float,false>(O_reg[head_dim_idx*2 + min_tile_m].f16[vec_id * 2 + min_tile_n]);
}
}
}
}
}
}
#pragma unroll
for (int mi = 0; mi < (WARP_M/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
flash::SumOp<float> sum_op;
dP_sum_cur[mi*2 + min_tile_m] = flash::Allreduce<64>::run(dP_sum_cur[mi*2 + min_tile_m], sum_op) * params.p_dropout;
if ((threadIdx.x >> 4) == 0) {
dP_sum[mi*32 + min_tile_m * 16 + (threadIdx.x & 15)] = dP_sum_cur[mi*2 + min_tile_m];
}
}
}
}
\ No newline at end of file
#include <iostream>
#include <memory>
#include <vector>
#include <random>
#include <fstream>
#include <stdlib.h>
#include <dirent.h>
#include <unistd.h>
#include <sys/stat.h>
#include "assert.h"
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "flash.h"
#include "utils.h"
#include "wait.h"
#include "../numeric_types.h"
#include "philox.cuh"
#include "softmax_tiling.h"
#include "gpu_gemm_nn.h"
#include "gpu_gemm_tt.h"
#include "intrinsic.h"
#include "intrinsic_mls_ds.h"
#include "static_switch.h"
#include "dot_do_o.h"
#include "dot_do_o_gfx938.h"
#include "dot_do_o_gfx946.h"
#include "prefetch.h"
#include "flash_singleton.h"
#include "flash_attention_dv_dk_bwd.h"
#include "flash_attention_dv_dk_bwd_gfx938.h"
#include "flash_attention_dv_dk_bwd_gfx946.h"
#include "flash_attention_dq_bwd.h"
#include "flash_attention_dq_bwd_gfx938.h"
#include "flash_attention_dq_bwd_gfx946.h"
using std::make_shared;
using std::shared_ptr;
template <int kBlockM_, int kBlockN_, int WARP_M_, int WARP_N_, typename Element>
inline __device__ void reshape(Element* smem, vec4_Element<Element> ds_reg_fp16[(WARP_N_/32)*(WARP_M_/32)][4], int warp_id) {
int lane_id = threadIdx.x & 63; //lane id, 0-63
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
int lds_offset = warp_id*(WARP_N_/32)*33*kBlockM_ + n_idx*33*kBlockM_ + m_idx*32*33 + min_tile_m*16*33 + vec_idx*4*33 + (lane_id>>4)*33 + min_tile_n*16 + (lane_id&15);
Element ds_reg_tmp = ds_reg_fp16[(WARP_N_/32)*m_idx + n_idx][min_tile_m*2 + min_tile_n][vec_idx];
{
smem[lds_offset] = ds_reg_fp16[(WARP_N_/32)*m_idx + n_idx][min_tile_m*2 + min_tile_n][vec_idx];
}
}
}
}
}
}
#pragma unroll
for(int m_idx=0; m_idx<(kBlockM_/32); m_idx++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
int lds_offset = warp_id*33*kBlockM_ + m_idx*32*33 + min_tile_m*16 + vec_idx*4 + (lane_id>>4) + min_tile_n*16*33 + (lane_id&15)*33;
ds_reg_fp16[(WARP_N_/32)*m_idx + n_idx][min_tile_m*2 + min_tile_n][vec_idx] = smem[lds_offset];
}
}
}
}
}
}
/*
* q_ptr: Transposed 32x16 matrix
* k_ptr: Non-transposed 32x16 matrix
* qk_ptr: Non-transposed 32x32 matrixseqlen_q
*/
template<class DataType>
int check_param(int seqlen_q, int seqlen_k, int K, int kBlockM_, int kBlockN_, int kBlockK_, int WARP_M_, int WARP_N_, dim3 blockDim, dim3 gridDim, int maxBlockThreads, int STAGES) {
// min warp size is 32x32
if(WARP_M_<32 || WARP_N_<32) {
std::cout<<"Error, WARP_M_<32 or WARP_N_<32!"<<std::endl;
assert(((WARP_M_>=32) && (WARP_N_>=32)));
}
// check block threads number
const int blockThreads = ((kBlockM_*kBlockN_)/(WARP_M_*WARP_N_)*64);
if(blockThreads > maxBlockThreads) {
std::cout<<"Error,Block threads is greater than maxBlockThreads! "<<std::endl;
assert(blockThreads <= maxBlockThreads);
}
//check lds data numbers
int DataTypeSize = sizeof(DataType);
const int q_lds_size = STAGES * kBlockM_ * kBlockK_ * DataTypeSize;
const int k_lds_size = STAGES * kBlockN_ * kBlockK_ * DataTypeSize;
if(((q_lds_size + k_lds_size)/1024) > 64) {
std::cout<<"Error, shared memory size is greater than 64KB"<<std::endl;
assert(((q_lds_size + k_lds_size)/1024) <= 64); //BW lds 64KB
}
}
#ifdef DEBUGING
#define print_qk(block_id_m, bidb, bidh) {\
int qk_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = qk_global_offset + block_n_idx * WARP_N_ + lane_id%16*2*params.seqlen_k + lane_id/16*2 + warp_m_idx * params.seqlen_k + warp_n_idx + vec_idx * 8; \
kq_ptr[offset] = s_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_softmax_rescale_o(block_id_m, bidb, bidh) { \
int qk_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = qk_global_offset + block_n_idx * WARP_N_ + lane_id%16*2*params.seqlen_k + lane_id/16*2 + warp_m_idx * params.seqlen_k + warp_n_idx + vec_idx * 8; \
s_ptr[offset] = s_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_dp(block_id_m, bidb, bidh) {\
int dp_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int dp_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + dp_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = dp_global_offset + block_n_idx * WARP_N_ + lane_id%16*2*params.seqlen_k + lane_id/16*2 + warp_m_idx * params.seqlen_k + warp_n_idx + vec_idx * 8; \
dp_ptr[offset] = dp_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_ds(block_id_m, bidb, bidh) {\
int ds_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int ds_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = ds_global_offset + block_n_idx * WARP_N_ + lane_id%16*2*params.seqlen_k + lane_id/16*2 + warp_m_idx * params.seqlen_k + warp_n_idx + vec_idx * 8; \
ds_ptr[offset] = dS_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#endif
template<class Element, class ElementAccum, bool Is_dropout, bool Is_causal , bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel, int kBlockM_, int kBlockN_, int K, int K_v, int kBlockK_, int WARP_M_, int WARP_N_, int STAGES, int USE_BSHD_LAYOUT, typename Params>
__forceinline__ __device__ void compute_dq_1colblock(Params &params, int bidb, int bidh, int m_block
) {
#ifdef DEBUGING
ElementAccum * kq_ptr = static_cast<ElementAccum*>(params.kq_ptr);
ElementAccum * s_ptr = static_cast<ElementAccum*>(params.s_ptr);
ElementAccum * dp_ptr = static_cast<ElementAccum*>(params.dp_ptr);
ElementAccum * ds_ptr = static_cast<ElementAccum*>(params.ds_ptr);
#endif
Element* q_ptr = static_cast<Element*>(params.q_ptr);
Element* k_ptr = static_cast<Element*>(params.k_ptr);
Element* v_ptr = static_cast<Element*>(params.v_ptr);
Element* o_ptr = static_cast<Element*>(params.o_ptr);
Element* dq_ptr = static_cast<Element*>(params.dq_ptr);
Element* dk_ptr = static_cast<Element*>(params.dk_ptr);
Element* dv_ptr = static_cast<Element*>(params.dv_ptr);
Element* do_ptr = static_cast<Element*>(params.do_ptr);
ElementAccum* softmax_lse_ptr = static_cast<ElementAccum*>(params.softmax_lse_ptr);
ElementAccum* dsoftmax_sum = static_cast<ElementAccum*>(params.dsoftmax_sum);
//flash-attention QK, kBlockN_==WARP_N_;
const int M_BLOCK_NUM = params.seqlen_q/kBlockM_;
const int N_BLOCK_NUM = params.seqlen_k/kBlockN_;
extern __shared__ Element smem[];
#if 1//defined(__gfx936__)
const bool Is_store_K = true;
const bool Is_preload_K = true;
const bool Is_preload_V = true;
#else
const bool Is_store_K = false;
const bool Is_preload_K = false;
const bool Is_preload_V = false;
#endif
const int K_prefetch_level = Is_preload_K ? 1 : 0;
const int V_prefetch_level = Is_preload_V ? 1 : 0;
const int Q_prefetch_level = 3;
Element* K_lds = (Element*)&(smem);
Element* Q_lds = (Element*)&(smem);
Element* dO_lds = (Element*)&(smem);
Element* V_lds = (Element*)&(smem) + (kBlockN_/32)*(K/32)*(32*34);//(Is_preload_K || Is_store_K) ? (Element*)&(smem) + (kBlockN_/32)*(K/32)*(32*34) : (Element*)&(smem);
int tidx = threadIdx.x;
int lane_id = threadIdx.x & 63; //lane id, 0-63
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (m_block < 0 || m_block * kBlockM_ >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM_ - params.window_size_left) / kBlockN_);
const int n_block_max = (!Is_causal && !Is_local) ? ceil_div(binfo.actual_seqlen_k, kBlockN_) : std::min(ceil_div(binfo.actual_seqlen_k, kBlockN_), flash::ceil_div((m_block + 1) * kBlockM_ + params.window_size_right, kBlockN_));
int seqlen_q_stride = params.q_row_stride;
int seqlen_k_stride = params.k_row_stride;
int seqlen_v_stride = params.v_row_stride;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
int seqlen_dq_stride = params.dq_row_stride;
// We move K and V to the last block.
const int row_offset_q = binfo.q_offset1(params.q_batch_stride, params.q_row_stride, bidb) + binfo.q_offset2(params.q_head_stride,bidh) + m_block * kBlockM_ * seqlen_q_stride;
const int row_offset_k = binfo.k_offset1(params.k_batch_stride, params.k_row_stride, bidb) + binfo.k_offset2(params.k_head_stride,bidh/params.h_h_k_ratio) + (n_block_max - 1) * kBlockN_ * seqlen_k_stride;
const int row_offset_v = binfo.k_offset1(params.v_batch_stride, params.v_row_stride, bidb) + binfo.k_offset2(params.v_head_stride,bidh/params.h_h_k_ratio) + (n_block_max - 1) * kBlockN_ * seqlen_v_stride;
const int row_offset_dO = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + m_block * kBlockM_ * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + m_block * kBlockM_ * seqlen_o_stride;
const int row_offset_dq = binfo.q_offset1(params.dq_batch_stride, params.dq_row_stride, bidb) + binfo.q_offset2(params.dq_head_stride,bidh) + m_block * kBlockM_ * seqlen_dq_stride;
const int row_offset_lse = params.cu_seqlens_q == nullptr ? (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM_ : bidh * params.total_q + binfo.sum_s_q + m_block * kBlockM_;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM_;
// Element * gQ = reinterpret_cast<Element *>(q_ptr) + row_offset_q;
auto gQ = tcp_cache_swizzle_func<K, Element>(reinterpret_cast<Element *>(q_ptr) + row_offset_q);
// Element * gK = reinterpret_cast<Element *>(k_ptr) + row_offset_k;
auto gK = tcp_cache_swizzle_func<K, Element>(reinterpret_cast<Element *>(k_ptr) + row_offset_k);
// Element * gV = reinterpret_cast<Element *>(v_ptr) + row_offset_v;
auto gV = tcp_cache_swizzle_func<K_v, Element>(reinterpret_cast<Element *>(v_ptr) + row_offset_v);
// Element * gdO = reinterpret_cast<Element *>(do_ptr) + row_offset_dO;
auto gdO = tcp_cache_swizzle_func<K_v, Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_dO);
Element * gO = reinterpret_cast<Element *>(o_ptr) + row_offset_o;
dq_ptr = reinterpret_cast<Element *>(dq_ptr) + row_offset_dq;
auto gdQ = tcp_cache_swizzle_func<K, Element>(reinterpret_cast<Element *>(dq_ptr));
ElementAccum *gLSE = reinterpret_cast<ElementAccum *>(softmax_lse_ptr) + row_offset_lse;
ElementAccum *gdPsum = reinterpret_cast<ElementAccum *>(dsoftmax_sum) + row_offset_dpsum;
constexpr int n_masking_steps = (!Is_causal && !Is_local)
? 1
: ((Is_even_MN && Is_causal) ? flash::ceil_div(kBlockM_, kBlockN_) : flash::ceil_div(kBlockM_, kBlockN_) + 1);
// int warp_id =0;
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
union_vec2_f16x2<Element> q_reg[(K/kBlockK_)*((WARP_M_*kBlockK_)/(32*32))*2][2];
union_vec2_f16x2<Element> dO_reg[(K_v/kBlockK_)*((WARP_M_*kBlockK_)/(32*32))*2][2];
union_vec4_fp32 acc_dq[(K/kBlockK_) * ((WARP_M_/32)*(kBlockK_/32))][4]={0};
float lse[WARP_M_/16];
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
int lse_idx = warp_id*WARP_M_ + mi*32 + ((lane_id & 15)*2) + min_tile_m;
lse[mi*2 + min_tile_m] = (Is_even_MN || lse_idx < binfo.actual_seqlen_q - m_block * kBlockM_) ? gLSE[lse_idx] : INFINITY;
}
}
float dP_sum_reg[WARP_M_/16];
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
int dP_sum_idx = warp_id*WARP_M_ + mi*32 + ((lane_id & 15)*2) + min_tile_m;
dP_sum_reg[mi*2 + min_tile_m] = gdPsum[dP_sum_idx];
}
}
prefetch_to_vgpr<K, kBlockM_, kBlockK_, WARP_N_, Element, ElementAccum, Is_even_MN>(gQ, Q_lds, q_reg, (binfo.actual_seqlen_q - m_block * kBlockM_), seqlen_q_stride);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
prefetch_to_vgpr<K_v, kBlockM_, kBlockK_, WARP_N_, Element, ElementAccum, Is_even_MN>(gdO, dO_lds, dO_reg, (binfo.actual_seqlen_q - m_block * kBlockM_), seqlen_do_stride);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
if constexpr (Is_preload_K){
prefetch_to_tmp_lds_wait<Is_even_MN, K, kBlockM_, kBlockN_, kBlockK_, WARP_M_, WARP_N_, Element>(gK, K_lds, (binfo.actual_seqlen_k - (n_block_max - 1) * kBlockN_), warp_id, seqlen_k_stride);
}
if constexpr (Is_preload_V){
prefetch_to_tmp_lds_wait<Is_even_MN, K_v, kBlockM_, kBlockN_, kBlockK_, WARP_M_, WARP_N_, Element>(gV, V_lds, (binfo.actual_seqlen_k - (n_block_max - 1) * kBlockN_), warp_id, seqlen_v_stride);
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
for (int n_block = n_block_max - 1; n_block >= n_block_min ; --n_block) {
union_vec2_f16x2<Element> v_reg[((WARP_N_*kBlockK_)/(32*32))*2][2];
union_vec4_fp32 dp_reg[(WARP_M_/32)*(kBlockN_/32)][4]= {0};
{
//dp gemm
gemm_tt_kq<false, Is_preload_K, Is_even_MN, 3, V_prefetch_level, K_v, kBlockM_, kBlockN_, kBlockK_, WARP_N_, WARP_N_, STAGES, Element>(gdO, gV, dO_lds, V_lds, (binfo.actual_seqlen_q - m_block * kBlockM_), (binfo.actual_seqlen_k - n_block * kBlockN_), dO_reg, v_reg, dp_reg, warp_id, seqlen_do_stride, seqlen_v_stride);
}
#ifdef DEBUGING
print_dp(m_block, bidb, bidh);
#endif
union_vec2_f16x2<Element> k_reg[((WARP_M_*kBlockK_)/(32*32))*2][2];
//c mini tile is 32*32
union_vec4_fp32 s_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
//qk gemm
gemm_tt_kq<Is_store_K, false, Is_even_MN, Q_prefetch_level, K_prefetch_level, K, kBlockM_, kBlockN_, kBlockK_, WARP_N_, WARP_N_, STAGES, Element>(gQ, gK, Q_lds, K_lds, (binfo.actual_seqlen_q - m_block * kBlockM_), (binfo.actual_seqlen_k - n_block * kBlockN_), q_reg, k_reg, s_reg, warp_id, seqlen_q_stride, seqlen_k_stride);
*(uint64_t*)&gV -= ((kBlockN_ * seqlen_v_stride) * sizeof(Element));
if (Is_preload_V && n_block > n_block_min){
prefetch_to_tmp_lds_wait<Is_even_MN, K_v, kBlockM_, kBlockN_, kBlockK_, WARP_M_, WARP_N_, Element>(gV, V_lds, (binfo.actual_seqlen_k - (n_block - 1) * kBlockN_), warp_id, seqlen_v_stride);
}
apply_mask_bwd<Is_even_MN, Is_local ? 3 : (Is_causal ? 1 : 0)>(s_reg, binfo.actual_seqlen_q - m_block * kBlockM_ - warp_id * 32, binfo.actual_seqlen_k - n_block * kBlockN_, (m_block * kBlockM_ + warp_id * 32) - (n_block * kBlockN_), params.window_size_left, params.window_size_right);
#ifdef DEBUGING
print_qk(m_block, bidb, bidh);
#endif
scale_apply_exp2_bwd_seq_q_major</*scale_max=*/false, WARP_M_, kBlockN_, union_vec4_fp32, ElementAccum>(s_reg, lse, params.scale_softmax_log2);
#ifdef DEBUGING
print_softmax_rescale_o(m_block, bidb, bidh)
#endif
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
// return p * (dp - d);
};
union_vec4_fp32 dS_reg[(WARP_M_/32)*(kBlockN_/32)][4];
#pragma unroll
for (int ni = 0; ni < (kBlockN_/32); ++ni) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
// result register ds_reg reuse dp_reg
dS_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx] = pointwise_mult(
s_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx],
dp_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx],
dP_sum_reg[min_tile_m + mi*2]);
}
}
}
}
}
#ifdef DEBUGING
print_ds(m_block, bidb, bidh);
#endif
union_vec2_f16x2<Element> dS_reg_fp16[(WARP_M_/32)*(kBlockN_/32)][4];
convert_pk_type<WARP_M_, kBlockN_, Element>(dS_reg_fp16, dS_reg);
{
//dq gemm, K*dS
gpu_gemm_B_in_reg<Is_store_K , false , false, Is_even_MN, K, kBlockK_, kBlockM_, kBlockN_, kBlockK_, WARP_M_, 2, Element>(gK, gK, K_lds, dS_reg_fp16, acc_dq, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_k_stride);
}
*(uint64_t*)&gK -= ((kBlockN_ * seqlen_k_stride) * sizeof(Element));
// if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0){
// printf("(binfo.actual_seqlen_k - n_block * kBlockN_) = %d\n", (binfo.actual_seqlen_k - n_block * kBlockN_));
// }
#if 1//defined(__gfx936__)
{
__syncthreads();
if (Is_preload_K && n_block > n_block_min){
prefetch_to_tmp_lds_wait<Is_even_MN, K, kBlockM_, kBlockN_, kBlockK_, WARP_M_, WARP_N_, Element>(gK, K_lds, (binfo.actual_seqlen_k - (n_block - 1) * kBlockN_), warp_id, seqlen_k_stride);
}
}
#else
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
#endif
}
{
int dq_lane_seq_idx = (lane_id >> 4);
int dq_lane_head_dim_idx = (lane_id & 15);
int dq_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K/kBlockK_); k_loop++) {
int dq_block_buffer_store_global_offset = k_loop * kBlockK_;
#pragma unroll
for(int warp_m_idx=0; warp_m_idx<(WARP_M_/32); warp_m_idx++) {
int dq_warp_buffer_store_global_offset = (warp_id*WARP_M_ + warp_m_idx*32 + dq_lane_seq_idx*2) * seqlen_dq_stride;
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
dq_global_addr_offset = dq_block_buffer_store_global_offset + dq_warp_buffer_store_global_offset + k_tile_idx*32;
#pragma unroll 2
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
#pragma unroll 2
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
int dq_global_addr = dq_global_addr_offset + (min_tile_m + vec_index*8)*seqlen_dq_stride + min_tile_k + dq_lane_head_dim_idx*2;
if(Is_even_MN || ((m_block * kBlockM_) + (warp_id*WARP_M_ + warp_m_idx*32 + dq_lane_seq_idx*2) + min_tile_m + vec_index*8) < binfo.actual_seqlen_q) {
dq_ptr[dq_global_addr] = DownCast<ElementAccum, Element>(acc_dq[k_loop * ((WARP_M_/32)*(kBlockK_/32)) + (warp_m_idx*(kBlockK_/32) + k_tile_idx)][min_tile_k + min_tile_m*2].f32[vec_index] * params.scale_softmax_rp_dropout);
}
}
}
}
}
}
}
}
}
#undef print_qk
#undef print_softmax_rescale_o
#undef print_dp
#undef print_ds
#ifdef DEBUGING
#define print_qk(block_id_m, bidb, bidh) {\
int qk_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = qk_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
kq_ptr[offset] = s_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_softmax_rescale_o(block_id_m, bidb, bidh) { \
int qk_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = qk_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
s_ptr[offset] = s_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_dp(block_id_m, bidb, bidh) {\
int dp_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int dp_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + dp_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = dp_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
dp_ptr[offset] = dp_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_ds(block_id_m, bidb, bidh) {\
int ds_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int ds_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = ds_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
ds_ptr[offset] = dS_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#endif
template<class Element, class ElementAccum, bool Is_dropout, bool Is_causal , bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel, int kBlockM_, int kBlockN_, int K, int K_v, int kBlockK_, int WARP_M_, int WARP_N_, int STAGES, int USE_BSHD_LAYOUT, typename Params>
__forceinline__ __device__ void compute_dq_1colblock_gfx938(Params &params, int bidb, int bidh, int m_block
) {
#ifdef DEBUGING
ElementAccum * kq_ptr = static_cast<ElementAccum*>(params.kq_ptr);
ElementAccum * s_ptr = static_cast<ElementAccum*>(params.s_ptr);
ElementAccum * dp_ptr = static_cast<ElementAccum*>(params.dp_ptr);
ElementAccum * ds_ptr = static_cast<ElementAccum*>(params.ds_ptr);
#endif
Element* q_ptr = static_cast<Element*>(params.q_ptr);
Element* k_ptr = static_cast<Element*>(params.k_ptr);
Element* v_ptr = static_cast<Element*>(params.v_ptr);
Element* o_ptr = static_cast<Element*>(params.o_ptr);
Element* dq_ptr = static_cast<Element*>(params.dq_ptr);
Element* dk_ptr = static_cast<Element*>(params.dk_ptr);
Element* dv_ptr = static_cast<Element*>(params.dv_ptr);
Element* do_ptr = static_cast<Element*>(params.do_ptr);
ElementAccum* softmax_lse_ptr = static_cast<ElementAccum*>(params.softmax_lse_ptr);
ElementAccum* dsoftmax_sum = static_cast<ElementAccum*>(params.dsoftmax_sum);
//flash-attention QK, kBlockN_==WARP_N_;
const int M_BLOCK_NUM = params.seqlen_q/kBlockM_;
const int N_BLOCK_NUM = params.seqlen_k/kBlockN_;
extern __shared__ Element smem[];
#if 1//defined(__gfx936__)
const bool Is_store_K = true;
const bool Is_preload_K = true;
const bool Is_preload_V = true;
#else
const bool Is_store_K = false;
const bool Is_preload_K = false;
const bool Is_preload_V = false;
#endif
const int K_prefetch_level = Is_preload_K ? 1 : 0;
const int V_prefetch_level = Is_preload_V ? 1 : 0;
const int Q_prefetch_level = 3;
Element* K_lds = (Element*)&(smem);
Element* Q_lds = (Element*)&(smem);
Element* dO_lds = (Element*)&(smem);
Element* V_lds = (Element*)&(smem) + kBlockN_* K;
int tidx = threadIdx.x;
int lane_id = threadIdx.x & 63; //lane id, 0-63
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (m_block < 0 || m_block * kBlockM_ >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM_ - params.window_size_left) / kBlockN_);
const int n_block_max = (!Is_causal && !Is_local) ? ceil_div(binfo.actual_seqlen_k, kBlockN_) : std::min(ceil_div(binfo.actual_seqlen_k, kBlockN_), flash::ceil_div((m_block + 1) * kBlockM_ + params.window_size_right, kBlockN_));
int seqlen_q_stride = params.q_row_stride;
int seqlen_k_stride = params.k_row_stride;
int seqlen_v_stride = params.v_row_stride;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
int seqlen_dq_stride = params.dq_row_stride;
// We move K and V to the last block.
const int row_offset_q = binfo.q_offset1(params.q_batch_stride, params.q_row_stride, bidb) + binfo.q_offset2(params.q_head_stride,bidh) + m_block * kBlockM_ * seqlen_q_stride;
const int row_offset_k = binfo.k_offset1(params.k_batch_stride, params.k_row_stride, bidb) + binfo.k_offset2(params.k_head_stride,bidh/params.h_h_k_ratio) + (n_block_max - 1) * kBlockN_ * seqlen_k_stride;
const int row_offset_v = binfo.k_offset1(params.v_batch_stride, params.v_row_stride, bidb) + binfo.k_offset2(params.v_head_stride,bidh/params.h_h_k_ratio) + (n_block_max - 1) * kBlockN_ * seqlen_v_stride;
const int row_offset_dO = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + m_block * kBlockM_ * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + m_block * kBlockM_ * seqlen_o_stride;
const int row_offset_dq = binfo.q_offset1(params.dq_batch_stride, params.dq_row_stride, bidb) + binfo.q_offset2(params.dq_head_stride,bidh) + m_block * kBlockM_ * seqlen_dq_stride;
const int row_offset_lse = params.cu_seqlens_q == nullptr ? (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM_ : bidh * params.total_q + binfo.sum_s_q + m_block * kBlockM_;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM_;
auto gQ = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(q_ptr) + row_offset_q, seqlen_q_stride);
auto gK = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(k_ptr) + row_offset_k, seqlen_k_stride);
auto gV = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(v_ptr) + row_offset_v, seqlen_v_stride);
auto gdO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_dO, seqlen_do_stride);
Element * gO = reinterpret_cast<Element *>(o_ptr) + row_offset_o;
dq_ptr = reinterpret_cast<Element *>(dq_ptr) + row_offset_dq;
ElementAccum *gLSE = reinterpret_cast<ElementAccum *>(softmax_lse_ptr) + row_offset_lse;
ElementAccum *gdPsum = reinterpret_cast<ElementAccum *>(dsoftmax_sum) + row_offset_dpsum;
constexpr int n_masking_steps = (!Is_causal && !Is_local)
? 1
: ((Is_even_MN && Is_causal) ? flash::ceil_div(kBlockM_, kBlockN_) : flash::ceil_div(kBlockM_, kBlockN_) + 1);
// int warp_id =0;
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
union_vec4_f16x2<Element> q_reg[(K/kBlockK_)*((WARP_M_*kBlockK_)/(32*32))*2];
union_vec4_f16x2<Element> dO_reg[(K_v/kBlockK_)*((WARP_M_*kBlockK_)/(32*32))*2];
union_vec4_fp32 acc_dq[(K/kBlockK_) * ((WARP_M_/32)*(kBlockK_/32))][4]={0};
float lse[WARP_M_/16];
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
int lse_idx = warp_id*WARP_M_ + mi*32 + (lane_id & 15) + min_tile_m * 16;
lse[mi*2 + min_tile_m] = (Is_even_MN || lse_idx < binfo.actual_seqlen_q - m_block * kBlockM_) ? gLSE[lse_idx] : INFINITY;
}
}
float dP_sum_reg[WARP_M_/16];
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
int dP_sum_idx = warp_id*WARP_M_ + mi*32 + (lane_id & 15) + min_tile_m * 16;
dP_sum_reg[mi*2 + min_tile_m] = gdPsum[dP_sum_idx];
}
}
prefetch_to_vgpr_gfx938<true, kBlockM_, K, Element, ElementAccum, Is_even_MN>(gQ, Q_lds, q_reg, (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
prefetch_to_vgpr_gfx938<true, kBlockM_, K, Element, ElementAccum, Is_even_MN>(gdO, dO_lds, dO_reg, (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
if constexpr (Is_preload_V){
prefetch_to_lds_gfx938<true, kBlockN_, K_v, Element, ElementAccum, Is_even_MN>(gV, 0, V_lds, (binfo.actual_seqlen_k - (n_block_max - 1) * kBlockN_), warp_id);
}
if constexpr (Is_preload_K){
prefetch_to_lds_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gK, 0, K_lds, (binfo.actual_seqlen_k - (n_block_max - 1) * kBlockN_), warp_id);
}
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
for (int n_block = n_block_max - 1; n_block >= n_block_min ; --n_block) {
union_vec4_f16x2<Element> v_reg[((WARP_N_*kBlockK_)/(32*32))*2];
union_vec4_fp32 dp_reg[(WARP_M_/32)*(kBlockN_/32)][4]= {0};
//dP gemm
gemm_tt_kq_gfx938<false, Is_preload_K, Is_even_MN, 3, V_prefetch_level, K_v, kBlockM_, kBlockN_, kBlockK_, WARP_N_, WARP_N_, STAGES, Element>(
gdO, gV, dO_lds, V_lds, (binfo.actual_seqlen_q - m_block * kBlockM_), (binfo.actual_seqlen_k - n_block * kBlockN_), dO_reg, v_reg, dp_reg, warp_id, seqlen_do_stride, seqlen_v_stride
);
#ifdef DEBUGING
print_dp(m_block, bidb, bidh);
#endif
union_vec4_f16x2<Element> k_reg[((WARP_M_*kBlockK_)/(32*32))*2];
//c mini tile is 32*32
union_vec4_fp32 s_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
//qk gemm
gemm_tt_kq_gfx938<Is_store_K, false, Is_even_MN, Q_prefetch_level, K_prefetch_level, K, kBlockM_, kBlockN_, kBlockK_, WARP_N_, WARP_N_, STAGES, Element>(
gQ, gK, Q_lds, K_lds, (binfo.actual_seqlen_q - m_block * kBlockM_), (binfo.actual_seqlen_k - n_block * kBlockN_), q_reg, k_reg, s_reg, warp_id, seqlen_q_stride, seqlen_k_stride
);
*(uint64_t*)&gV -= ((kBlockN_ * seqlen_v_stride) * sizeof(Element));
if (Is_preload_V && n_block > n_block_min){
prefetch_to_lds_gfx938<true, kBlockN_, K_v, Element, ElementAccum, Is_even_MN>(gV, 0, V_lds, (binfo.actual_seqlen_k - (n_block - 1) * kBlockN_), warp_id);
}
apply_mask_bwd_gfx938<Is_even_MN, Is_local ? 3 : (Is_causal ? 1 : 0)>(s_reg, binfo.actual_seqlen_q - m_block * kBlockM_ - warp_id * 32, binfo.actual_seqlen_k - n_block * kBlockN_, (m_block * kBlockM_ + warp_id * 32) - (n_block * kBlockN_), params.window_size_left, params.window_size_right);
#ifdef DEBUGING
print_qk(m_block, bidb, bidh);
#endif
scale_apply_exp2_bwd_seq_q_major</*scale_max=*/false, WARP_M_, kBlockN_, union_vec4_fp32, ElementAccum>(s_reg, lse, params.scale_softmax_log2);
#ifdef DEBUGING
print_softmax_rescale_o(m_block, bidb, bidh)
#endif
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
// return p * (dp - d);
};
union_vec4_fp32 dS_reg[(WARP_M_/32)*(kBlockN_/32)][4];
#pragma unroll
for (int ni = 0; ni < (kBlockN_/32); ++ni) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
// result register ds_reg reuse dp_reg
dS_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx] = pointwise_mult(
s_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx],
dp_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx],
dP_sum_reg[min_tile_m + mi*2]);
}
}
}
}
}
#ifdef DEBUGING
print_ds(m_block, bidb, bidh);
#endif
union_vec4_f16x2<Element> dS_reg_fp16[(WARP_M_/32)*(kBlockN_/32)*2];
convert_pk_type_gfx938<WARP_M_, kBlockN_, Element>(dS_reg_fp16, dS_reg);
{
//dq gemm, K*dS
gpu_gemm_B_in_reg_gfx938<Is_store_K , false , Is_even_MN, K, kBlockK_, kBlockM_, kBlockN_, kBlockK_, WARP_M_, 2, Element>(gK, gK, K_lds, dS_reg_fp16, acc_dq, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_k_stride);
}
*(uint64_t*)&gK -= ((kBlockN_ * seqlen_k_stride) * sizeof(Element));
#if 1//defined(__gfx936__)
{
__syncthreads();
if (Is_preload_K && n_block > n_block_min){
prefetch_to_lds_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gK, 0, K_lds, (binfo.actual_seqlen_k - (n_block - 1) * kBlockN_), warp_id);
}
}
#else
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
#endif
}
//mmac
{
int dq_lane_seq_idx = (lane_id >> 4);
int dq_lane_head_dim_idx = (lane_id & 15);
int dq_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K/kBlockK_); k_loop++) {
int dq_block_buffer_store_global_offset = k_loop * kBlockK_;
#pragma unroll
for(int warp_m_idx=0; warp_m_idx<(WARP_M_/32); warp_m_idx++) {
int dq_warp_buffer_store_global_offset = (warp_id*WARP_M_ + warp_m_idx*32 + dq_lane_seq_idx) * seqlen_dq_stride;
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
dq_global_addr_offset = dq_block_buffer_store_global_offset + dq_warp_buffer_store_global_offset + k_tile_idx*32;
#pragma unroll 2
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
#pragma unroll 2
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
int dq_global_addr = dq_global_addr_offset + (min_tile_m*16 + vec_index*4)*seqlen_dq_stride + min_tile_k + dq_lane_head_dim_idx*2;
if(Is_even_MN || ((m_block * kBlockM_) + (warp_id*WARP_M_ + warp_m_idx*32 + dq_lane_seq_idx) + min_tile_m*16 + vec_index*4) < binfo.actual_seqlen_q) {
dq_ptr[dq_global_addr] = DownCast<ElementAccum, Element>(acc_dq[k_loop * ((WARP_M_/32)*(kBlockK_/32)) + (warp_m_idx*(kBlockK_/32) + k_tile_idx)][min_tile_k + min_tile_m*2].f32[vec_index] * params.scale_softmax_rp_dropout);
}
}
}
}
}
}
}
}
}
#undef print_qk
#undef print_softmax_rescale_o
#undef print_dp
#undef print_ds
#ifdef DEBUGING
#define print_qk(block_id_m, bidb, bidh) {\
int qk_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = qk_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
kq_ptr[offset] = s_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_softmax_rescale_o(block_id_m, bidb, bidh) { \
int qk_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = qk_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
s_ptr[offset] = s_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_dp(block_id_m, bidb, bidh) {\
int dp_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int dp_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + dp_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = dp_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
dp_ptr[offset] = dp_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_ds(block_id_m, bidb, bidh) {\
int ds_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int ds_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = ds_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
ds_ptr[offset] = dS_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#endif
template<class Element, class ElementAccum, bool Is_dropout, bool Is_causal , bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel, int kBlockM_, int kBlockN_, int K, int K_v, int kBlockK_, int WARP_M_, int WARP_N_, int STAGES, int USE_BSHD_LAYOUT, typename Params>
__forceinline__ __device__ void compute_dq_1colblock_gfx946(Params &params, int bidb, int bidh, int m_block
) {
#ifdef DEBUGING
ElementAccum * kq_ptr = static_cast<ElementAccum*>(params.kq_ptr);
ElementAccum * s_ptr = static_cast<ElementAccum*>(params.s_ptr);
ElementAccum * dp_ptr = static_cast<ElementAccum*>(params.dp_ptr);
ElementAccum * ds_ptr = static_cast<ElementAccum*>(params.ds_ptr);
#endif
Element* q_ptr = static_cast<Element*>(params.q_ptr);
Element* k_ptr = static_cast<Element*>(params.k_ptr);
Element* v_ptr = static_cast<Element*>(params.v_ptr);
Element* o_ptr = static_cast<Element*>(params.o_ptr);
Element* dq_ptr = static_cast<Element*>(params.dq_ptr);
Element* dk_ptr = static_cast<Element*>(params.dk_ptr);
Element* dv_ptr = static_cast<Element*>(params.dv_ptr);
Element* do_ptr = static_cast<Element*>(params.do_ptr);
ElementAccum* softmax_lse_ptr = static_cast<ElementAccum*>(params.softmax_lse_ptr);
ElementAccum* dsoftmax_sum = static_cast<ElementAccum*>(params.dsoftmax_sum);
//flash-attention QK, kBlockN_==WARP_N_;
const int M_BLOCK_NUM = params.seqlen_q/kBlockM_;
const int N_BLOCK_NUM = params.seqlen_k/kBlockN_;
extern __shared__ Element smem[];
#if 1//defined(__gfx936__)
const bool Is_store_K = true;
const bool Is_preload_K = true;
const bool Is_preload_V = true;
#else
const bool Is_store_K = false;
const bool Is_preload_K = false;
const bool Is_preload_V = false;
#endif
const int K_prefetch_level = Is_preload_K ? 1 : 0;
const int V_prefetch_level = Is_preload_V ? 1 : 0;
const int Q_prefetch_level = 3;
Element* K_lds = (Element*)&(smem);
Element* Q_lds = (Element*)&(smem);
Element* dO_lds = (Element*)&(smem);
Element* V_lds = (Element*)&(smem) + kBlockN_* K;
int tidx = threadIdx.x;
int lane_id = threadIdx.x & 63; //lane id, 0-63
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (m_block < 0 || m_block * kBlockM_ >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM_ - params.window_size_left) / kBlockN_);
const int n_block_max = (!Is_causal && !Is_local) ? ceil_div(binfo.actual_seqlen_k, kBlockN_) : std::min(ceil_div(binfo.actual_seqlen_k, kBlockN_), flash::ceil_div((m_block + 1) * kBlockM_ + params.window_size_right, kBlockN_));
int seqlen_q_stride = params.q_row_stride;
int seqlen_k_stride = params.k_row_stride;
int seqlen_v_stride = params.v_row_stride;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
int seqlen_dq_stride = params.dq_row_stride;
// We move K and V to the last block.
const int row_offset_q = binfo.q_offset1(params.q_batch_stride, params.q_row_stride, bidb) + binfo.q_offset2(params.q_head_stride,bidh) + m_block * kBlockM_ * seqlen_q_stride;
const int row_offset_k = binfo.k_offset1(params.k_batch_stride, params.k_row_stride, bidb) + binfo.k_offset2(params.k_head_stride,bidh/params.h_h_k_ratio) + (n_block_max - 1) * kBlockN_ * seqlen_k_stride;
const int row_offset_v = binfo.k_offset1(params.v_batch_stride, params.v_row_stride, bidb) + binfo.k_offset2(params.v_head_stride,bidh/params.h_h_k_ratio) + (n_block_max - 1) * kBlockN_ * seqlen_v_stride;
const int row_offset_dO = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + m_block * kBlockM_ * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + m_block * kBlockM_ * seqlen_o_stride;
const int row_offset_dq = binfo.q_offset1(params.dq_batch_stride, params.dq_row_stride, bidb) + binfo.q_offset2(params.dq_head_stride,bidh) + m_block * kBlockM_ * seqlen_dq_stride;
const int row_offset_lse = params.cu_seqlens_q == nullptr ? (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM_ : bidh * params.total_q + binfo.sum_s_q + m_block * kBlockM_;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM_;
auto gQ = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(q_ptr) + row_offset_q, seqlen_q_stride);
auto gK = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(k_ptr) + row_offset_k, seqlen_k_stride);
auto gV = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(v_ptr) + row_offset_v, seqlen_v_stride);
auto gdO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_dO, seqlen_do_stride);
Element * gO = reinterpret_cast<Element *>(o_ptr) + row_offset_o;
dq_ptr = reinterpret_cast<Element *>(dq_ptr) + row_offset_dq;
ElementAccum *gLSE = reinterpret_cast<ElementAccum *>(softmax_lse_ptr) + row_offset_lse;
ElementAccum *gdPsum = reinterpret_cast<ElementAccum *>(dsoftmax_sum) + row_offset_dpsum;
constexpr int n_masking_steps = (!Is_causal && !Is_local)
? 1
: ((Is_even_MN && Is_causal) ? flash::ceil_div(kBlockM_, kBlockN_) : flash::ceil_div(kBlockM_, kBlockN_) + 1);
// int warp_id =0;
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
union_vec4_f16x2<Element> q_reg[(K/kBlockK_)*((WARP_M_*kBlockK_)/(32*32))*2];
union_vec4_f16x2<Element> dO_reg[(K_v/kBlockK_)*((WARP_M_*kBlockK_)/(32*32))*2];
union_vec4_fp32 acc_dq[(K/kBlockK_) * ((WARP_M_/32)*(kBlockK_/32))][4]={0};
float lse[WARP_M_/16];
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
int lse_idx = warp_id*WARP_M_ + mi*32 + (lane_id & 15) + min_tile_m * 16;
lse[mi*2 + min_tile_m] = (Is_even_MN || lse_idx < binfo.actual_seqlen_q - m_block * kBlockM_) ? gLSE[lse_idx] : INFINITY;
}
}
float dP_sum_reg[WARP_M_/16];
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
int dP_sum_idx = warp_id*WARP_M_ + mi*32 + (lane_id & 15) + min_tile_m * 16;
dP_sum_reg[mi*2 + min_tile_m] = gdPsum[dP_sum_idx];
}
}
prefetch_to_vgpr_gfx938<true, kBlockM_, K, Element, ElementAccum, Is_even_MN>(gQ, Q_lds, q_reg, (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
prefetch_to_vgpr_gfx938<true, kBlockM_, K, Element, ElementAccum, Is_even_MN>(gdO, dO_lds, dO_reg, (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
if constexpr (Is_preload_V){
prefetch_to_lds_gfx938<true, kBlockN_, K_v, Element, ElementAccum, Is_even_MN>(gV, 0, V_lds, (binfo.actual_seqlen_k - (n_block_max - 1) * kBlockN_), warp_id);
}
if constexpr (Is_preload_K){
prefetch_to_lds_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gK, 0, K_lds, (binfo.actual_seqlen_k - (n_block_max - 1) * kBlockN_), warp_id);
}
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
for (int n_block = n_block_max - 1; n_block >= n_block_min ; --n_block) {
union_vec4_f16x2<Element> v_reg[((WARP_N_*kBlockK_)/(32*32))*2];
union_vec4_fp32 dp_reg[(WARP_M_/32)*(kBlockN_/32)][4]= {0};
//dP gemm
gemm_tt_kq_gfx938<false, Is_preload_K, Is_even_MN, 3, V_prefetch_level, K_v, kBlockM_, kBlockN_, kBlockK_, WARP_N_, WARP_N_, STAGES, Element>(
gdO, gV, dO_lds, V_lds, (binfo.actual_seqlen_q - m_block * kBlockM_), (binfo.actual_seqlen_k - n_block * kBlockN_), dO_reg, v_reg, dp_reg, warp_id, seqlen_do_stride, seqlen_v_stride
);
#ifdef DEBUGING
print_dp(m_block, bidb, bidh);
#endif
union_vec4_f16x2<Element> k_reg[((WARP_M_*kBlockK_)/(32*32))*2];
//c mini tile is 32*32
union_vec4_fp32 s_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
//qk gemm
gemm_tt_kq_gfx938<Is_store_K, false, Is_even_MN, Q_prefetch_level, K_prefetch_level, K, kBlockM_, kBlockN_, kBlockK_, WARP_N_, WARP_N_, STAGES, Element>(
gQ, gK, Q_lds, K_lds, (binfo.actual_seqlen_q - m_block * kBlockM_), (binfo.actual_seqlen_k - n_block * kBlockN_), q_reg, k_reg, s_reg, warp_id, seqlen_q_stride, seqlen_k_stride
);
*(uint64_t*)&gV -= ((kBlockN_ * seqlen_v_stride) * sizeof(Element));
if (Is_preload_V && n_block > n_block_min){
prefetch_to_lds_gfx938<true, kBlockN_, K_v, Element, ElementAccum, Is_even_MN>(gV, 0, V_lds, (binfo.actual_seqlen_k - (n_block - 1) * kBlockN_), warp_id);
}
apply_mask_bwd_gfx938<Is_even_MN, Is_local ? 3 : (Is_causal ? 1 : 0)>(s_reg, binfo.actual_seqlen_q - m_block * kBlockM_ - warp_id * 32, binfo.actual_seqlen_k - n_block * kBlockN_, (m_block * kBlockM_ + warp_id * 32) - (n_block * kBlockN_), params.window_size_left, params.window_size_right);
#ifdef DEBUGING
print_qk(m_block, bidb, bidh);
#endif
scale_apply_exp2_bwd_seq_q_major</*scale_max=*/false, WARP_M_, kBlockN_, union_vec4_fp32, ElementAccum>(s_reg, lse, params.scale_softmax_log2);
#ifdef DEBUGING
print_softmax_rescale_o(m_block, bidb, bidh)
#endif
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
// return p * (dp - d);
};
union_vec4_fp32 dS_reg[(WARP_M_/32)*(kBlockN_/32)][4];
#pragma unroll
for (int ni = 0; ni < (kBlockN_/32); ++ni) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
// result register ds_reg reuse dp_reg
dS_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx] = pointwise_mult(
s_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx],
dp_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx],
dP_sum_reg[min_tile_m + mi*2]);
// dS_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx] = s_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
#ifdef DEBUGING
print_ds(m_block, bidb, bidh);
#endif
union_vec4_f16x2<Element> dS_reg_fp16[(WARP_M_/32)*(kBlockN_/32)*2];
convert_pk_type_gfx938<WARP_M_, kBlockN_, Element>(dS_reg_fp16, dS_reg);
{
//dq gemm, K*dS
gpu_gemm_B_in_reg_gfx946<Is_store_K , false , Is_even_MN, K, kBlockK_, kBlockM_, kBlockN_, kBlockK_, WARP_M_, 2, Element>(gK, gK, K_lds, dS_reg_fp16, acc_dq, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_k_stride);
}
*(uint64_t*)&gK -= ((kBlockN_ * seqlen_k_stride) * sizeof(Element));
// if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0){
// printf("(binfo.actual_seqlen_k - n_block * kBlockN_) = %d\n", (binfo.actual_seqlen_k - n_block * kBlockN_));
// }
#if 1//defined(__gfx936__)
{
__syncthreads();
if (Is_preload_K && n_block > n_block_min){
prefetch_to_lds_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gK, 0, K_lds, (binfo.actual_seqlen_k - (n_block - 1) * kBlockN_), warp_id);
}
}
#else
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
#endif
}
#if 1
//这是正常的MLS+ds_read_matrix的layout
{
dq_ptr = dq_ptr + binfo.q_offset1(params.dq_batch_stride, params.dq_row_stride, bidb) + binfo.q_offset2(params.dq_head_stride,bidh);
auto gdQ = tcp_cache_swizzle_func<K_v, Element>(dq_ptr);
int dq_lane_seq_idx = (lane_id >> 4);
int dq_lane_head_dim_idx = (lane_id & 15);
int dq_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K_v/kBlockK_); k_loop++) {
#pragma unroll
for(int warp_m_idx=0; warp_m_idx<(WARP_M_/32); warp_m_idx++) {
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
int v_offset = dq_lane_head_dim_idx * seqlen_dq_stride + dq_lane_seq_idx * 4;
int s_offset = (min_tile_m * seqlen_dq_stride * 16 + vec_index % 2 * 2 + vec_index / 2 * 16) + (k_tile_idx*32) + ((warp_id*WARP_M_ + warp_m_idx*32) * seqlen_dq_stride) + (k_loop * kBlockK_ + m_block * kBlockM_ * seqlen_dq_stride);
int known_offset = 0;
vec2_Element<Element> v_data;
v_data[0] = DownCast<float,Element,true>(acc_dq[k_loop * ((WARP_M_/32)*(kBlockK_/32)) + (warp_m_idx*(kBlockK_/32) + k_tile_idx)][min_tile_m*2 + vec_index / 2].f32[vec_index % 2 * 2] * params.scale_softmax_rp_dropout);
v_data[1] = DownCast<float,Element,true>(acc_dq[k_loop * ((WARP_M_/32)*(kBlockK_/32)) + (warp_m_idx*(kBlockK_/32) + k_tile_idx)][min_tile_m*2 + vec_index / 2].f32[vec_index % 2 * 2 + 1] * params.scale_softmax_rp_dropout);
if (Is_even_MN || min_tile_m*16 + (warp_id*WARP_M_ + warp_m_idx*32) + m_block * kBlockM_ + dq_lane_head_dim_idx < binfo.actual_seqlen_q){
inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdQ, s_offset, /* immediate integer */known_offset);
}
}
}
}
}
}
}
#endif
}
#undef print_qk
#undef print_softmax_rescale_o
#undef print_dp
#undef print_ds
#define print_kq(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int qk_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int qk_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_id*WARP_N_ + qk_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + qk_warp_n_id*WARP_N_ + n_idx * 32 + ((lane_id & 15) << 1) + min_tile_n) < params.seqlen_k) && \
((block_id_m*kBlockM_ + qk_warp_m_id*WARP_M_ + m_idx*32 + reg_id * 8 + min_tile_m + ((lane_id / 16) * 2)) < params.seqlen_q)) { \
int offset = qk_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ ((lane_id & 15) << 1) + min_tile_m*params.seqlen_k + ((lane_id / 16) * 2) *params.seqlen_k + min_tile_n ; \
kq_ptr[offset + reg_id * 8 *params.seqlen_k] = 0;(s_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_m*2 + min_tile_n].f32[reg_id]); \
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
// #define print_kq(block_id_m, bidb, bidh) { \
// __builtin_amdgcn_sched_barrier(0);\
// __builtin_amdgcn_s_waitcnt(0);\
// __syncthreads();\
// __builtin_amdgcn_sched_barrier(0);\
// int qk_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
// int qk_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
// int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
// + block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_id*WARP_N_ + qk_warp_m_id*WARP_M_*params.seqlen_k; \
// for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
// for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
// for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
// for(int reg_id=0; reg_id<4; reg_id++) { \
// for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
// for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
// if(((n_block*kBlockN_ + qk_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
// ((block_id_m*kBlockM_ + qk_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
// int offset = qk_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
// kq_ptr[offset + reg_id *params.seqlen_k] = 0;(s_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_m*2 + min_tile_n].f32[reg_id]); \
// } \
// } \
// } \
// } \
// } \
// } \
// } \
// __builtin_amdgcn_sched_barrier(0);\
// __builtin_amdgcn_s_waitcnt(0);\
// __syncthreads();\
// __builtin_amdgcn_sched_barrier(0);\
// }
#define print_softmax_rescale_o(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int s_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int s_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int s_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + s_warp_n_id*WARP_N_ + s_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + s_warp_n_id*WARP_N_ + n_idx * 32 + ((lane_id & 15) << 1) + min_tile_n) < params.seqlen_k) && \
((block_id_m*kBlockM_ + s_warp_m_id*WARP_M_ + m_idx*32 + reg_id * 8 + min_tile_m + ((lane_id / 16) * 2)) < params.seqlen_q)) { \
int offset = s_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ ((lane_id & 15) << 1) + min_tile_m*params.seqlen_k + ((lane_id / 16) * 2)*params.seqlen_k + min_tile_n ;\
s_ptr[offset + reg_id * 8 * params.seqlen_k] = (s_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f32[reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_ds(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int ds_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int ds_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int ds_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_id*WARP_N_ + ds_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + ds_warp_n_id*WARP_N_ + n_idx * 32 + ((lane_id & 15) << 1) + min_tile_n) < params.seqlen_k) && \
((block_id_m*kBlockM_ + ds_warp_m_id*WARP_M_ + m_idx*32 + reg_id * 8 + min_tile_m + ((lane_id / 16) * 2)) < params.seqlen_q)) { \
int offset = ds_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ ((lane_id & 15) << 1) + min_tile_m*params.seqlen_k + ((lane_id / 16) * 2)*params.seqlen_k + min_tile_n ;\
ds_ptr[offset + reg_id * 8 * params.seqlen_k] = (dS_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f32[reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_ds_fp16(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int ds_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int ds_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int ds_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_id*WARP_N_ + ds_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + ds_warp_n_id*WARP_N_ + n_idx * 32 + ((lane_id & 15) << 1) + min_tile_n) < params.seqlen_k) && \
((block_id_m*kBlockM_ + ds_warp_m_id*WARP_M_ + m_idx*32 + reg_id * 8 + min_tile_m + ((lane_id / 16) * 2)) < params.seqlen_q)) { \
int offset = ds_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ ((lane_id & 15) << 1) + min_tile_m*params.seqlen_k + ((lane_id / 16) * 2)*params.seqlen_k + min_tile_n ;\
ds_ptr[offset + reg_id * 8 * params.seqlen_k] = UpCast<Element,float>(dS_reg_fp16[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f16[reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_dp(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int dp_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int dp_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int dp_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + dp_warp_n_id*WARP_N_ + dp_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) {\
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + dp_warp_n_id*WARP_N_ + n_idx * 32 + ((lane_id & 15) << 1) + min_tile_n) < params.seqlen_k) && \
((block_id_m*kBlockM_ + dp_warp_m_id*WARP_M_ + m_idx*32 + reg_id * 8 + min_tile_m + ((lane_id / 16) * 2)) < params.seqlen_q)) { \
int offset = dp_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ ((lane_id & 15) << 1) + min_tile_m*params.seqlen_k + ((lane_id / 16) * 2)*params.seqlen_k + min_tile_n ;\
dp_ptr[offset + reg_id * 8 * params.seqlen_k] = (dp_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32) + m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f32[reg_id]);\
} \
} \
} \
} \
} \
}\
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
template<class Element, class ElementAccumType, bool Is_dropout, bool Is_causal , bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, int kBlockM_, int kBlockN_, int K, int K_v, int kBlockK_, int WARP_M_, int WARP_N_, bool USE_BSHD_LAYOUT, typename Params>
__forceinline__ __device__ void compute_dk_dv_1colblock(Params &params, int bidb, int bidh, int n_block
) {
#ifdef DEBUGING
ElementAccumType * kq_ptr = static_cast<ElementAccumType*>(params.kq_ptr);
ElementAccumType * s_ptr = static_cast<ElementAccumType*>(params.s_ptr);
ElementAccumType * dp_ptr = static_cast<ElementAccumType*>(params.dp_ptr);
ElementAccumType * ds_ptr = static_cast<ElementAccumType*>(params.ds_ptr);
#endif
Element* q_ptr = static_cast<Element*>(params.q_ptr);
Element* k_ptr = static_cast<Element*>(params.k_ptr);
Element* v_ptr = static_cast<Element*>(params.v_ptr);
Element* o_ptr = static_cast<Element*>(params.o_ptr);
Element* p_ptr = static_cast<Element*>(params.p_ptr);
// Element* dq_ptr = static_cast<Element*>(params.dq_ptr);
Element* dk_ptr = static_cast<Element*>(params.dk_ptr);
Element* dv_ptr = static_cast<Element*>(params.dv_ptr);
Element* do_ptr = static_cast<Element*>(params.do_ptr);
ElementAccumType* softmax_lse_ptr = static_cast<ElementAccumType*>(params.softmax_lse_ptr);
ElementAccumType* dsoftmax_sum = static_cast<ElementAccumType*>(params.dsoftmax_sum);
//flash-attention QK, kBlockN_==WARP_N_;
// static_assert(kBlockM_=WARP_M_,"Error: kBlockM_ not equal WARP_M_!");
const int WARP_NUM = (kBlockM_*kBlockN_)/(WARP_M_*WARP_N_);
const int M_BLOCK_NUM = params.seqlen_q/kBlockM_;
const int N_BLOCK_NUM = params.seqlen_k/kBlockN_;
extern __shared__ Element smem[];
int K_lds_ratio;
const int K_prefetch_level = 3;
const int STAGES = 2;
const bool Is_store_Q = true;
const bool Is_store_dO = true;
const bool Is_preload_Q = true;
const bool Is_preload_dO = true;
const int dP_dO_prefetch_level = Is_store_dO ? 1 : 0;
const int Q_prefetech_level = Is_preload_Q ? 1 : 0;
if constexpr (K_prefetch_level == 2){
K_lds_ratio = (K / kBlockK_) / 2;
} else {
K_lds_ratio = (K_prefetch_level == 3) ? 0 : STAGES;
}
// Element* K_lds = (Element*)&(smem);
// Element* Q_lds = K_lds + (kBlockN_/32) * (kBlockK_/32)*(32*34) * K_lds_ratio;
// Element* V_lds = K_prefetch_level == 2 ? Q_lds : K_lds;
// Element* dO_lds = Q_lds;
Element* K_lds = (Element*)&(smem);
Element* dO_lds = K_lds + (kBlockN_/32) * (kBlockK_/32)*(32*34) * K_lds_ratio;
Element* V_lds = K_prefetch_level == 2 ? dO_lds : K_lds;
Element* Q_lds = Is_store_Q ? dO_lds + (kBlockM_/32) * (K_v/32)*(32*34) : dO_lds;
#if 0//defined(__gfx936__)
auto pointwise_mult = [](vec2_fp32 p, vec2_fp32 dp, vec2_fp32 d) {
auto d0 = (!Is_dropout || p[0] >= 0 ? dp[0] - d[0] : d[0]);
auto d1 = (!Is_dropout || p[1] >= 0 ? dp[1] - d[1] : d[1]);
// return vec2_fp32{p[0]*d0,p[1]*d1};
// return __builtin_hcu_pk_mul_f32(p, vec2_fp32{d0, d1});
return __builtin_hcu_v_pk_mul_f32(p, vec2_fp32{d0, d1});
};
#else
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
};
#endif
int tidx = threadIdx.x;
int lane_id = threadIdx.x & 63; //lane id, 0-63
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (n_block * kBlockN_ >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0) return;
const int m_block_min = (!Is_causal && !Is_local) ? 0 : std::max(0, (n_block * kBlockN_ - params.window_size_right) / kBlockM_);
const int m_block_max = !Is_local ? ceil_div(binfo.actual_seqlen_q, kBlockM_) : std::min(ceil_div(binfo.actual_seqlen_q, kBlockM_), ceil_div((n_block + 1) * kBlockN_ + params.window_size_left, kBlockM_));
int seqlen_q_stride = params.q_row_stride;
int seqlen_k_stride = params.k_row_stride;
int seqlen_v_stride = params.v_row_stride;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
int seqlen_dk_stride = params.dk_row_stride;
int seqlen_dv_stride = params.dv_row_stride;
// We move K and V to the last block.
const int row_offset_q = binfo.q_offset1(params.q_batch_stride, params.q_row_stride, bidb) + binfo.q_offset2(params.q_head_stride,bidh) + (m_block_max - 1) * kBlockM_* seqlen_q_stride;
const int row_offset_k = binfo.k_offset1(params.k_batch_stride, params.k_row_stride, bidb) + binfo.k_offset2(params.k_head_stride,bidh/params.h_h_k_ratio) + n_block * kBlockN_ * seqlen_k_stride;
const int row_offset_v = binfo.k_offset1(params.v_batch_stride, params.v_row_stride, bidb) + binfo.k_offset2(params.v_head_stride,bidh/params.h_h_k_ratio) + n_block * kBlockN_ * seqlen_v_stride;
const int row_offset_dO = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + (m_block_max - 1) * kBlockM_ * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + (m_block_max - 1) * kBlockM_ * seqlen_o_stride;
// const int row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + (m_block_max - 1) * kBlockM_;
const int row_offset_lse = params.cu_seqlens_q == nullptr ? (bidb * params.h + bidh) * params.seqlen_q + (m_block_max - 1) * kBlockM_ : bidh * params.total_q + binfo.sum_s_q + (m_block_max - 1) * kBlockM_;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + (m_block_max - 1) * kBlockM_;
// Element * gQ = reinterpret_cast<Element *>(q_ptr) + row_offset_q;
auto gQ = tcp_cache_swizzle_func<K, Element>(reinterpret_cast<Element *>(q_ptr) + row_offset_q);
// Element * gK = reinterpret_cast<Element *>(k_ptr) + row_offset_k;
auto gK = tcp_cache_swizzle_func<K, Element>(reinterpret_cast<Element *>(k_ptr) + row_offset_k);
// Element * gV = reinterpret_cast<Element *>(v_ptr) + row_offset_v;
auto gV = tcp_cache_swizzle_func<K_v, Element>(reinterpret_cast<Element *>(v_ptr) + row_offset_v);
// Element * gdO = reinterpret_cast<Element *>(do_ptr) + row_offset_dO;
auto gdO = tcp_cache_swizzle_func<K_v, Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_dO);
Element * gO = reinterpret_cast<Element *>(o_ptr) + row_offset_o;
// Element * gdQ = reinterpret_cast<Element *>(dq_ptr) + row_offset_dq;
ElementAccumType *gLSE = reinterpret_cast<ElementAccumType *>(softmax_lse_ptr) + row_offset_lse;
ElementAccumType *gdPsum = reinterpret_cast<ElementAccumType *>(dsoftmax_sum) + row_offset_dpsum;
constexpr int m_masking_steps = (!Is_causal && !Is_local)
? 0
: flash::ceil_div(kBlockN_, kBlockM_);
/***************************************************************************************************************************/
// int warp_id =0;
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
union_vec2_f16x2<Element> k_reg[(K/kBlockK_)*((WARP_N_*kBlockK_)/(32*32))*2/((K_prefetch_level == 3)? 1 : 2)][2]; //ds_read mini size is 32*32,2 is seq, 4 is head dim
union_vec2_f16x2<Element> v_reg[(K_v/kBlockK_)*((WARP_N_*kBlockK_)/(32*32))*2][2];
__builtin_amdgcn_sched_barrier(0);
prefetch_to_vgpr<K_v, kBlockN_, kBlockK_, WARP_N_, Element, ElementAccumType, Is_even_MN>(gV, V_lds, v_reg, (binfo.actual_seqlen_k - n_block * kBlockN_), seqlen_v_stride);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
prefetch_to_vgpr<K, kBlockN_, kBlockK_, WARP_N_, Element, ElementAccumType, Is_even_MN>(gK, K_lds, k_reg, (binfo.actual_seqlen_k - n_block * kBlockN_), seqlen_k_stride);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
if constexpr (Is_preload_Q){
prefetch_to_tmp_lds_wait<Is_even_MN, K, kBlockN_, kBlockM_, kBlockK_, WARP_N_, WARP_M_, Element>(gQ, Q_lds, (binfo.actual_seqlen_q - (m_block_max - 1) * kBlockM_), warp_id, seqlen_q_stride);
}
if constexpr (Is_preload_dO){
prefetch_to_tmp_lds_wait<Is_even_MN, K_v, kBlockN_, kBlockM_, kBlockK_, WARP_N_, WARP_M_, Element>(gdO, dO_lds, (binfo.actual_seqlen_q - (m_block_max - 1) * kBlockM_), warp_id, seqlen_do_stride);
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
union_vec4_fp32 acc_dv[(K_v/kBlockK_) * ((WARP_N_/32)*(kBlockK_/32))][4]={0};
union_vec4_fp32 acc_dk[(K/kBlockK_) * ((WARP_N_/32)*(kBlockK_/32))][4]={0};
for (int m_block = m_block_max - 1; m_block >= m_block_min; --m_block) {
union_vec2_f16x2<Element> q_reg[((WARP_M_*kBlockK_)/(32*32))*2][2];
// int warp_id =0;
int warp_id_vec = threadIdx.x / 64; //warp id in a block
warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
//c mini tile is 32*32
union_vec4_fp32 s_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
//qk gemm
gemm_tt_kq<Is_store_Q, Is_preload_dO, Is_even_MN, K_prefetch_level, Q_prefetech_level, K, kBlockN_, kBlockM_, kBlockK_, WARP_N_, WARP_M_, STAGES, Element>(gK, gQ, K_lds, Q_lds, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), k_reg, q_reg, s_reg, warp_id, seqlen_k_stride, seqlen_q_stride);
float lse[kBlockM_/4];
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
#pragma unroll
for (int mi = 0; mi < (kBlockM_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
const int lse_idx = mi*32 + vec_idx*8 + (lane_id >> 4)*2 + min_tile_m;
lse[(mi*2 + min_tile_m)*4 + vec_idx] = Is_even_MN || lse_idx < binfo.actual_seqlen_q - m_block * kBlockM_ ? gLSE[lse_idx] : INFINITY;
}
}
}
apply_mask_bwd<Is_even_MN, Is_local ? 3 : (Is_causal ? 2 : 0)>(s_reg, binfo.actual_seqlen_k - n_block * kBlockN_ - warp_id * 32, binfo.actual_seqlen_q - m_block * kBlockM_, (n_block * kBlockN_ + warp_id * 32) - m_block * kBlockM_, params.window_size_right, params.window_size_left);
#ifdef DEBUGING
print_kq(m_block, bidb, bidh);
#endif
float dP_sum_reg[kBlockM_/4];
#pragma unroll
for (int vec_idx = 0; vec_idx < (kBlockM_/8); ++vec_idx) {
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
dP_sum_reg[vec_idx*2 + min_tile_m] = gdPsum[vec_idx*8 + ((lane_id >> 4)*2) + min_tile_m];
}
}
{
scale_apply_exp2_bwd</*scale_max=*/false, kBlockM_, WARP_N_>(s_reg, lse, params.scale_softmax_log2);
}
#ifdef DEBUGING
print_softmax_rescale_o(m_block, bidb, bidh);
#endif
// //TODO:drop
union_vec2_f16x2<Element> p_reg[(kBlockM_/32)*(WARP_N_/32)][4];
convert_pk_type<kBlockM_, WARP_N_, Element>(p_reg, s_reg);
//QK(seq_q, seq_kv), seq_q is continuous, seq_kv is not continuous
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
{
//dv gemm,dO*P
gpu_gemm_B_in_reg<Is_preload_dO, Is_store_dO, false, Is_even_MN, K_v, kBlockK_, kBlockN_, kBlockM_, kBlockK_, WARP_N_, 2, Element>(gdO, gQ, dO_lds, p_reg, acc_dv, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_do_stride);
}
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
union_vec2_f16x2<Element> dO_reg[((WARP_M_*kBlockK_)/(32*32))*2][2];
union_vec4_fp32 dp_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
{
// dP gemm
gemm_tt_kq<Is_store_dO, false, Is_even_MN, 3, dP_dO_prefetch_level, K_v, kBlockN_, kBlockM_, kBlockK_, WARP_N_, WARP_M_, STAGES, Element>(
gV, gdO, V_lds, dO_lds, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), v_reg, dO_reg, dp_reg, warp_id, seqlen_v_stride, seqlen_do_stride);
}
#ifdef DEBUGING
print_dp(m_block, bidb, bidh);
#endif
union_vec4_fp32 dS_reg[(WARP_N_/32)*(kBlockM_/32)][4];
#pragma unroll
for (int mi = 0; mi < (kBlockM_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for (int ni = 0; ni < (WARP_N_/32); ++ni) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#if 0//defined(__gfx936__)
#pragma unroll
for(int vec_idx=0; vec_idx<2; vec_idx++) {
// result register ds_reg reuse dp_reg
dS_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx] = pointwise_mult(
s_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx],
dp_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx],
vec2_fp32{gdPsum[vec_idx*16 + mi*8*4 + ((lane_id >> 4)*2) + min_tile_m], gdPsum[vec_idx*16 + mi*8*4 + ((lane_id >> 4)*2) + min_tile_m + 8]});
}
#else
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
// result register ds_reg reuse dp_reg
// if((m_block*kBlockM_ + vec_idx * 8 + min_tile_m + ((lane_id / 16) * 2)) < params.seqlen_q){
dS_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].f32[vec_idx] = pointwise_mult(s_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].f32[vec_idx], dp_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].f32[vec_idx], dP_sum_reg[vec_idx*2 + min_tile_m]);
// }
// else{
// dS_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].f32[vec_idx] = 0;
// }
}
#endif
}
}
}
}
#ifdef DEBUGING
print_ds(m_block, bidb, bidh);
#endif
union_vec2_f16x2<Element> dS_reg_fp16[(WARP_N_/32)*(kBlockM_/32)][4];
convert_pk_type<kBlockM_, WARP_N_, Element>(dS_reg_fp16, dS_reg);
// #ifdef DEBUGING
// print_ds_fp16(m_block, bidb, bidh);
// #endif
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
{
//dk gemm, Q*dS
gpu_gemm_B_in_reg<Is_store_Q , false , false, Is_even_MN, K, kBlockK_, kBlockN_, kBlockM_, kBlockK_, WARP_N_, 2, Element>(gQ, gdO, Q_lds, dS_reg_fp16, acc_dk, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_q_stride);
}
gLSE = gLSE + (-int(kBlockM_));
gdPsum = gdPsum - kBlockM_;
*(uint64_t*)&gQ -= ((kBlockM_ * seqlen_q_stride) * sizeof(Element));
*(uint64_t*)&gdO -= ((kBlockM_ * seqlen_do_stride) * sizeof(Element));
{
__syncthreads();
if (Is_preload_Q && m_block > m_block_min){
prefetch_to_tmp_lds_wait<Is_even_MN, K, kBlockN_, kBlockM_, kBlockK_, WARP_N_, WARP_M_, Element>(gQ, Q_lds, (binfo.actual_seqlen_q - (m_block - 1) * kBlockM_), warp_id, seqlen_q_stride);
}
// __syncthreads();
if (Is_preload_dO && m_block > m_block_min){
prefetch_to_tmp_lds_wait<Is_even_MN, K_v, kBlockN_, kBlockM_, kBlockK_, WARP_N_, WARP_M_, Element>(gdO, dO_lds, (binfo.actual_seqlen_q - (m_block - 1) * kBlockM_), warp_id, seqlen_do_stride);
}
}
}
{
// dv_ptr = dv_ptr + binfo.k_offset1(params.dv_batch_stride, params.dv_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dv_head_stride,bidh);
dv_ptr = dv_ptr + binfo.k_offset1_write(params.dv_batch_stride, params.dv_row_stride, bidb) + binfo.k_offset2(params.dv_head_stride,bidh);
auto gdV = tcp_cache_swizzle_func<K_v, Element>(dv_ptr);
int dv_lane_seq_idx = (lane_id >> 4);
int dv_lane_head_dim_idx = (lane_id & 15);
int dv_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K_v/kBlockK_); k_loop++) {
#pragma unroll
for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
int v_offset = (dv_lane_head_dim_idx*2) + (dv_lane_seq_idx*2 * seqlen_dv_stride);
int s_offset = (min_tile_n*seqlen_dv_stride + vec_index * 8 * seqlen_dv_stride) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dv_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dv_stride);
int known_offset = 0;
vec2_Element<Element> v_data;
v_data[0] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2].f32[vec_index]);
v_data[1] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 1].f32[vec_index]);
if (Is_even_MN || n_block * kBlockN_ + warp_id*WARP_N_ + warp_n_idx*32 + dv_lane_seq_idx*2 + min_tile_n + vec_index * 8 < binfo.actual_seqlen_k){
inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdV, s_offset, /* immediate integer */known_offset);
}
}
}
}
}
}
}
{
// dk_ptr = dk_ptr + binfo.k_offset1(params.dk_batch_stride, params.dk_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dk_head_stride,bidh);
dk_ptr = dk_ptr + binfo.k_offset1_write(params.dk_batch_stride, params.dk_row_stride, bidb) + binfo.k_offset2(params.dk_head_stride,bidh);
auto gdK = tcp_cache_swizzle_func<K, Element>(dk_ptr);
int dk_lane_seq_idx = (lane_id >> 4);
int dk_lane_head_dim_idx = (lane_id & 15);
int dk_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K/kBlockK_); k_loop++) {
#pragma unroll
for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
vec2_Element<Element> v_data;
int v_offset = dk_lane_head_dim_idx*2 + (dk_lane_seq_idx*2) * seqlen_dk_stride;
int s_offset = n_block * kBlockN_ * seqlen_dk_stride + (warp_id*WARP_N_) * seqlen_dk_stride + (min_tile_n*seqlen_dk_stride + vec_index * 8 * seqlen_dk_stride + k_tile_idx*32 + k_loop * kBlockK_ + warp_n_idx*32);
int known_offset = 0;
v_data[0] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2].f32[vec_index] * params.scale_softmax_rp_dropout);
v_data[1] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 1].f32[vec_index] * params.scale_softmax_rp_dropout);
if (Is_even_MN || n_block * kBlockN_ + warp_id*WARP_N_ + dk_lane_seq_idx*2 + min_tile_n + vec_index * 8 < binfo.actual_seqlen_k){
inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdK, s_offset, /* immediate integer */known_offset);
}
}
}
}
}
}
}
}
#undef print_kq
#undef print_dq
#undef print_softmax_rescale_o
#undef print_ds
#undef print_ds_fp16
#undef print_dp
#define print_kq(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int qk_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int qk_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_id*WARP_N_ + qk_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + qk_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + qk_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = qk_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
kq_ptr[offset + reg_id *params.seqlen_k] = (s_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_m*2 + min_tile_n].f32[reg_id]); \
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_softmax_rescale_o(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int s_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int s_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int s_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + s_warp_n_id*WARP_N_ + s_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + s_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + s_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = s_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
s_ptr[offset + reg_id * params.seqlen_k] = (s_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f32[reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_ds(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int ds_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int ds_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int ds_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_id*WARP_N_ + ds_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + ds_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + ds_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = ds_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
ds_ptr[offset + reg_id * params.seqlen_k] = (dS_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f32[reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_ds_fp16(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int ds_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int ds_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int ds_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_id*WARP_N_ + ds_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + ds_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + ds_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = ds_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
ds_ptr[offset + reg_id * 8 * params.seqlen_k] = UpCast<Element,float>(dS_reg_fp16[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f16[reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_dp(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int dp_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int dp_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int dp_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + dp_warp_n_id*WARP_N_ + dp_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) {\
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + dp_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + dp_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = dp_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
dp_ptr[offset + reg_id * params.seqlen_k] = (dp_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32) + m_idx*(WARP_N_/32) + n_idx][min_tile_m*2 + min_tile_n].f32[reg_id]);\
} \
} \
} \
} \
} \
}\
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
/*
load q/k:累加方向为主序方向
ps: 在offset传0的情况下,T和R的取值似乎没有影响!?
调用matrix_load_32x32_b16:
R=0: offset in column direction
load Q: T=1: row major
load K: T=0: column major
m_ab=1: 线程数据在主序方向拼接
调用ds_read_matrix_trans_format(和m_ab保持一致):
element:0x2 row:0x2 col:0x1 alt:0x0
load v:累加方向为非主序方向
调用matrix_load_32x32_b16:
R=0: offset in column direction
T=1: row major
m_ab=0: 线程数据在非主序方向拼接
调用ds_read_matrix_format(和m_ab保持一致)
*/
template<class Element, class ElementAccum, bool Is_dropout, bool Is_causal , bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, int kBlockM_, int kBlockN_, int K, int K_v, int kBlockK_, int WARP_M_, int WARP_N_, bool USE_BSHD_LAYOUT, typename Params>
__forceinline__ __device__ void compute_dk_dv_1colblock_gfx938(Params &params, int bidb, int bidh, int n_block
) {
#ifdef DEBUGING
ElementAccum * kq_ptr = static_cast<ElementAccum*>(params.kq_ptr);
ElementAccum * s_ptr = static_cast<ElementAccum*>(params.s_ptr);
ElementAccum * dp_ptr = static_cast<ElementAccum*>(params.dp_ptr);
ElementAccum * ds_ptr = static_cast<ElementAccum*>(params.ds_ptr);
#endif
Element* q_ptr = static_cast<Element*>(params.q_ptr);
Element* k_ptr = static_cast<Element*>(params.k_ptr);
Element* v_ptr = static_cast<Element*>(params.v_ptr);
Element* o_ptr = static_cast<Element*>(params.o_ptr);
Element* p_ptr = static_cast<Element*>(params.p_ptr);
// Element* dq_ptr = static_cast<Element*>(params.dq_ptr);
Element* dk_ptr = static_cast<Element*>(params.dk_ptr);
Element* dv_ptr = static_cast<Element*>(params.dv_ptr);
Element* do_ptr = static_cast<Element*>(params.do_ptr);
ElementAccum* softmax_lse_ptr = static_cast<ElementAccum*>(params.softmax_lse_ptr);
ElementAccum* dsoftmax_sum = static_cast<ElementAccum*>(params.dsoftmax_sum);
//flash-attention QK, kBlockN_==WARP_N_;
// static_assert(kBlockM_=WARP_M_,"Error: kBlockM_ not equal WARP_M_!");
const int WARP_NUM = (kBlockM_*kBlockN_)/(WARP_M_*WARP_N_);
const int M_BLOCK_NUM = params.seqlen_q/kBlockM_;
const int N_BLOCK_NUM = params.seqlen_k/kBlockN_;
extern __shared__ Element smem[];
int K_lds_ratio;
// 0表示k不预取;1表示k预取一半到寄存器;2表示一半到寄存器,一半到LDS;3表示全部预取到寄存器
const int K_prefetch_level = 3;
const int STAGES = 2;
const bool Is_store_Q = true;
const bool Is_store_dO = true;
const bool Is_preload_Q = true;
const bool Is_preload_dO = true;
const int dP_dO_prefetch_level = Is_store_dO ? 1 : 0;
const int Q_prefetech_level = Is_preload_Q ? 1 : 0;
if constexpr (K_prefetch_level == 2){
K_lds_ratio = (K / kBlockK_) / 2;
} else {
K_lds_ratio = (K_prefetch_level == 3) ? 0 : STAGES;
}
Element* K_lds = (Element*)&(smem);
Element* dO_lds = K_lds + kBlockN_ * kBlockK_ * K_lds_ratio;
Element* V_lds = K_prefetch_level == 2 ? dO_lds : K_lds;
Element* Q_lds = Is_store_Q ? dO_lds + kBlockM_ * K_v : dO_lds;
#if 0//defined(__gfx938__)
auto pointwise_mult = [](vec2_fp32 p, vec2_fp32 dp, vec2_fp32 d) {
auto d0 = (!Is_dropout || p[0] >= 0 ? dp[0] - d[0] : d[0]);
auto d1 = (!Is_dropout || p[1] >= 0 ? dp[1] - d[1] : d[1]);
// return vec2_fp32{p[0]*d0,p[1]*d1};
// return __builtin_hcu_pk_mul_f32(p, vec2_fp32{d0, d1});
return __builtin_hcu_pk_mul_f32(p, vec2_fp32{d0, d1});
};
#else
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
};
#endif
int tidx = threadIdx.x;
int lane_id = threadIdx.x & 63; //lane id, 0-63
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (n_block * kBlockN_ >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0) return;
const int m_block_min = (!Is_causal && !Is_local) ? 0 : std::max(0, (n_block * kBlockN_ - params.window_size_right) / kBlockM_);
const int m_block_max = !Is_local ? ceil_div(binfo.actual_seqlen_q, kBlockM_) : std::min(ceil_div(binfo.actual_seqlen_q, kBlockM_), ceil_div((n_block + 1) * kBlockN_ + params.window_size_left, kBlockM_));
int seqlen_q_stride = params.q_row_stride;
int seqlen_k_stride = params.k_row_stride;
int seqlen_v_stride = params.v_row_stride;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
int seqlen_dk_stride = params.dk_row_stride;
int seqlen_dv_stride = params.dv_row_stride;
// We move K and V to the last block.
const int row_offset_q = binfo.q_offset1(params.q_batch_stride, params.q_row_stride, bidb) + binfo.q_offset2(params.q_head_stride,bidh) + (m_block_max - 1) * kBlockM_* seqlen_q_stride;
const int row_offset_k = binfo.k_offset1(params.k_batch_stride, params.k_row_stride, bidb) + binfo.k_offset2(params.k_head_stride,bidh/params.h_h_k_ratio) + n_block * kBlockN_ * seqlen_k_stride;
const int row_offset_v = binfo.k_offset1(params.v_batch_stride, params.v_row_stride, bidb) + binfo.k_offset2(params.v_head_stride,bidh/params.h_h_k_ratio) + n_block * kBlockN_ * seqlen_v_stride;
const int row_offset_dO = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + (m_block_max - 1) * kBlockM_ * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + (m_block_max - 1) * kBlockM_ * seqlen_o_stride;
// const int row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + (m_block_max - 1) * kBlockM_;
const int row_offset_lse = params.cu_seqlens_q == nullptr ? (bidb * params.h + bidh) * params.seqlen_q + (m_block_max - 1) * kBlockM_ : bidh * params.total_q + binfo.sum_s_q + (m_block_max - 1) * kBlockM_;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + (m_block_max - 1) * kBlockM_;
auto gQ = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(q_ptr) + row_offset_q, seqlen_q_stride);
auto gK = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(k_ptr) + row_offset_k, seqlen_k_stride);
auto gV = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(v_ptr) + row_offset_v, seqlen_v_stride);
auto gdO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_dO, seqlen_do_stride);
Element * gO = reinterpret_cast<Element *>(o_ptr) + row_offset_o;
ElementAccum *gLSE = reinterpret_cast<ElementAccum *>(softmax_lse_ptr) + row_offset_lse;
ElementAccum *gdPsum = reinterpret_cast<ElementAccum *>(dsoftmax_sum) + row_offset_dpsum;
constexpr int m_masking_steps = (!Is_causal && !Is_local)
? 0
: flash::ceil_div(kBlockN_, kBlockM_);
/***************************************************************************************************************************/
// int warp_id =0;
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
union_vec4_f16x2<Element> k_reg[(K/kBlockK_)*((WARP_N_*kBlockK_)/(32*32))*2/((K_prefetch_level == 3)? 1 : 2)]; //ds_read mini size is 32*32,2 is seq, 4 is head dim
union_vec4_f16x2<Element> v_reg[(K_v/kBlockK_)*((WARP_N_*kBlockK_)/(32*32))*2];
//提前读取V到vgpr
prefetch_to_vgpr_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gV, V_lds, v_reg, (binfo.actual_seqlen_k - n_block * kBlockN_), warp_id);
//提前读取K到vgpr
prefetch_to_vgpr_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gK, K_lds, k_reg, (binfo.actual_seqlen_k - n_block * kBlockN_), warp_id);
//提前读取Q到lds
if constexpr (Is_preload_Q){
prefetch_to_lds_gfx938<true, kBlockM_, K, Element, ElementAccum, Is_even_MN>(gQ, 0, Q_lds, (binfo.actual_seqlen_q - (m_block_max - 1) * kBlockM_), warp_id);
}
//提前读取dO到lds
if constexpr (Is_preload_dO){
prefetch_to_lds_gfx938<true, kBlockM_, K_v, Element, ElementAccum, Is_even_MN>(gdO, 0, dO_lds, (binfo.actual_seqlen_q - (m_block_max - 1) * kBlockM_), warp_id);
}
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
union_vec4_fp32 acc_dv[(K_v/kBlockK_) * ((WARP_N_/32)*(kBlockK_/32))][4]={0};
union_vec4_fp32 acc_dk[(K/kBlockK_) * ((WARP_N_/32)*(kBlockK_/32))][4]={0};
for (int m_block = m_block_max - 1; m_block >= m_block_min; --m_block) {
union_vec4_f16x2<Element> q_reg[((WARP_M_*kBlockK_)/(32*32))*2];
//c mini tile is 32*32
union_vec4_fp32 s_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
/*
qk gemm
结果矩阵layout:
0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48 0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48
1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49 1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49
...
0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48 0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48
1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49 1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49
*/
gemm_tt_kq_gfx938<Is_store_Q, Is_preload_dO, Is_even_MN, K_prefetch_level, Q_prefetech_level, K, kBlockN_, kBlockM_, kBlockK_, WARP_N_, WARP_M_, STAGES, Element>(
gK, gQ, K_lds, Q_lds, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), k_reg, q_reg, s_reg, warp_id, seqlen_k_stride, seqlen_q_stride);
/*
lse layout:
4 warp:
32
32
32
32
因为warp在seqlen_k维度,所以不区分warp
每16个thread持有相同的lse,所以需要/4
*/
float lse[kBlockM_/4];
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
#pragma unroll
for (int mi = 0; mi < (kBlockM_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
const int lse_idx = mi*32 + min_tile_m * 16 + (lane_id >> 4)*4 + vec_idx;
lse[(mi*2 + min_tile_m)*4 + vec_idx] = Is_even_MN || lse_idx < binfo.actual_seqlen_q - m_block * kBlockM_ ? gLSE[lse_idx] : INFINITY;
}
}
}
apply_mask_bwd_gfx938<Is_even_MN, Is_local ? 3 : (Is_causal ? 2 : 0)>(s_reg, binfo.actual_seqlen_k - n_block * kBlockN_ - warp_id * 32, binfo.actual_seqlen_q - m_block * kBlockM_, (n_block * kBlockN_ + warp_id * 32) - m_block * kBlockM_, params.window_size_right, params.window_size_left);
#ifdef DEBUGING
print_kq(m_block, bidb, bidh);
#endif
//do . o后在headdim维度reduce求和,读取方式和lse一样,因为pad了,所以无需边界判断
float dP_sum_reg[kBlockM_/4];
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
#pragma unroll
for (int mi = 0; mi < (kBlockM_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
const int dPsum_idx = mi*32 + min_tile_m * 16 + (lane_id >> 4)*4 + vec_idx;
dP_sum_reg[(mi*2 + min_tile_m)*4 + vec_idx] = gdPsum[dPsum_idx];
}
}
}
{
scale_apply_exp2_bwd</*scale_max=*/false, kBlockM_, WARP_N_>(s_reg, lse, params.scale_softmax_log2);
}
#ifdef DEBUGING
print_softmax_rescale_o(m_block, bidb, bidh);
#endif
// //TODO:drop
union_vec4_f16x2<Element> p_reg[(kBlockM_/32)*(WARP_N_/32)*2];
// convert_pk_type<kBlockM_, WARP_N_, Element>(p_reg, s_reg);
convert_pk_type_gfx938<kBlockM_, WARP_N_, Element>(p_reg, s_reg);
//QK(seq_q, seq_kv), seq_q is continuous, seq_kv is not continuous
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
{
//dv gemm,dO*P
gpu_gemm_B_in_reg_gfx938<Is_preload_dO, Is_store_dO, Is_even_MN, K_v, kBlockK_, kBlockN_, kBlockM_, kBlockK_, WARP_N_, 2, Element>(gdO, gQ, dO_lds, p_reg, acc_dv, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_do_stride);
}
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
union_vec4_f16x2<Element> dO_reg[((WARP_M_*kBlockK_)/(32*32))*2];
union_vec4_fp32 dp_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
{
// dP gemm dO * V
gemm_tt_kq_gfx938<Is_store_dO, false, Is_even_MN, 3, dP_dO_prefetch_level, K_v, kBlockN_, kBlockM_, kBlockK_, WARP_N_, WARP_M_, STAGES, Element>(
gV, gdO, V_lds, dO_lds, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), v_reg, dO_reg, dp_reg, warp_id, seqlen_v_stride, seqlen_do_stride);
}
#ifdef DEBUGING
print_dp(m_block, bidb, bidh);
#endif
union_vec4_fp32 dS_reg[(WARP_N_/32)*(kBlockM_/32)][4];
#pragma unroll
for (int mi = 0; mi < (kBlockM_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for (int ni = 0; ni < (WARP_N_/32); ++ni) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#if 0//defined(__gfx938__)
#pragma unroll
for(int vec_idx=0; vec_idx<2; vec_idx++) {
// result register ds_reg reuse dp_reg
dS_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx] = pointwise_mult(
s_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx],
dp_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx],
vec2_fp32{dP_sum_reg[min_tile_m*4 + vec_idx * 2], dP_sum_reg[min_tile_m*4 + vec_idx * 2 + 1]});
}
#else
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
// result register ds_reg reuse dp_reg
dS_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].f32[vec_idx] = pointwise_mult(
s_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].f32[vec_idx],
dp_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].f32[vec_idx],
dP_sum_reg[min_tile_m*4 + vec_idx]);
}
#endif
}
}
}
}
#ifdef DEBUGING
print_ds(m_block, bidb, bidh);
#endif
union_vec4_f16x2<Element> dS_reg_fp16[(WARP_N_/32)*(kBlockM_/32)*2];
convert_pk_type_gfx938<kBlockM_, WARP_N_, Element>(dS_reg_fp16, dS_reg);
// #ifdef DEBUGING
// print_ds_fp16(m_block, bidb, bidh);
// #endif
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
{
//dk gemm, Q*dS
gpu_gemm_B_in_reg_gfx938<Is_store_Q , false, Is_even_MN, K, kBlockK_, kBlockN_, kBlockM_, kBlockK_, WARP_N_, 2, Element>(gQ, gdO, Q_lds, dS_reg_fp16, acc_dk, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_q_stride);
}
gLSE = gLSE + (-int(kBlockM_));
gdPsum = gdPsum - kBlockM_;
*(uint64_t*)&gQ -= ((kBlockM_ * seqlen_q_stride) * sizeof(Element));
*(uint64_t*)&gdO -= ((kBlockM_ * seqlen_do_stride) * sizeof(Element));
{
__syncthreads();
if (Is_preload_Q && m_block > m_block_min){
prefetch_to_lds_gfx938<true, kBlockM_, K, Element, ElementAccum, Is_even_MN>(gQ, 0, Q_lds, (binfo.actual_seqlen_q - (m_block - 1) * kBlockM_), warp_id);
}
// __syncthreads();
if (Is_preload_dO && m_block > m_block_min){
prefetch_to_lds_gfx938<true, kBlockM_, K_v, Element, ElementAccum, Is_even_MN>(gdO, 0, dO_lds, (binfo.actual_seqlen_q - (m_block - 1) * kBlockM_), warp_id);
}
}
}
{
// dv_ptr = dv_ptr + binfo.k_offset1(params.dv_batch_stride, params.dv_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dv_head_stride,bidh);
dv_ptr = dv_ptr + binfo.k_offset1_write(params.dv_batch_stride, params.dv_row_stride, bidb) + binfo.k_offset2(params.dv_head_stride,bidh);
auto gdV = tcp_cache_swizzle_func<K_v, Element>(dv_ptr);
int dv_lane_seq_idx = (lane_id >> 4);
int dv_lane_head_dim_idx = (lane_id & 15);
int dv_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K_v/kBlockK_); k_loop++) {
#pragma unroll
for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
int v_offset = dv_lane_head_dim_idx*2 + dv_lane_seq_idx * seqlen_dv_stride;
int s_offset = (min_tile_n*seqlen_dv_stride*16 + vec_index * 4 * seqlen_dv_stride) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dv_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dv_stride);
int known_offset = 0;
vec2_Element<Element> v_data;
v_data[0] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2].f32[vec_index]);
v_data[1] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 1].f32[vec_index]);
if (Is_even_MN || n_block * kBlockN_ + warp_id*WARP_N_ + warp_n_idx*32 + dv_lane_seq_idx + min_tile_n*16 + vec_index * 4 < binfo.actual_seqlen_k){
inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdV, s_offset, /* immediate integer */known_offset);
}
}
}
}
}
}
}
{
// dk_ptr = dk_ptr + binfo.k_offset1(params.dk_batch_stride, params.dk_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dk_head_stride,bidh);
dk_ptr = dk_ptr + binfo.k_offset1_write(params.dk_batch_stride, params.dk_row_stride, bidb) + binfo.k_offset2(params.dk_head_stride,bidh);
auto gdK = tcp_cache_swizzle_func<K, Element>(dk_ptr);
int dk_lane_seq_idx = (lane_id >> 4);
int dk_lane_head_dim_idx = (lane_id & 15);
int dk_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K/kBlockK_); k_loop++) {
#pragma unroll
for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
vec2_Element<Element> v_data;
int v_offset = dk_lane_head_dim_idx*2 + dk_lane_seq_idx * seqlen_dk_stride;
int s_offset = n_block * kBlockN_ * seqlen_dk_stride + (warp_id*WARP_N_) * seqlen_dk_stride + (min_tile_n*seqlen_dk_stride*16 + vec_index * 4 * seqlen_dk_stride + k_tile_idx*32 + k_loop * kBlockK_ + warp_n_idx*32);
int known_offset = 0;
v_data[0] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2].f32[vec_index] * params.scale_softmax_rp_dropout);
v_data[1] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 1].f32[vec_index] * params.scale_softmax_rp_dropout);
if (Is_even_MN || n_block * kBlockN_ + warp_id*WARP_N_ + dk_lane_seq_idx + min_tile_n*16 + vec_index * 4 < binfo.actual_seqlen_k){
inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdK, s_offset, /* immediate integer */known_offset);
}
}
}
}
}
}
}
}
#undef print_dq
#undef print_softmax_rescale_o
#undef print_ds
#undef print_ds_fp16
#undef print_dp
#define print_kq(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int qk_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int qk_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_id*WARP_N_ + qk_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + qk_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + qk_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = qk_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
kq_ptr[offset + reg_id *params.seqlen_k] = (s_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_m*2 + min_tile_n].f32[reg_id]); \
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_softmax_rescale_o(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int s_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int s_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int s_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + s_warp_n_id*WARP_N_ + s_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + s_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + s_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = s_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
s_ptr[offset + reg_id * params.seqlen_k] = (s_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f32[reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_ds(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int ds_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int ds_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int ds_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_id*WARP_N_ + ds_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + ds_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + ds_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = ds_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
ds_ptr[offset + reg_id * params.seqlen_k] = (dS_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f32[reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_ds_fp16(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int ds_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int ds_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int ds_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_id*WARP_N_ + ds_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + ds_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + ds_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = ds_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
ds_ptr[offset + reg_id * params.seqlen_k] = UpCast<Element,float,true>(dS_reg_fp16[(m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx)*2 + min_tile_m].f16[min_tile_n*4 + reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
// #define print_ds_fp16(block_id_m, bidb, bidh) { \
// __builtin_amdgcn_sched_barrier(0);\
// __builtin_amdgcn_s_waitcnt(0);\
// __syncthreads();\
// __builtin_amdgcn_sched_barrier(0);\
// int ds_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
// int ds_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
// int ds_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_id*WARP_N_ + ds_warp_m_id*WARP_M_*params.seqlen_k; \
// for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
// for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
// for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
// for(int reg_id=0; reg_id<4; reg_id++) { \
// for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
// for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
// if(((n_block*kBlockN_ + ds_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
// ((block_id_m*kBlockM_ + ds_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
// int offset = ds_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
// ds_ptr[offset + reg_id * 8 * params.seqlen_k] = UpCast<Element,float>(dS_reg_fp16[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f16[reg_id]);\
// } \
// } \
// } \
// } \
// } \
// } \
// } \
// __builtin_amdgcn_sched_barrier(0);\
// __builtin_amdgcn_s_waitcnt(0);\
// __syncthreads();\
// __builtin_amdgcn_sched_barrier(0);\
// }
#define print_dp(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int dp_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int dp_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int dp_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + dp_warp_n_id*WARP_N_ + dp_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) {\
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + dp_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + dp_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = dp_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
dp_ptr[offset + reg_id * params.seqlen_k] = (dp_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32) + m_idx*(WARP_N_/32) + n_idx][min_tile_m*2 + min_tile_n].f32[reg_id]);\
} \
} \
} \
} \
} \
}\
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
/*
load q/k:累加方向为主序方向
ps: 在offset传0的情况下,T和R的取值似乎没有影响!?
调用matrix_load_32x32_b16:
R=0: offset in column direction
load Q: T=1: row major
load K: T=0: column major
m_ab=1: 线程数据在主序方向拼接
调用ds_read_matrix_trans_format(和m_ab保持一致):
element:0x2 row:0x2 col:0x1 alt:0x0
load v:累加方向为非主序方向
调用matrix_load_32x32_b16:
R=0: offset in column direction
T=1: row major
m_ab=0: 线程数据在非主序方向拼接
调用ds_read_matrix_format(和m_ab保持一致)
*/
template<class Element, class ElementAccum, bool Is_dropout, bool Is_causal , bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, int kBlockM_, int kBlockN_, int K, int K_v, int kBlockK_, int WARP_M_, int WARP_N_, bool USE_BSHD_LAYOUT, typename Params>
__forceinline__ __device__ void compute_dk_dv_1colblock_gfx946(Params &params, int bidb, int bidh, int n_block
) {
#ifdef DEBUGING
ElementAccum * kq_ptr = static_cast<ElementAccum*>(params.kq_ptr);
ElementAccum * s_ptr = static_cast<ElementAccum*>(params.s_ptr);
ElementAccum * dp_ptr = static_cast<ElementAccum*>(params.dp_ptr);
ElementAccum * ds_ptr = static_cast<ElementAccum*>(params.ds_ptr);
#endif
Element* q_ptr = static_cast<Element*>(params.q_ptr);
Element* k_ptr = static_cast<Element*>(params.k_ptr);
Element* v_ptr = static_cast<Element*>(params.v_ptr);
Element* o_ptr = static_cast<Element*>(params.o_ptr);
Element* p_ptr = static_cast<Element*>(params.p_ptr);
// Element* dq_ptr = static_cast<Element*>(params.dq_ptr);
Element* dk_ptr = static_cast<Element*>(params.dk_ptr);
Element* dv_ptr = static_cast<Element*>(params.dv_ptr);
Element* do_ptr = static_cast<Element*>(params.do_ptr);
ElementAccum* softmax_lse_ptr = static_cast<ElementAccum*>(params.softmax_lse_ptr);
ElementAccum* dsoftmax_sum = static_cast<ElementAccum*>(params.dsoftmax_sum);
//flash-attention QK, kBlockN_==WARP_N_;
// static_assert(kBlockM_=WARP_M_,"Error: kBlockM_ not equal WARP_M_!");
const int WARP_NUM = (kBlockM_*kBlockN_)/(WARP_M_*WARP_N_);
const int M_BLOCK_NUM = params.seqlen_q/kBlockM_;
const int N_BLOCK_NUM = params.seqlen_k/kBlockN_;
extern __shared__ Element smem[];
int K_lds_ratio;
// 0表示k不预取;1表示k预取一半到寄存器;2表示一半到寄存器,一半到LDS;3表示全部预取到寄存器
const int K_prefetch_level = 3;
const int STAGES = 2;
const bool Is_store_Q = true;
const bool Is_store_dO = true;
const bool Is_preload_Q = true;
const bool Is_preload_dO = true;
const int dP_dO_prefetch_level = Is_store_dO ? 1 : 0;
const int Q_prefetech_level = Is_preload_Q ? 1 : 0;
if constexpr (K_prefetch_level == 2){
K_lds_ratio = (K / kBlockK_) / 2;
} else {
K_lds_ratio = (K_prefetch_level == 3) ? 0 : STAGES;
}
Element* K_lds = (Element*)&(smem);
Element* dO_lds = K_lds + kBlockN_ * kBlockK_ * K_lds_ratio;
Element* V_lds = K_prefetch_level == 2 ? dO_lds : K_lds;
Element* Q_lds = Is_store_Q ? dO_lds + kBlockM_ * K_v : dO_lds;
#if 0//defined(__gfx936__)
auto pointwise_mult = [](vec2_fp32 p, vec2_fp32 dp, vec2_fp32 d) {
auto d0 = (!Is_dropout || p[0] >= 0 ? dp[0] - d[0] : d[0]);
auto d1 = (!Is_dropout || p[1] >= 0 ? dp[1] - d[1] : d[1]);
// return vec2_fp32{p[0]*d0,p[1]*d1};
// return __builtin_hcu_pk_mul_f32(p, vec2_fp32{d0, d1});
return __builtin_hcu_v_pk_mul_f32(p, vec2_fp32{d0, d1});
};
#else
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
};
#endif
int tidx = threadIdx.x;
int lane_id = threadIdx.x & 63; //lane id, 0-63
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (n_block * kBlockN_ >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0) return;
const int m_block_min = (!Is_causal && !Is_local) ? 0 : std::max(0, (n_block * kBlockN_ - params.window_size_right) / kBlockM_);
const int m_block_max = !Is_local ? ceil_div(binfo.actual_seqlen_q, kBlockM_) : std::min(ceil_div(binfo.actual_seqlen_q, kBlockM_), ceil_div((n_block + 1) * kBlockN_ + params.window_size_left, kBlockM_));
int seqlen_q_stride = params.q_row_stride;
int seqlen_k_stride = params.k_row_stride;
int seqlen_v_stride = params.v_row_stride;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
int seqlen_dk_stride = params.dk_row_stride;
int seqlen_dv_stride = params.dv_row_stride;
// We move K and V to the last block.
const int row_offset_q = binfo.q_offset1(params.q_batch_stride, params.q_row_stride, bidb) + binfo.q_offset2(params.q_head_stride,bidh) + (m_block_max - 1) * kBlockM_* seqlen_q_stride;
const int row_offset_k = binfo.k_offset1(params.k_batch_stride, params.k_row_stride, bidb) + binfo.k_offset2(params.k_head_stride,bidh/params.h_h_k_ratio) + n_block * kBlockN_ * seqlen_k_stride;
const int row_offset_v = binfo.k_offset1(params.v_batch_stride, params.v_row_stride, bidb) + binfo.k_offset2(params.v_head_stride,bidh/params.h_h_k_ratio) + n_block * kBlockN_ * seqlen_v_stride;
const int row_offset_dO = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + (m_block_max - 1) * kBlockM_ * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + (m_block_max - 1) * kBlockM_ * seqlen_o_stride;
// const int row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + (m_block_max - 1) * kBlockM_;
const int row_offset_lse = params.cu_seqlens_q == nullptr ? (bidb * params.h + bidh) * params.seqlen_q + (m_block_max - 1) * kBlockM_ : bidh * params.total_q + binfo.sum_s_q + (m_block_max - 1) * kBlockM_;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + (m_block_max - 1) * kBlockM_;
auto gQ = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(q_ptr) + row_offset_q, seqlen_q_stride);
auto gK = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(k_ptr) + row_offset_k, seqlen_k_stride);
auto gV = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(v_ptr) + row_offset_v, seqlen_v_stride);
auto gdO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_dO, seqlen_do_stride);
Element * gO = reinterpret_cast<Element *>(o_ptr) + row_offset_o;
ElementAccum *gLSE = reinterpret_cast<ElementAccum *>(softmax_lse_ptr) + row_offset_lse;
ElementAccum *gdPsum = reinterpret_cast<ElementAccum *>(dsoftmax_sum) + row_offset_dpsum;
constexpr int m_masking_steps = (!Is_causal && !Is_local)
? 0
: flash::ceil_div(kBlockN_, kBlockM_);
/***************************************************************************************************************************/
// int warp_id =0;
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
union_vec4_f16x2<Element> k_reg[(K/kBlockK_)*((WARP_N_*kBlockK_)/(32*32))*2/((K_prefetch_level == 3)? 1 : 2)]; //ds_read mini size is 32*32,2 is seq, 4 is head dim
union_vec4_f16x2<Element> v_reg[(K_v/kBlockK_)*((WARP_N_*kBlockK_)/(32*32))*2];
//提前读取V到vgpr
prefetch_to_vgpr_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gV, V_lds, v_reg, (binfo.actual_seqlen_k - n_block * kBlockN_), warp_id);
//提前读取K到vgpr
prefetch_to_vgpr_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gK, K_lds, k_reg, (binfo.actual_seqlen_k - n_block * kBlockN_), warp_id);
//提前读取Q到lds
if constexpr (Is_preload_Q){
prefetch_to_lds_gfx938<true, kBlockM_, K, Element, ElementAccum, Is_even_MN>(gQ, 0, Q_lds, (binfo.actual_seqlen_q - (m_block_max - 1) * kBlockM_), warp_id);
}
//提前读取dO到lds
if constexpr (Is_preload_dO){
prefetch_to_lds_gfx938<true, kBlockM_, K_v, Element, ElementAccum, Is_even_MN>(gdO, 0, dO_lds, (binfo.actual_seqlen_q - (m_block_max - 1) * kBlockM_), warp_id);
}
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
union_vec4_fp32 acc_dv[(K_v/kBlockK_) * ((WARP_N_/32)*(kBlockK_/32))][4]={0};
union_vec4_fp32 acc_dk[(K/kBlockK_) * ((WARP_N_/32)*(kBlockK_/32))][4]={0};
for (int m_block = m_block_max - 1; m_block >= m_block_min; --m_block) {
union_vec4_f16x2<Element> q_reg[((WARP_M_*kBlockK_)/(32*32))*2];
//c mini tile is 32*32
union_vec4_fp32 s_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
/*
qk gemm
结果矩阵layout:
0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48 0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48
1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49 1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49
...
0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48 0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48
1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49 1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49
*/
gemm_tt_kq_gfx938<Is_store_Q, Is_preload_dO, Is_even_MN, K_prefetch_level, Q_prefetech_level, K, kBlockN_, kBlockM_, kBlockK_, WARP_N_, WARP_M_, STAGES, Element>(
gK, gQ, K_lds, Q_lds, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), k_reg, q_reg, s_reg, warp_id, seqlen_k_stride, seqlen_q_stride);
/*
lse layout:
4 warp:
32
32
32
32
因为warp在seqlen_k维度,所以不区分warp
每16个thread持有相同的lse,所以需要/4
*/
float lse[kBlockM_/4];
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
#pragma unroll
for (int mi = 0; mi < (kBlockM_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
const int lse_idx = mi*32 + min_tile_m * 16 + (lane_id >> 4)*4 + vec_idx;
lse[(mi*2 + min_tile_m)*4 + vec_idx] = Is_even_MN || lse_idx < binfo.actual_seqlen_q - m_block * kBlockM_ ? gLSE[lse_idx] : INFINITY;
}
}
}
apply_mask_bwd_gfx938<Is_even_MN, Is_local ? 3 : (Is_causal ? 2 : 0)>(s_reg, binfo.actual_seqlen_k - n_block * kBlockN_ - warp_id * 32, binfo.actual_seqlen_q - m_block * kBlockM_, (n_block * kBlockN_ + warp_id * 32) - m_block * kBlockM_, params.window_size_right, params.window_size_left);
#ifdef DEBUGING
print_kq(m_block, bidb, bidh);
#endif
//do . o后在headdim维度reduce求和,读取方式和lse一样,因为pad了,所以无需边界判断
float dP_sum_reg[kBlockM_/4];
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
#pragma unroll
for (int mi = 0; mi < (kBlockM_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
const int dPsum_idx = mi*32 + min_tile_m * 16 + (lane_id >> 4)*4 + vec_idx;
dP_sum_reg[(mi*2 + min_tile_m)*4 + vec_idx] = gdPsum[dPsum_idx];
}
}
}
{
scale_apply_exp2_bwd</*scale_max=*/false, kBlockM_, WARP_N_>(s_reg, lse, params.scale_softmax_log2);
}
#ifdef DEBUGING
print_softmax_rescale_o(m_block, bidb, bidh);
#endif
// //TODO:drop
union_vec4_f16x2<Element> p_reg[(kBlockM_/32)*(WARP_N_/32)*2];
// convert_pk_type<kBlockM_, WARP_N_, Element>(p_reg, s_reg);
convert_pk_type_gfx938<kBlockM_, WARP_N_, Element>(p_reg, s_reg);
//QK(seq_q, seq_kv), seq_q is continuous, seq_kv is not continuous
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
{
//dv gemm,dO*P
gpu_gemm_B_in_reg_gfx946<Is_preload_dO, Is_store_dO, Is_even_MN, K_v, kBlockK_, kBlockN_, kBlockM_, kBlockK_, WARP_N_, 2, Element>(gdO, gQ, dO_lds, p_reg, acc_dv, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_do_stride);
}
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
union_vec4_f16x2<Element> dO_reg[((WARP_M_*kBlockK_)/(32*32))*2];
union_vec4_fp32 dp_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
{
// dP gemm dO * V
gemm_tt_kq_gfx938<Is_store_dO, false, Is_even_MN, 3, dP_dO_prefetch_level, K_v, kBlockN_, kBlockM_, kBlockK_, WARP_N_, WARP_M_, STAGES, Element>(
gV, gdO, V_lds, dO_lds, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), v_reg, dO_reg, dp_reg, warp_id, seqlen_v_stride, seqlen_do_stride);
}
#ifdef DEBUGING
print_dp(m_block, bidb, bidh);
#endif
union_vec4_fp32 dS_reg[(WARP_N_/32)*(kBlockM_/32)][4];
#pragma unroll
for (int mi = 0; mi < (kBlockM_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for (int ni = 0; ni < (WARP_N_/32); ++ni) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#if 0//defined(__gfx936__)
#pragma unroll
for(int vec_idx=0; vec_idx<2; vec_idx++) {
// result register ds_reg reuse dp_reg
dS_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx] = pointwise_mult(
s_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx],
dp_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx],
vec2_fp32{gdPsum[vec_idx*16 + mi*8*4 + ((lane_id >> 4)*2) + min_tile_m], gdPsum[vec_idx*16 + mi*8*4 + ((lane_id >> 4)*2) + min_tile_m + 8]});
}
#else
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
// result register ds_reg reuse dp_reg
dS_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].f32[vec_idx] = pointwise_mult(
s_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].f32[vec_idx],
dp_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].f32[vec_idx],
dP_sum_reg[min_tile_m*4 + vec_idx]);
}
#endif
}
}
}
}
// #ifdef DEBUGING
// print_ds(m_block, bidb, bidh);
// #endif
union_vec4_f16x2<Element> dS_reg_fp16[(WARP_N_/32)*(kBlockM_/32)*2];
convert_pk_type_gfx938<kBlockM_, WARP_N_, Element>(dS_reg_fp16, dS_reg);
// #ifdef DEBUGING
// print_ds_fp16(m_block, bidb, bidh);
// #endif
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
{
//dk gemm, Q*dS
gpu_gemm_B_in_reg_gfx946<Is_store_Q , false, Is_even_MN, K, kBlockK_, kBlockN_, kBlockM_, kBlockK_, WARP_N_, 2, Element>(gQ, gdO, Q_lds, dS_reg_fp16, acc_dk, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_q_stride);
}
gLSE = gLSE + (-int(kBlockM_));
gdPsum = gdPsum - kBlockM_;
*(uint64_t*)&gQ -= ((kBlockM_ * seqlen_q_stride) * sizeof(Element));
*(uint64_t*)&gdO -= ((kBlockM_ * seqlen_do_stride) * sizeof(Element));
{
__syncthreads();
if (Is_preload_Q && m_block > m_block_min){
prefetch_to_lds_gfx938<true, kBlockM_, K, Element, ElementAccum, Is_even_MN>(gQ, 0, Q_lds, (binfo.actual_seqlen_q - (m_block - 1) * kBlockM_), warp_id);
}
// __syncthreads();
if (Is_preload_dO && m_block > m_block_min){
prefetch_to_lds_gfx938<true, kBlockM_, K_v, Element, ElementAccum, Is_even_MN>(gdO, 0, dO_lds, (binfo.actual_seqlen_q - (m_block - 1) * kBlockM_), warp_id);
}
}
}
#if 1
//这是正常的MLS+ds_read_matrix的layout
{
// dv_ptr = dv_ptr + binfo.k_offset1(params.dv_batch_stride, params.dv_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dv_head_stride,bidh);
dv_ptr = dv_ptr + binfo.k_offset1_write(params.dv_batch_stride, params.dv_row_stride, bidb) + binfo.k_offset2(params.dv_head_stride,bidh);
auto gdV = tcp_cache_swizzle_func<K_v, Element>(dv_ptr);
int dv_lane_seq_idx = (lane_id >> 4);
int dv_lane_head_dim_idx = (lane_id & 15);
int dv_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K_v/kBlockK_); k_loop++) {
#pragma unroll
for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
int v_offset = dv_lane_head_dim_idx * seqlen_dv_stride + dv_lane_seq_idx * 4;
int s_offset = (min_tile_n * seqlen_dv_stride * 16 + vec_index % 2 * 2 + vec_index / 2 * 16) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dv_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dv_stride);
int known_offset = 0;
vec2_Element<Element> v_data;
v_data[0] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + vec_index / 2].f32[vec_index % 2 * 2]);
v_data[1] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + vec_index / 2].f32[vec_index % 2 * 2 + 1]);
if (Is_even_MN || min_tile_n*16 + (warp_id*WARP_N_ + warp_n_idx*32) + n_block * kBlockN_ + dv_lane_head_dim_idx < binfo.actual_seqlen_k){
inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdV, s_offset, /* immediate integer */known_offset);
}
}
}
}
}
}
}
#endif
#if 1
//这是正常的MLS+ds_read_matrix的layout
{
dk_ptr = dk_ptr + binfo.k_offset1_write(params.dk_batch_stride, params.dk_row_stride, bidb) + binfo.k_offset2(params.dk_head_stride,bidh);
auto gdK = tcp_cache_swizzle_func<K_v, Element>(dk_ptr);
int dk_lane_seq_idx = (lane_id >> 4);
int dk_lane_head_dim_idx = (lane_id & 15);
int dk_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K_v/kBlockK_); k_loop++) {
#pragma unroll
for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
int v_offset = dk_lane_head_dim_idx * seqlen_dk_stride + dk_lane_seq_idx * 4;
int s_offset = (min_tile_n * seqlen_dk_stride * 16 + vec_index % 2 * 2 + vec_index / 2 * 16) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dk_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dk_stride);
int known_offset = 0;
vec2_Element<Element> v_data;
v_data[0] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + vec_index / 2].f32[vec_index % 2 * 2] * params.scale_softmax_rp_dropout);
v_data[1] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + vec_index / 2].f32[vec_index % 2 * 2 + 1] * params.scale_softmax_rp_dropout);
if (Is_even_MN || min_tile_n*16 + (warp_id*WARP_N_ + warp_n_idx*32) + n_block * kBlockN_ + dk_lane_head_dim_idx < binfo.actual_seqlen_k){
inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdK, s_offset, /* immediate integer */known_offset);
}
}
}
}
}
}
}
#endif
// #if 1
// {
// // dv_ptr = dv_ptr + binfo.k_offset1(params.dv_batch_stride, params.dv_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dv_head_stride,bidh);
// dv_ptr = dv_ptr + binfo.k_offset1_write(params.dv_batch_stride, params.dv_row_stride, bidb) + binfo.k_offset2(params.dv_head_stride,bidh);
// auto gdV = tcp_cache_swizzle_func<K_v, Element>(dv_ptr);
// int dv_lane_seq_idx = (lane_id >> 4);
// int dv_lane_head_dim_idx = (lane_id & 15);
// int dv_global_addr_offset=0;
// #pragma unroll
// for(int k_loop = 0; k_loop<(K_v/kBlockK_); k_loop++) {
// #pragma unroll
// for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
// #pragma unroll
// for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
// #pragma unroll
// for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
// #pragma unroll
// for(int vec_index=0; vec_index<4; vec_index++) {
// int v_offset = dv_lane_head_dim_idx * seqlen_dv_stride + dv_lane_seq_idx * 8;
// int s_offset = (min_tile_n * seqlen_dv_stride * 16 + vec_index * 2) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dv_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dv_stride);
// int known_offset = 0;
// vec2_Element<Element> v_data;
// v_data[0] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 0].f32[vec_index]);
// v_data[1] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 1].f32[vec_index]);
// if (Is_even_MN || min_tile_n*16 + (warp_id*WARP_N_ + warp_n_idx*32) + n_block * kBlockN_ + dv_lane_head_dim_idx < binfo.actual_seqlen_k){
// inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdV, s_offset, /* immediate integer */known_offset);
// }
// }
// }
// }
// }
// }
// }
// #endif
// // //test only
// // {
// // // dv_ptr = dv_ptr + binfo.k_offset1(params.dv_batch_stride, params.dv_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dv_head_stride,bidh);
// // dv_ptr = dv_ptr + binfo.k_offset1_write(params.dv_batch_stride, params.dv_row_stride, bidb) + binfo.k_offset2(params.dv_head_stride,bidh);
// // auto gdV = tcp_cache_swizzle_func<K_v, Element>(dv_ptr);
// // int dv_lane_seq_idx = (lane_id >> 4);
// // int dv_lane_head_dim_idx = (lane_id & 15);
// // int dv_global_addr_offset=0;
// // #pragma unroll
// // for(int k_loop = 0; k_loop<(K_v/kBlockK_); k_loop++) {
// // #pragma unroll
// // for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
// // #pragma unroll
// // for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
// // #pragma unroll
// // for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
// // #pragma unroll
// // for(int vec_index=0; vec_index<4; vec_index++) {
// // // int v_offset = dv_lane_head_dim_idx * seqlen_dv_stride + dv_lane_seq_idx * 8;
// // // int s_offset = (min_tile_n * seqlen_dv_stride * 16 + vec_index * 2) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dv_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dv_stride);
// // int v_offset = dv_lane_head_dim_idx * 2 + dv_lane_seq_idx * 4 * seqlen_dv_stride;
// // int s_offset = (min_tile_n * seqlen_dv_stride * 16 + vec_index * seqlen_dv_stride) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dv_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dv_stride);
// // int known_offset = 0;
// // vec2_Element<Element> v_data;
// // v_data[0] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + vec_index / 2].f32[vec_index % 2 * 2]);
// // v_data[1] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + vec_index / 2].f32[vec_index % 2 * 2 + 1]);
// // inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdV, s_offset, /* immediate integer */known_offset);
// // }
// // }
// // }
// // }
// // }
// // }
// {
// // dk_ptr = dk_ptr + binfo.k_offset1(params.dk_batch_stride, params.dk_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dk_head_stride,bidh);
// dk_ptr = dk_ptr + binfo.k_offset1_write(params.dk_batch_stride, params.dk_row_stride, bidb) + binfo.k_offset2(params.dk_head_stride,bidh);
// auto gdK = tcp_cache_swizzle_func<K, Element>(dk_ptr);
// int dk_lane_seq_idx = (lane_id >> 4);
// int dk_lane_head_dim_idx = (lane_id & 15);
// int dk_global_addr_offset=0;
// #pragma unroll
// for(int k_loop = 0; k_loop<(K/kBlockK_); k_loop++) {
// #pragma unroll
// for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
// #pragma unroll
// for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
// #pragma unroll
// for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
// #pragma unroll
// for(int vec_index=0; vec_index<4; vec_index++) {
// vec2_Element<Element> v_data;
// int v_offset = dk_lane_head_dim_idx * seqlen_dk_stride + dk_lane_seq_idx * 8;
// int s_offset = (min_tile_n * seqlen_dk_stride * 16 + vec_index * 2) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dk_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dk_stride);
// int known_offset = 0;
// v_data[0] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 0].f32[vec_index] * params.scale_softmax_rp_dropout);
// v_data[1] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 1].f32[vec_index] * params.scale_softmax_rp_dropout);
// if (Is_even_MN || min_tile_n*16 + (warp_id*WARP_N_ + warp_n_idx*32) + n_block * kBlockN_ + dk_lane_head_dim_idx < binfo.actual_seqlen_k){
// inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdK, s_offset, /* immediate integer */known_offset);
// }
// }
// }
// }
// }
// }
// }
// {
// // dv_ptr = dv_ptr + binfo.k_offset1(params.dv_batch_stride, params.dv_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dv_head_stride,bidh);
// dv_ptr = dv_ptr + binfo.k_offset1_write(params.dv_batch_stride, params.dv_row_stride, bidb) + binfo.k_offset2(params.dv_head_stride,bidh);
// auto gdV = tcp_cache_swizzle_func<K_v, Element>(dv_ptr);
// int dv_lane_seq_idx = (lane_id >> 4);
// int dv_lane_head_dim_idx = (lane_id & 15);
// int dv_global_addr_offset=0;
// #pragma unroll
// for(int k_loop = 0; k_loop<(K_v/kBlockK_); k_loop++) {
// #pragma unroll
// for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
// #pragma unroll
// for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
// #pragma unroll
// for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
// #pragma unroll
// for(int vec_index=0; vec_index<4; vec_index++) {
// int v_offset = dv_lane_head_dim_idx*2 + dv_lane_seq_idx * seqlen_dv_stride;
// int s_offset = (min_tile_n*seqlen_dv_stride*16 + vec_index * 4 * seqlen_dv_stride) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dv_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dv_stride);
// int known_offset = 0;
// vec2_Element<Element> v_data;
// v_data[0] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2].f32[vec_index]);
// v_data[1] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 1].f32[vec_index]);
// if (Is_even_MN || n_block * kBlockN_ + warp_id*WARP_N_ + warp_n_idx*32 + dv_lane_seq_idx + min_tile_n*16 + vec_index * 4 < binfo.actual_seqlen_k){
// inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdV, s_offset, /* immediate integer */known_offset);
// }
// }
// }
// }
// }
// }
// }
// {
// // dk_ptr = dk_ptr + binfo.k_offset1(params.dk_batch_stride, params.dk_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dk_head_stride,bidh);
// dk_ptr = dk_ptr + binfo.k_offset1_write(params.dk_batch_stride, params.dk_row_stride, bidb) + binfo.k_offset2(params.dk_head_stride,bidh);
// auto gdK = tcp_cache_swizzle_func<K, Element>(dk_ptr);
// int dk_lane_seq_idx = (lane_id >> 4);
// int dk_lane_head_dim_idx = (lane_id & 15);
// int dk_global_addr_offset=0;
// #pragma unroll
// for(int k_loop = 0; k_loop<(K/kBlockK_); k_loop++) {
// #pragma unroll
// for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
// #pragma unroll
// for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
// #pragma unroll
// for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
// #pragma unroll
// for(int vec_index=0; vec_index<4; vec_index++) {
// vec2_Element<Element> v_data;
// int v_offset = dk_lane_head_dim_idx*2 + dk_lane_seq_idx * seqlen_dk_stride;
// int s_offset = n_block * kBlockN_ * seqlen_dk_stride + (warp_id*WARP_N_) * seqlen_dk_stride + (min_tile_n*seqlen_dk_stride*16 + vec_index * 4 * seqlen_dk_stride + k_tile_idx*32 + k_loop * kBlockK_ + warp_n_idx*32);
// int known_offset = 0;
// v_data[0] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2].f32[vec_index] * params.scale_softmax_rp_dropout);
// v_data[1] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 1].f32[vec_index] * params.scale_softmax_rp_dropout);
// if (Is_even_MN || n_block * kBlockN_ + warp_id*WARP_N_ + dk_lane_seq_idx + min_tile_n*16 + vec_index * 4 < binfo.actual_seqlen_k){
// inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdK, s_offset, /* immediate integer */known_offset);
// }
// }
// }
// }
// }
// }
// }
}
#undef print_dq
#undef print_softmax_rescale_o
#undef print_ds
#undef print_ds_fp16
#undef print_dp
#pragma once
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "intrinsic.h"
#include "prefetch.h"
// K BLOCK_K BLOCK_N BLOCK_M BLOCK_K WARP_N
template<bool Is_preload_A, bool Is_store_A, bool Is_preload_C, bool Is_even_MN, int M/*head_dim*/, int BLOCK_M, int BLOCK_N, int BLOCK_K, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum = float>
__forceinline__ __device__ void gpu_gemm_B_in_reg(vec4_uint A_ptr, vec4_uint C_ptr, Element* A_lds, union_vec2_f16x2<Element> B_reg[(WARP_M/32)*(BLOCK_K/32)][4], union_vec4_fp32 C_reg[(M/BLOCK_M)*(WARP_M/32)*(WARP_N/32)][4], int N/*seq_kv*/, int K/*seq_q*/, int warp_id, int seqlen_A_stride)
{
#if 1
const int WARP_NUM = (BLOCK_M*BLOCK_N)/(WARP_M*WARP_N);
const int A_lds_load_num = (BLOCK_M*BLOCK_K) / (4*32);
static_assert(BLOCK_K>=32, "Error: gpu_gemm_B_in_reg gemm BLOCK_K must be equal or greater than 32");
static_assert(BLOCK_N>=WARP_N, "Error: gpu_gemm_B_in_reg gemm BLOCK_N must be equal or greater than WARP_N");
static_assert(BLOCK_M==WARP_M, "Error: gpu_gemm_B_in_reg gemm BLOCK_M must be equal to WARP_M");
union_vec2_f16x2<Element> A_reg[((WARP_M*BLOCK_K)/(32*32))*2][2];
//c mini tile is 32*32
// vec4_fp32 o[(WARP_M/32)*(WARP_N/32)][4]={0};
// __shared__ Element A_lds[STAGES*BLOCK_N * BLOCK_K];
//wave size should be defined in launch file. Here use 64 threads
int lane_id = threadIdx.x & 63; //lane id, 0-63
int row = lane_id % 4;
int col = lane_id / 4;
int stage_id = 0;
if(STAGES > 1 && (!Is_preload_A)) {
int m_loop = 0;
int A_block_buffer_load_global_offset = m_loop*BLOCK_M ; //+ k_loop * BLOCK_K * N;
// A_ptr buffer load mini size is 32*32, buffer_load_dword mini size is 4*32
int A_lane_m_idx = lane_id % 16;
// int A_lane_k_idx = lane_id / 16;
int A_lane_k_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1); //(0, 1, 2, 3) --> (0, 2, 1, 3)
for(int warp_loop=warp_id; warp_loop<A_lds_load_num; warp_loop+=WARP_NUM) {
// for(int warp_loop_tmp = 0; warp_loop_tmp < A_lds_load_num / WARP_NUM; warp_loop_tmp++){
// int warp_loop = warp_loop_tmp * WARP_NUM + warp_id;
//global->lds, right matrix
int A_warp_buffer_load_k_id = (warp_loop / (BLOCK_M/32)); //seq_len
int A_warp_buffer_load_m_id = (warp_loop % (BLOCK_M/32)); //head_dim
{
int A_warp_buffer_load_global_offset = (A_warp_buffer_load_m_id * 32);
int A_warp_buffer_load_lds_offset = (A_warp_buffer_load_m_id * 32) + (A_warp_buffer_load_k_id * 4 * BLOCK_M);
if(Is_store_A){
A_warp_buffer_load_lds_offset = (A_warp_buffer_load_m_id * 32) + (A_warp_buffer_load_k_id * (4 * BLOCK_M + 2));
}
int A_gsoffset = (A_block_buffer_load_global_offset + A_warp_buffer_load_global_offset)/2 ;
int A_gvoffset;
if constexpr (Is_even_MN){
A_gvoffset = ((A_lane_m_idx * 2 + (A_lane_k_idx + A_warp_buffer_load_k_id*4)* seqlen_A_stride))/2 ;
} else {
A_gvoffset = ((A_lane_m_idx * 2 + min(A_lane_k_idx + A_warp_buffer_load_k_id*4, K-1)* seqlen_A_stride))/2 ;
}
// int gvOffset = (64*8 + lane_id*8)/2;
int A_lds_offset = ((stage_id)*BLOCK_K*BLOCK_M + A_warp_buffer_load_lds_offset)/2;
if(Is_store_A){
A_lds_offset = ((stage_id)*(BLOCK_K/32)*(BLOCK_M/32)*32*34 + A_warp_buffer_load_lds_offset)/2;
}
builtin_buffer_load_dword_lds(A_lds , A_ptr, A_lds_offset, A_gsoffset, A_gvoffset);
}
}
}
#if 1
// int lds_offset = row * 8 + col * 32;
for(int m_loop = 1; m_loop<(M/BLOCK_M) + 1; m_loop++) {
if(STAGES > 1) {
if constexpr(Is_preload_A || Is_store_A){
stage_id ++;
} else {
stage_id = stage_id ^ 1;
}
}
if(STAGES == 1) {
m_loop--;
}
if((!Is_preload_A)&& m_loop < (M/BLOCK_M)) {
int A_block_buffer_load_global_offset = m_loop*BLOCK_M;
if(Is_store_A){
int A_lds_stage_offset = (stage_id)*(BLOCK_K/32)*(BLOCK_M/32)*32*34;
buffer_load_lds_tile_pad(Is_even_MN, WARP_NUM, seqlen_A_stride, BLOCK_M, BLOCK_K, Element, A_ptr, A_lds, A_block_buffer_load_global_offset, A_lds_stage_offset, K, warp_id, lane_id);
} else {
int A_lds_stage_offset = (stage_id)*BLOCK_K*BLOCK_M;
buffer_load_lds_tile(Is_even_MN, WARP_NUM, seqlen_A_stride, BLOCK_M, BLOCK_K, Element, A_ptr, A_lds, A_block_buffer_load_global_offset, A_lds_stage_offset, K, warp_id, lane_id);
}
}
if(!Is_preload_A){
if(STAGES > 1) {
if(m_loop < (M/BLOCK_M)){
// if constexpr(Is_preload_A){
// vmcnt_wait((M/BLOCK_M - m_loop) * (BLOCK_K*BLOCK_M) / (4*32)/WARP_NUM);
// } else {
vmcnt_wait((BLOCK_K*BLOCK_M) / (4*32)/WARP_NUM);
// }
} else {
vmcnt_wait(0);
}
} else {
vmcnt_wait(0);
}
}
if constexpr (STAGES > 1) {
if constexpr(Is_preload_A || Is_store_A){
stage_id --;
} else {
stage_id = stage_id ^ 1;
}
}
if (Is_preload_C && m_loop > 1) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
int C_block_buffer_load_global_offset = (m_loop - 2)*BLOCK_M;
// A_ptr buffer load mini size is 32*32, buffer_load_dword mini size is 4*32
int C_lane_m_idx = lane_id % 16;
// int A_lane_k_idx = lane_id / 16;
int C_lane_k_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1); //(0, 1, 2, 3) --> (0, 2, 1, 3)
for(int warp_loop_temp=0; warp_loop_temp< A_lds_load_num/WARP_NUM; warp_loop_temp++) {
int warp_loop = warp_loop_temp * WARP_NUM + warp_id;
//global->lds, right matrix
int C_warp_buffer_load_k_id = (warp_loop / (BLOCK_K/32)); //seq_len
int C_warp_buffer_load_m_id = (warp_loop % (BLOCK_M/32)); //head_dim
{
int C_warp_buffer_load_global_offset = (C_warp_buffer_load_m_id * 32);
int C_warp_buffer_load_lds_offset = (C_warp_buffer_load_m_id * 32) + (C_warp_buffer_load_k_id * 4 * BLOCK_M);
if(Is_store_A){
C_warp_buffer_load_lds_offset = (C_warp_buffer_load_m_id * 32) + (C_warp_buffer_load_k_id * (4 * BLOCK_M + 2));
}
int C_gsoffset = (C_block_buffer_load_global_offset + C_warp_buffer_load_global_offset)/2 ;
int C_gvoffset;
if constexpr (Is_even_MN){
C_gvoffset = ((C_lane_m_idx * 2 + (C_lane_k_idx + C_warp_buffer_load_k_id*4)* M))/2 ;
} else {
C_gvoffset = ((C_lane_m_idx * 2 + min(C_lane_k_idx + C_warp_buffer_load_k_id*4, K-1)* M))/2 ;
}
// int gvOffset = (64*8 + lane_id*8)/2;
int A_lds_offset = ((m_loop - 2)*BLOCK_K*BLOCK_M + C_warp_buffer_load_lds_offset)/2;
if(Is_store_A){
A_lds_offset = ((m_loop - 2)*(BLOCK_K/32)*(BLOCK_M/32)*32*34 + C_warp_buffer_load_lds_offset)/2;
}
builtin_buffer_load_dword_lds(A_lds , C_ptr, A_lds_offset, C_gsoffset, C_gvoffset);
}
}
}
//lds -> vgpr use ds_read_m; left matrix
int A_lane_head_dim_idx = lane_id % 16;
int A_lane_seq_idx = lane_id / 16;
// __builtin_amdgcn_s_waitcnt(4080 + ((M/BLOCK_M) - m_loop)*(A_lds_load_num/WARP_NUM));
// vmcnt_wait_no_barrier(((M/BLOCK_M) - m_loop)*(A_lds_load_num/WARP_NUM));
vec2_Element<Element> *A_lds_v2fp16 = (vec2_Element<Element> *)(A_lds );
//lds -> vgpr use ds_read_m; right matrix
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(WARP_M/32); head_dim_idx++) {
#pragma unroll
for(int seq_idx=0; seq_idx<(BLOCK_K/32); seq_idx++) {
#pragma unroll
for(int seq_min_tile_idx=0; seq_min_tile_idx<2; seq_min_tile_idx++) { // min k tile
// __builtin_amdgcn_s_waitcnt(4082);
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) //16*32 half need 4 ds_read_b32
{
// int lds_offset = (stage_id*BLOCK_K*BLOCK_M + (seq_idx*32*BLOCK_M) + head_dim_idx*32 * 32 + A_lane_seq_idx/2*4*32 + A_lane_seq_idx%2*32 + (seq_min_tile_idx*32*2) + vec_idx*8*32 + A_lane_head_dim_idx*2)/2;
int lds_offset = stage_id * BLOCK_K * BLOCK_M / 2 + seq_idx * BLOCK_M * 16 + head_dim_idx * 512 + A_lane_seq_idx/2 * 64 + A_lane_seq_idx % 2 * 16 + seq_min_tile_idx * 32 + vec_idx * 128 + A_lane_head_dim_idx;
if constexpr(Is_preload_A || Is_store_A){
// lds_offset = (stage_id*(BLOCK_K/32)*(BLOCK_M/32)*32*34 + (seq_idx*34*BLOCK_M) + head_dim_idx*32 * 34 + A_lane_seq_idx/2*(4*32 + 2) + A_lane_seq_idx%2*32 + (seq_min_tile_idx*32*2) + vec_idx*(8*32+4) + A_lane_head_dim_idx*2)/2;
// lds_offset += (stage_id*(BLOCK_K/32)*(BLOCK_M/32)*32*2 + 2*seq_idx*BLOCK_M + head_dim_idx * 32 * 2 + A_lane_seq_idx/2*2 + vec_idx*4)/2;
lds_offset += stage_id * BLOCK_K * BLOCK_M / 32 + seq_idx * BLOCK_M + head_dim_idx * 32 + A_lane_seq_idx / 2 + vec_idx * 2;
}
inline_ds_read_b32_wait(A_lds_v2fp16, lds_offset, A_reg[(head_dim_idx*(BLOCK_K/32) + seq_idx)*2 + seq_min_tile_idx][vec_idx/2].f16x2[vec_idx%2]);
}
// #pragma unroll
// for(int vec_idx=0; vec_idx<2; vec_idx++) //16*32 half need 4 ds_read_b32
// {
// int lds_offset = (stage_id*BLOCK_K*BLOCK_M + (seq_idx*32*BLOCK_M) + head_dim_idx*32 * 32 + A_lane_seq_idx/2*4*32 + A_lane_seq_idx%2*32 + (seq_min_tile_idx*32*2) + vec_idx*16*32 + A_lane_head_dim_idx*2)/2;
// if constexpr(Is_preload_A || Is_store_A){
// lds_offset = (stage_id*(BLOCK_K/32)*(BLOCK_M/32)*32*34 + (seq_idx*34*BLOCK_M) + head_dim_idx*32 * 34 + A_lane_seq_idx/2*(4*32 + 2) + A_lane_seq_idx%2*32 + (seq_min_tile_idx*32*2) + vec_idx*(16*32+8) + A_lane_head_dim_idx*2)/2;
// }
// inline_ds_read2_b32_no_wait(A_lds_v2fp16, lds_offset, A_reg[(head_dim_idx*(BLOCK_K/32) + seq_idx)*2 + seq_min_tile_idx][vec_idx].f32, 4*32);
// }
}
}
}
// asm volatile("s_waitcnt lgkmcnt(0)");
// __builtin_amdgcn_sched_barrier(0);
if constexpr (STAGES == 1){
m_loop++;
}
asm volatile("s_setprio 1");
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int k_idx=0; k_idx<(BLOCK_K/32); k_idx++) { //BLOCK_K mini size is 32
//min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
if constexpr (std::is_same<Element,Float8_e4m3_t>::value){
C_reg[(m_loop-1) * ((WARP_M/32)*(WARP_N/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32 = flash::mmac<half_t, ElementAccum>(
vec4_Element<half_t>{
UpCast<Element, half_t, true>(A_reg[(m_idx* (BLOCK_K/32) + k_idx)*2 + 0][min_tile_k].f16x2[0][min_tile_m]),
UpCast<Element, half_t, true>(A_reg[(m_idx* (BLOCK_K/32) + k_idx)*2 + 1][min_tile_k].f16x2[0][min_tile_m]),
UpCast<Element, half_t, true>(A_reg[(m_idx* (BLOCK_K/32) + k_idx)*2 + 0][min_tile_k].f16x2[1][min_tile_m]),
UpCast<Element, half_t, true>(A_reg[(m_idx* (BLOCK_K/32) + k_idx)*2 + 1][min_tile_k].f16x2[1][min_tile_m])},
vec4_Element<half_t>{
UpCast<Element, half_t, true>(B_reg[(k_idx)*(WARP_N/32) + n_idx][0*2 + min_tile_n].f16x2[min_tile_k][0]),
UpCast<Element, half_t, true>(B_reg[(k_idx)*(WARP_N/32) + n_idx][1*2 + min_tile_n].f16x2[min_tile_k][0]),
UpCast<Element, half_t, true>(B_reg[(k_idx)*(WARP_N/32) + n_idx][0*2 + min_tile_n].f16x2[min_tile_k][1]),
UpCast<Element, half_t, true>(B_reg[(k_idx)*(WARP_N/32) + n_idx][1*2 + min_tile_n].f16x2[min_tile_k][1])},
C_reg[(m_loop-1) * ((WARP_M/32)*(WARP_N/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32);
} else {
C_reg[(m_loop-1) * ((WARP_M/32)*(WARP_N/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
vec4_Element<Element>{
A_reg[(m_idx* (BLOCK_K/32) + k_idx)*2 + 0][min_tile_k].f16x2[0][min_tile_m],
A_reg[(m_idx* (BLOCK_K/32) + k_idx)*2 + 1][min_tile_k].f16x2[0][min_tile_m],
A_reg[(m_idx* (BLOCK_K/32) + k_idx)*2 + 0][min_tile_k].f16x2[1][min_tile_m],
A_reg[(m_idx* (BLOCK_K/32) + k_idx)*2 + 1][min_tile_k].f16x2[1][min_tile_m]},
vec4_Element<Element>{
B_reg[(k_idx)*(WARP_N/32) + n_idx][0*2 + min_tile_n].f16x2[min_tile_k][0],
B_reg[(k_idx)*(WARP_N/32) + n_idx][1*2 + min_tile_n].f16x2[min_tile_k][0],
B_reg[(k_idx)*(WARP_N/32) + n_idx][0*2 + min_tile_n].f16x2[min_tile_k][1],
B_reg[(k_idx)*(WARP_N/32) + n_idx][1*2 + min_tile_n].f16x2[min_tile_k][1]},
C_reg[(m_loop-1) * ((WARP_M/32)*(WARP_N/32)) + (n_idx*(WARP_M/32) + m_idx)][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
}
}
asm volatile("s_setprio 0");
if(STAGES > 1) {
if constexpr(Is_preload_A || Is_store_A){
stage_id ++;
} else {
stage_id ^=1;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
} else {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
}
if constexpr (Is_preload_C) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
int C_block_buffer_load_global_offset = 3*BLOCK_M;
// A_ptr buffer load mini size is 32*32, buffer_load_dword mini size is 4*32
int C_lane_m_idx = lane_id % 16;
// int A_lane_k_idx = lane_id / 16;
int C_lane_k_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1); //(0, 1, 2, 3) --> (0, 2, 1, 3)
const int C_lds_load_num = (BLOCK_M*BLOCK_K) / (4*32);
for(int warp_loop_temp=0; warp_loop_temp< C_lds_load_num/WARP_NUM; warp_loop_temp++) {
int warp_loop = warp_loop_temp * WARP_NUM + warp_id;
//global->lds, right matrix
int C_warp_buffer_load_k_id = (warp_loop / (BLOCK_K/32)); //seq_len
int C_warp_buffer_load_m_id = (warp_loop % (BLOCK_M/32)); //head_dim
{
int C_warp_buffer_load_global_offset = (C_warp_buffer_load_m_id * 32);
int C_warp_buffer_load_lds_offset = (C_warp_buffer_load_m_id * 32) + (C_warp_buffer_load_k_id * 4 * BLOCK_M);
if(Is_store_A){
C_warp_buffer_load_lds_offset = (C_warp_buffer_load_m_id * 32) + (C_warp_buffer_load_k_id * (4 * BLOCK_M + 2));
}
int C_gsoffset = (C_block_buffer_load_global_offset + C_warp_buffer_load_global_offset)/2 ;
int C_gvoffset;
if constexpr (Is_even_MN){
C_gvoffset = ((C_lane_m_idx * 2 + (C_lane_k_idx + C_warp_buffer_load_k_id*4)* M))/2 ;
} else {
C_gvoffset = ((C_lane_m_idx * 2 + min(C_lane_k_idx + C_warp_buffer_load_k_id*4, K-1)* M))/2 ;
}
// int gvOffset = (64*8 + lane_id*8)/2;
int A_lds_offset = (3*BLOCK_K*BLOCK_M + C_warp_buffer_load_lds_offset)/2;
if(Is_store_A){
A_lds_offset = (3*(BLOCK_K/32)*(BLOCK_M/32)*32*34 + C_warp_buffer_load_lds_offset)/2;
}
builtin_buffer_load_dword_lds(A_lds , C_ptr, A_lds_offset, C_gsoffset, C_gvoffset);
}
}
}
#endif
#endif
}
// K BLOCK_K BLOCK_N BLOCK_M BLOCK_K WARP_N
template<bool Is_preload_A, bool Is_store_A, bool Is_even_MN, int M/*head_dim*/, int BLOCK_M, int BLOCK_N, int BLOCK_K, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum = float>
__forceinline__ __device__ void gpu_gemm_B_in_reg_gfx938(
vec4_uint A_ptr,
vec4_uint C_ptr,
Element* A_lds,
union_vec4_f16x2<Element> B_reg[(WARP_M/32)*(BLOCK_K/32)*2],
union_vec4_fp32 C_reg[(M/BLOCK_M)*(WARP_M/32)*(WARP_N/32)][4],
int N/*seq_kv*/,
int K/*seq_q*/,
int warp_id,
int seqlen_A_stride) {
#if 1
const int WARP_NUM = (BLOCK_M*BLOCK_N)/(WARP_M*WARP_N);
const int A_lds_load_num = (BLOCK_M*BLOCK_K) / (4*32);
static_assert(BLOCK_K>=32, "Error: gpu_gemm_B_in_reg gemm BLOCK_K must be equal or greater than 32");
static_assert(BLOCK_N>=WARP_N, "Error: gpu_gemm_B_in_reg gemm BLOCK_N must be equal or greater than WARP_N");
static_assert(BLOCK_M==WARP_M, "Error: gpu_gemm_B_in_reg gemm BLOCK_M must be equal to WARP_M");
union_vec4_f16x2<Element> A_reg[((WARP_M*BLOCK_K)/(32*32))*2];
//c mini tile is 32*32
// vec4_fp32 o[(WARP_M/32)*(WARP_N/32)][4]={0};
// __shared__ Element A_lds[STAGES*BLOCK_N * BLOCK_K];
//wave size should be defined in launch file. Here use 64 threads
int lane_id = threadIdx.x & 63; //lane id, 0-63
int row = lane_id % 4;
int col = lane_id / 4;
int stage_id = 0;
if(STAGES > 1 && (!Is_preload_A)) {
int m_loop = 0;
int A_block_buffer_load_global_offset = m_loop * BLOCK_M;
int A_lds_stage_offset = stage_id * BLOCK_M * BLOCK_K;
prefetch_to_lds_gfx938<false, BLOCK_M, BLOCK_K, Element, ElementAccum, Is_even_MN>(A_ptr, A_block_buffer_load_global_offset, A_lds + A_lds_stage_offset, seqlen_A_stride, warp_id);
}
#if 1
// int lds_offset = row * 8 + col * 32;
for(int m_loop = 1; m_loop<(M/BLOCK_M) + 1; m_loop++) {
if(STAGES > 1) {
if constexpr(Is_preload_A || Is_store_A){
stage_id ++;
} else {
stage_id = stage_id ^ 1;
}
}
if(STAGES == 1) {
m_loop--;
}
if((!Is_preload_A)&& m_loop < (M/BLOCK_M)) {
int A_block_buffer_load_global_offset = m_loop*BLOCK_M;
int A_lds_stage_offset = (stage_id)*BLOCK_K*BLOCK_M;
prefetch_to_lds_gfx938<false, BLOCK_M, BLOCK_K, Element, ElementAccum, Is_even_MN>(A_ptr, A_block_buffer_load_global_offset, A_lds + A_lds_stage_offset, seqlen_A_stride, warp_id);
}
//BM = 32, BK = 32
if(warp_id == 0) {
if(!Is_preload_A){
if(STAGES > 1 && m_loop < (M/BLOCK_M)) {
vmcnt_wait(1);
} else {
vmcnt_wait(0);
}
}
}
if constexpr (STAGES > 1) {
if constexpr(Is_preload_A || Is_store_A){
stage_id --;
} else {
stage_id = stage_id ^ 1;
}
}
//lds -> vgpr use ds_read_m; left matrix
if(!Is_preload_A) {
int A_lds_stage_offset = stage_id * BLOCK_K * BLOCK_M;
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(A_lds + A_lds_stage_offset), A_reg[0].f16, A_reg[1].f16, false);
if constexpr (std::is_same_v<Element, half_t>) {
A_reg[0].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(A_lds + A_lds_stage_offset, 0, 2, 1, 0);
A_reg[1].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(A_lds + A_lds_stage_offset, 1024, 2, 1, 0);
} else {
A_reg[0].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(A_lds + A_lds_stage_offset, 0, 2, 1, 0);
A_reg[1].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(A_lds + A_lds_stage_offset, 1024, 2, 1, 0);
}
} else {
// gfx938 m_ab = 0的gemm想要复用m_ab = 1的LDS数据
int A_lane_head_dim_idx = lane_id % 16;
int A_lane_seq_idx = lane_id / 16;
vec2_Element<Element> *A_lds_v2fp16 = (vec2_Element<Element> *)(A_lds);
for(int min_tile_k = 0; min_tile_k < 2; min_tile_k++) {
for(int vec_idx = 0; vec_idx < 4; vec_idx++) {
//dword为单位
int lds_offset = stage_id * BLOCK_K * BLOCK_M / 2 + A_lane_seq_idx * 4 * 16 + vec_idx * 16 + min_tile_k * 16 * 16;
lds_offset += (A_lane_head_dim_idx + vec_idx / 2 * 4 + (A_lane_seq_idx % 2) * 8) % 16;
// int lds_offset = stage_id * BLOCK_K * BLOCK_M / 2 + A_lane_seq_idx/2 * 64 + A_lane_seq_idx % 2 * 16 + min_tile_k * 32 + vec_idx * 128 + A_lane_head_dim_idx;
inline_ds_read_b32_wait(A_lds_v2fp16, lds_offset, A_reg[min_tile_k].f16x2[vec_idx]);
}
}
}
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
if constexpr (STAGES == 1){
m_loop++;
}
// asm volatile("s_setprio 1");
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int k_idx=0; k_idx<(BLOCK_K/32); k_idx++) { //BLOCK_K mini size is 32
//min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
if constexpr (std::is_same<Element,Float8_e4m3_t>::value){
} else {
//A采用ds_read后对应的mmac
C_reg[m_loop-1][min_tile_n*2 + min_tile_m].f32 = mmac<Element, ElementAccum>(
//BN = 32, BK = 32
// vec4_Element<Element>{A_reg[min_tile_k].f16[0*2 + min_tile_m], A_reg[min_tile_k].f16[1*2 + min_tile_m], A_reg[min_tile_k].f16[2*2 + min_tile_m], A_reg[min_tile_k].f16[3*2 + min_tile_m]},
vec4_Element<Element>{A_reg[min_tile_k].f16x2[0][min_tile_m], A_reg[min_tile_k].f16x2[1][min_tile_m], A_reg[min_tile_k].f16x2[2][min_tile_m], A_reg[min_tile_k].f16x2[3][min_tile_m]},
B_reg[min_tile_k].f16x4[min_tile_n],
C_reg[m_loop-1][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
}
}
// asm volatile("s_setprio 0");
if(STAGES > 1) {
if constexpr(Is_preload_A || Is_store_A){
stage_id ++;
} else {
stage_id ^=1;
__builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
} else {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
}
#endif
#endif
}
// K BLOCK_K BLOCK_N BLOCK_M BLOCK_K WARP_N
template<bool Is_preload_A, bool Is_store_A, bool Is_even_MN, int M/*head_dim*/, int BLOCK_M, int BLOCK_N, int BLOCK_K, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum = float>
__forceinline__ __device__ void gpu_gemm_B_in_reg_gfx946(
vec4_uint A_ptr,
vec4_uint C_ptr,
Element* A_lds,
union_vec4_f16x2<Element> B_reg[(WARP_M/32)*(BLOCK_K/32)*2],
union_vec4_fp32 C_reg[(M/BLOCK_M)*(WARP_M/32)*(WARP_N/32)][4],
int N/*seq_kv*/,
int K/*seq_q*/,
int warp_id,
int seqlen_A_stride) {
#if 1
const int WARP_NUM = (BLOCK_M*BLOCK_N)/(WARP_M*WARP_N);
const int A_lds_load_num = (BLOCK_M*BLOCK_K) / (4*32);
static_assert(BLOCK_K>=32, "Error: gpu_gemm_B_in_reg gemm BLOCK_K must be equal or greater than 32");
static_assert(BLOCK_N>=WARP_N, "Error: gpu_gemm_B_in_reg gemm BLOCK_N must be equal or greater than WARP_N");
static_assert(BLOCK_M==WARP_M, "Error: gpu_gemm_B_in_reg gemm BLOCK_M must be equal to WARP_M");
union_vec4_f16x2<Element> A_reg[((WARP_M*BLOCK_K)/(32*32))*2];
//c mini tile is 32*32
// vec4_fp32 o[(WARP_M/32)*(WARP_N/32)][4]={0};
// __shared__ Element A_lds[STAGES*BLOCK_N * BLOCK_K];
//wave size should be defined in launch file. Here use 64 threads
int lane_id = threadIdx.x & 63; //lane id, 0-63
int row = lane_id % 4;
int col = lane_id / 4;
int stage_id = 0;
if(STAGES > 1 && (!Is_preload_A)) {
int m_loop = 0;
int A_block_buffer_load_global_offset = m_loop * BLOCK_M;
int A_lds_stage_offset = stage_id * BLOCK_M * BLOCK_K;
prefetch_to_lds_gfx938<false, BLOCK_M, BLOCK_K, Element, ElementAccum, Is_even_MN>(A_ptr, A_block_buffer_load_global_offset, A_lds + A_lds_stage_offset, seqlen_A_stride, warp_id);
}
#if 1
// int lds_offset = row * 8 + col * 32;
for(int m_loop = 1; m_loop<(M/BLOCK_M) + 1; m_loop++) {
if(STAGES > 1) {
if constexpr(Is_preload_A || Is_store_A){
stage_id ++;
} else {
stage_id = stage_id ^ 1;
}
}
if(STAGES == 1) {
m_loop--;
}
if((!Is_preload_A)&& m_loop < (M/BLOCK_M)) {
int A_block_buffer_load_global_offset = m_loop*BLOCK_M;
int A_lds_stage_offset = (stage_id)*BLOCK_K*BLOCK_M;
prefetch_to_lds_gfx938<false, BLOCK_M, BLOCK_K, Element, ElementAccum, Is_even_MN>(A_ptr, A_block_buffer_load_global_offset, A_lds + A_lds_stage_offset, seqlen_A_stride, warp_id);
}
//BM = 32, BK = 32
if(warp_id == 0) {
if(!Is_preload_A){
if(STAGES > 1 && m_loop < (M/BLOCK_M)) {
vmcnt_wait(1);
} else {
vmcnt_wait(0);
}
}
}
if constexpr (STAGES > 1) {
if constexpr(Is_preload_A || Is_store_A){
stage_id --;
} else {
stage_id = stage_id ^ 1;
}
}
//lds -> vgpr use ds_read_m; left matrix
//由于ds_read方式发生了改变,mmac结果矩阵layout变化,存储的时候,offset要进行修改
{
int A_lds_stage_offset = stage_id * BLOCK_K * BLOCK_M;
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(A_lds + A_lds_stage_offset), A_reg[0].f16, A_reg[1].f16, false);
if constexpr (std::is_same_v<Element, half_t>) {
A_reg[0].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(A_lds + A_lds_stage_offset, 0, 2, 1, 0);
A_reg[1].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(A_lds + A_lds_stage_offset, 1024, 2, 1, 0);
} else {
A_reg[0].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(A_lds + A_lds_stage_offset, 0, 2, 1, 0);
A_reg[1].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(A_lds + A_lds_stage_offset, 1024, 2, 1, 0);
}
}
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
if constexpr (STAGES == 1){
m_loop++;
}
asm volatile("s_setprio 1");
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int k_idx=0; k_idx<(BLOCK_K/32); k_idx++) { //BLOCK_K mini size is 32
//min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
if constexpr (std::is_same<Element,Float8_e4m3_t>::value){
} else {
//A采用ds_read后对应的mmac
C_reg[m_loop-1][min_tile_n*2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
//BN = 32, BK = 32
// vec4_Element<Element>{A_reg[min_tile_k].f16[0*2 + min_tile_m], A_reg[min_tile_k].f16[1*2 + min_tile_m], A_reg[min_tile_k].f16[2*2 + min_tile_m], A_reg[min_tile_k].f16[3*2 + min_tile_m]},
B_reg[min_tile_k].f16x4[min_tile_n],
A_reg[min_tile_k].f16x4[min_tile_m],
C_reg[m_loop-1][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
}
}
// //test only
// for(int min_tile_n = 0; min_tile_n < 2; ++ min_tile_n) {
// for(int min_tile_m = 0; min_tile_m < 2; ++ min_tile_m) {
// C_reg[m_loop-1][min_tile_n*2 + min_tile_m].f32[0] = UpCast<Element,float, true>(B_reg[min_tile_m].f16x4[min_tile_n][0]);
// C_reg[m_loop-1][min_tile_n*2 + min_tile_m].f32[1] = UpCast<Element,float, true>(B_reg[min_tile_m].f16x4[min_tile_n][1]);
// C_reg[m_loop-1][min_tile_n*2 + min_tile_m].f32[2] = UpCast<Element,float, true>(B_reg[min_tile_m].f16x4[min_tile_n][2]);
// C_reg[m_loop-1][min_tile_n*2 + min_tile_m].f32[3] = UpCast<Element,float, true>(B_reg[min_tile_m].f16x4[min_tile_n][3]);
// }
// }
asm volatile("s_setprio 0");
if(STAGES > 1) {
if constexpr(Is_preload_A || Is_store_A){
stage_id ++;
} else {
stage_id ^=1;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
} else {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
}
#endif
#endif
}
\ No newline at end of file
#pragma once
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "intrinsic.h"
#include "numeric_types.h"
#include "intrinsic_mls_ds.h"
#include "prefetch.h"
// 无预取:prefetch_level = 0; 预取到LDS:prefetch_level = 1; 预取到寄存器:prefetch_level = 2;
template<bool Is_store_B, bool Is_preload_C, bool Is_even_MN, int A_prefetch_level, int B_prefetch_level, int K, int BLOCK_M, int BLOCK_N, int BLOCK_K, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum = float>
__forceinline__ __device__ void gemm_tt_kq(vec4_uint A_ptr, vec4_uint B_ptr, Element* A_lds, Element* B_lds, int max_m_len_offset, int max_n_len_offset, union_vec2_f16x2<Element> A_reg[(K/BLOCK_K)*((WARP_M*BLOCK_K)/(32*32))*2/((A_prefetch_level == 3)? 1 : 2)][2], union_vec2_f16x2<Element> B_reg[STAGES*((WARP_N*BLOCK_K)/(32*32))*2][2], union_vec4_fp32 C_reg[(WARP_M/32)*(BLOCK_N/32)][4], int warp_id, int seqlen_A_stride, int seqlen_B_stride) {
const int WARP_NUM = BLOCK_M/WARP_M;
//wave size should be defined in launch file. Here use 64 threads
int lane_id = threadIdx.x & 63; //lane id, 0-63
int row = lane_id % 4;
int col = lane_id / 4;
#if 1
for(int n_loop = 0 ; n_loop < BLOCK_N/WARP_N; n_loop++)
{
int stage_id = 0;
int stage_id_reg = 0;
{ int k_loop = 0;
if(STAGES > 1) {
if(A_prefetch_level == 0) {
int A_block_buffer_load_global_offset = k_loop * BLOCK_K;
int A_lds_stage_offset = stage_id * (BLOCK_M/32) * (BLOCK_K/32)*(32*34);
buffer_load_lds_tile_pad(Is_even_MN, WARP_NUM, seqlen_A_stride, BLOCK_M, BLOCK_K, Element, A_ptr, A_lds, A_block_buffer_load_global_offset, A_lds_stage_offset, K, warp_id, lane_id);
}
if(B_prefetch_level == 0) {
int B_block_buffer_load_global_offset = k_loop * BLOCK_K + n_loop * WARP_N * K;
int B_lds_stage_offset = stage_id * (WARP_N/32) * (BLOCK_K/32)*(32*34);
if constexpr (Is_store_B){
B_lds_stage_offset += n_loop * (K/32) * (WARP_N/32)*(32*34);
}
buffer_load_lds_tile_pad_1(Is_even_MN, WARP_NUM, seqlen_B_stride, WARP_N, BLOCK_K, Element, B_ptr, B_lds, B_block_buffer_load_global_offset, B_lds_stage_offset, K, warp_id, lane_id);
}
}
}
// int lds_offset = row * 8 + col * 32;
for(int k_loop = 1; k_loop<(K/BLOCK_K) + 1; k_loop++) {
if constexpr (STAGES > 1) {
if constexpr (Is_store_B || B_prefetch_level == 1){
stage_id++;
} else {
stage_id ^= 1;
}
}
if(STAGES == 1) {
k_loop--;
}
if(k_loop < (K/BLOCK_K)){
if(A_prefetch_level == 0 || (A_prefetch_level == 1 && k_loop >= (K/BLOCK_K)/2)) {
int A_block_buffer_load_global_offset = k_loop * BLOCK_K;
int A_lds_stage_offset = stage_id * (BLOCK_M/32) * (BLOCK_K/32)*(32*34);
buffer_load_lds_tile_pad(Is_even_MN, WARP_NUM, seqlen_A_stride, BLOCK_M, BLOCK_K, Element, A_ptr, A_lds, A_block_buffer_load_global_offset, A_lds_stage_offset, K, warp_id, lane_id);
}
if(B_prefetch_level == 0) {
int B_block_buffer_load_global_offset = k_loop * BLOCK_K + n_loop * WARP_N * K;
int B_lds_stage_offset = stage_id * (WARP_N/32) * (BLOCK_K/32)*(32*34);
if constexpr (Is_store_B || B_prefetch_level == 1){
B_lds_stage_offset += n_loop * (K/32) * (WARP_N/32)*(32*34);
}
buffer_load_lds_tile_pad_1(Is_even_MN, WARP_NUM, seqlen_B_stride, WARP_N, BLOCK_K, Element, B_ptr, B_lds, B_block_buffer_load_global_offset, B_lds_stage_offset, K, warp_id, lane_id);
}
}
else if (B_prefetch_level==0){
vmcnt_wait(0);
}
int precompute_B_lds_offset[2*2];
//lds -> vgpr use ds_read_m; right matrix
int k_warp_n_id = (warp_id & (WARP_N/WARP_N - 1));
int k_lds_stage_offset = STAGES == 1 ? 0 : ( (Is_store_B || B_prefetch_level == 1) ? (stage_id - 1) * (WARP_N/32) * (BLOCK_K/32)*(32*17) : (stage_id ^ 1) * (WARP_N/32) * (BLOCK_K/32)*(32*17));
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(B_lds);
//a warp load min size is (row, col) = (32,16) float
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(BLOCK_K/32); head_dim_idx++) { //32 half in col direction
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int vec_idx = 0; vec_idx < 2; vec_idx ++) {
int lds_offset = k_lds_stage_offset + head_dim_idx*(WARP_N*17) + (k_warp_n_id*(WARP_N/32) + n_idx)*(32*17) + vec_idx * 4 + min_tile_n*32 + (lane_id & 1)*16 + ((lane_id & 15)>>1)*64 + /*padding*/ ((lane_id & 15)>>1) + ((lane_id/16) & 1)*8 + (lane_id/32);
precompute_B_lds_offset[min_tile_n * 2 + vec_idx] = lds_offset;
if constexpr (Is_store_B || B_prefetch_level == 1){
precompute_B_lds_offset[min_tile_n * 2 + vec_idx] += n_loop * (WARP_N/32) * (K/32)*(32*17);
}
}
}
}
}
if(STAGES > 1) {
if constexpr(B_prefetch_level==1){
if constexpr (std::is_same<Element,Float8_e4m3_t>::value){
vmcnt_wait(0);
} else {
vmcnt_wait(((BLOCK_N/WARP_N * K/BLOCK_K)*(Is_preload_C ? 2 : 1) - (n_loop * (K/BLOCK_K) + k_loop)) * (WARP_N*BLOCK_K) / (4*32)/WARP_NUM);
}
} else {
if(k_loop < (K/BLOCK_K)){
if(A_prefetch_level == 0 && B_prefetch_level != 0) {
buffer_load_lds_dwordx1_wait<(BLOCK_M * BLOCK_K) / (4*32)/WARP_NUM>();
} else if(A_prefetch_level != 0 && B_prefetch_level == 0) {
buffer_load_lds_dwordx1_wait<(WARP_N*BLOCK_K) / (4*32)/WARP_NUM>();
} else if(A_prefetch_level == 0 && B_prefetch_level == 0) {
buffer_load_lds_dwordx1_wait<(BLOCK_M * BLOCK_K) / (4*32)/WARP_NUM + (WARP_N*BLOCK_K) / (4*32)/WARP_NUM>();
}
}
}
} else {
vmcnt_wait(0);
}
if constexpr (STAGES > 1) {
if constexpr (Is_store_B || B_prefetch_level == 1){
stage_id--;
} else {
stage_id ^= 1;
}
}
union_vec2_f16x2<Element> A_reg_tmp[2][2];
if (A_prefetch_level == 0 || (A_prefetch_level != 3 && k_loop >= (K/BLOCK_K)/2 )) {
//lds -> vgpr use ds_read_m; left matrix
int A_warp_m_id = (warp_id & ((BLOCK_M/WARP_M) - 1));
int A_lds_stage_offset = stage_id * (BLOCK_M/32) * (BLOCK_K/32)*(32*17);
vec2_Element<Element> *A_lds_v2fp16 = (vec2_Element<Element> *)(A_lds);
asm volatile("s_setprio 1");
// #pragma unroll
// for(int head_dim_idx=0; head_dim_idx<(BLOCK_K/32); head_dim_idx++) { //32 half in col direction
// #pragma unroll
// for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
// //a warp load min size is (row, col) = (32,16) float
// #pragma unroll
// for(int i=0; i<2; i++) { //sequence direction
// #pragma unroll
// for(int j=0; j<2; j++) { //head dim direction
// int lds_offset = A_lds_stage_offset + head_dim_idx*BLOCK_M*17 + (warp_id*(WARP_M/32) + m_idx)*(32*17) + j*4 + i*32 + (lane_id & 1)*16 + ((lane_id & 15)>>1)*64 + /*padding*/ ((lane_id & 15)>>1) + ((lane_id/16) &1)*8 + (lane_id/32);
// inline_ds_read2_b32_no_wait(A_lds_v2fp16, lds_offset, A_reg_tmp[(head_dim_idx*(WARP_M/32) + m_idx)*2 + i][j].f32, 2);
// }
// // #pragma unroll
// // for(int j=0; j<4; j++) { //head dim direction
// // int lds_offset = A_lds_stage_offset + head_dim_idx*BLOCK_M*17 + (warp_id*(WARP_M/32) + m_idx)*(32*17) + j*2 + i*32 + (lane_id & 1)*16 + ((lane_id & 15)>>1)*64 + /*padding*/ ((lane_id & 15)>>1) + ((lane_id/16) &1)*8 + (lane_id/32);
// // inline_ds_read_b32_wait(A_lds_v2fp16, lds_offset, A_reg_tmp[(head_dim_idx*(WARP_M/32) + m_idx)*2 + i][j/2].f16x2[j%2]);
// // }
// }
// }
// }
ds_read_tile_pad<WARP_M, BLOCK_K, WARP_NUM, Element>(A_lds_v2fp16, A_lds_stage_offset, A_reg_tmp, 0, warp_id, lane_id);
asm volatile("s_setprio 0");
}
// int k_warp_n_id = (warp_id & (WARP_N/WARP_N - 1));
// int k_lds_stage_offset = STAGES == 1 ? 0 : (stage_id ) * (WARP_N/32) * (BLOCK_K/32)*(32*17);
// if constexpr (Is_store_B || B_prefetch_level == 1){
// k_lds_stage_offset += n_loop * (WARP_N/32) * (K/32)*(32*17);
// }
// vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(B_lds);
ds_read2_tile_pad_no_wait(WARP_M, BLOCK_K, WARP_NUM, Element, k_lds_v2fp16, precompute_B_lds_offset, B_reg, stage_id_reg, k_warp_n_id, lane_id);
// ds_read2_tile_pad_no_wait<WARP_M, BLOCK_K, WARP_NUM, Element>(k_lds_v2fp16, k_lds_stage_offset, B_reg, stage_id_reg, k_warp_n_id, lane_id);
// ds_read_tile_pad<WARP_M, BLOCK_K, WARP_NUM, Element>(k_lds_v2fp16, k_lds_stage_offset, B_reg, stage_id_reg, k_warp_n_id, lane_id);
if constexpr (STAGES == 1){
k_loop++;
}
//min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
asm volatile("s_setprio 1");
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
const int lgkmcnt = 2 - min_tile_n*2;
lgkmcnt_wait(lgkmcnt);
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
for(int head_dim_idx=0; head_dim_idx< (BLOCK_K/32); head_dim_idx++) {
#pragma unroll
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
if constexpr (std::is_same<Element,Float8_e4m3_t>::value){
if (A_prefetch_level == 0 || (A_prefetch_level != 3 && k_loop >= (K/BLOCK_K)/2 + 1 )){
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx) * (WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32 = flash::mmac<half_t, ElementAccum>(
vec4_Element<half_t>{
UpCast<Element, half_t, true>(A_reg_tmp[(head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[0][0]),
UpCast<Element, half_t, true>(A_reg_tmp[(head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[0][1]),
UpCast<Element, half_t, true>(A_reg_tmp[(head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[1][0]),
UpCast<Element, half_t, true>(A_reg_tmp[(head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[1][1])},
vec4_Element<half_t>{
UpCast<Element, half_t, true>(B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[0][0]),
UpCast<Element, half_t, true>(B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[0][1]),
UpCast<Element, half_t, true>(B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[1][0]),
UpCast<Element, half_t, true>(B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[1][1])},
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx) * (WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32);
} else {
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32 = flash::mmac<half_t, ElementAccum>(
vec4_Element<half_t>{
UpCast<Element, half_t, true>(A_reg[(k_loop-1)*((WARP_M*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[0][0]),
UpCast<Element, half_t, true>(A_reg[(k_loop-1)*((WARP_M*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[0][1]),
UpCast<Element, half_t, true>(A_reg[(k_loop-1)*((WARP_M*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[1][0]),
UpCast<Element, half_t, true>(A_reg[(k_loop-1)*((WARP_M*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[1][1])},
vec4_Element<half_t>{
UpCast<Element, half_t, true>(B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[0][0]),
UpCast<Element, half_t, true>(B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[0][1]),
UpCast<Element, half_t, true>(B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[1][0]),
UpCast<Element, half_t, true>(B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[1][1])},
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32);
}
} else {
if (A_prefetch_level == 0 || (A_prefetch_level != 3 && k_loop >= (K/BLOCK_K)/2 + 1 )){
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx) * (WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
vec4_Element<Element>{
A_reg_tmp[(head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[0][0],
A_reg_tmp[(head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[0][1],
A_reg_tmp[(head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[1][0],
A_reg_tmp[(head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[1][1]},
vec4_Element<Element>{
B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[0][0],
B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[0][1],
B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[1][0],
B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[1][1]},
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx) * (WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32);
} else {
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32 = flash::mmac<Element, ElementAccum>(
vec4_Element<Element>{
A_reg[(k_loop-1)*((WARP_M*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[0][0],
A_reg[(k_loop-1)*((WARP_M*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[0][1],
A_reg[(k_loop-1)*((WARP_M*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[1][0],
A_reg[(k_loop-1)*((WARP_M*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_M/32) + m_idx)*2 + min_tile_m][min_tile_k].f16x2[1][1]},
vec4_Element<Element>{
B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[0][0],
B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[0][1],
B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[1][0],
B_reg[(stage_id_reg *((WARP_N*BLOCK_K)/(32*32))*2 + (head_dim_idx*(WARP_N/32) + n_idx))*2 + min_tile_n][min_tile_k].f16x2[1][1]},
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
}
}
}
asm volatile("s_setprio 0");
if constexpr (STAGES > 1){
if constexpr (!Is_store_B && B_prefetch_level !=1) {
stage_id ^= 1;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
} else{
stage_id++;
}
} else {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
}
}
#endif
}
// 无预取:prefetch_level = 0; 预取到LDS:prefetch_level = 1; 预取到寄存器:prefetch_level = 2;
template<bool Is_store_B, bool Is_preload_C, bool Is_even_MN, int A_prefetch_level, int B_prefetch_level, int K, int BLOCK_M, int BLOCK_N, int BLOCK_K, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum = float>
__forceinline__ __device__ void gemm_tt_kq_gfx938(
vec4_uint A_ptr,
vec4_uint B_ptr,
Element* A_lds,
Element* B_lds,
int max_m_len_offset,
int max_n_len_offset,
union_vec4_f16x2<Element> A_reg[(K/BLOCK_K)*((WARP_M*BLOCK_K)/(32*32))*2/((A_prefetch_level == 3)? 1 : 2)],
union_vec4_f16x2<Element> B_reg[STAGES*((WARP_N*BLOCK_K)/(32*32))*2],
union_vec4_fp32 C_reg[(WARP_M/32)*(BLOCK_N/32)][4],
int warp_id,
int seqlen_A_stride,
int seqlen_B_stride) {
const int ELEMENT_BYTES = sizeof(Element);
const int WARP_NUM = BLOCK_M/WARP_M;
//wave size should be defined in launch file. Here use 64 threads
int lane_id = threadIdx.x & 63; //lane id, 0-63
int row = lane_id % 4;
int col = lane_id / 4;
#if 1
for(int n_loop = 0 ; n_loop < BLOCK_N/WARP_N; n_loop++)
{
int stage_id = 0;
int stage_id_reg = 0;
{ int k_loop = 0;
if(STAGES > 1) {
if(A_prefetch_level == 0) {
int A_block_buffer_load_global_offset = k_loop * BLOCK_K;
int A_lds_stage_offset = stage_id * BLOCK_M* BLOCK_K;
prefetch_to_lds_gfx938<true, BLOCK_M, BLOCK_K, Element, ElementAccum, Is_even_MN>(A_ptr, A_block_buffer_load_global_offset, A_lds + A_lds_stage_offset, seqlen_A_stride, warp_id);
}
if(B_prefetch_level == 0) {
int B_block_buffer_load_global_offset = k_loop * BLOCK_K + n_loop * WARP_N * K;
int B_lds_stage_offset = stage_id * WARP_N * BLOCK_K;
if constexpr (Is_store_B){
B_lds_stage_offset += n_loop * K * WARP_N;
}
prefetch_to_lds_gfx938<true, WARP_N, BLOCK_K, Element, ElementAccum, Is_even_MN>(B_ptr, B_block_buffer_load_global_offset, B_lds + B_lds_stage_offset, seqlen_B_stride, warp_id);
}
}
}
// int lds_offset = row * 8 + col * 32;
for(int k_loop = 1; k_loop<(K/BLOCK_K) + 1; k_loop++) {
if constexpr (STAGES > 1) {
if constexpr (Is_store_B || B_prefetch_level == 1){
stage_id++;
} else {
stage_id ^= 1;
}
}
if(STAGES == 1) {
k_loop--;
}
if(k_loop < (K/BLOCK_K)){
if(A_prefetch_level == 0 || (A_prefetch_level == 1 && k_loop >= (K/BLOCK_K)/2)) {
int A_block_buffer_load_global_offset = k_loop * BLOCK_K;
int A_lds_stage_offset = stage_id * BLOCK_M* BLOCK_K;
prefetch_to_lds_gfx938<true, BLOCK_M, BLOCK_K, Element, ElementAccum, Is_even_MN>(A_ptr, A_block_buffer_load_global_offset, A_lds + A_lds_stage_offset, seqlen_A_stride, warp_id);
}
if(B_prefetch_level == 0) {
int B_block_buffer_load_global_offset = k_loop * BLOCK_K + n_loop * WARP_N * K;
int B_lds_stage_offset = stage_id * WARP_N * BLOCK_K;
if constexpr (Is_store_B){
B_lds_stage_offset += n_loop * K * WARP_N;
}
prefetch_to_lds_gfx938<true, WARP_N, BLOCK_K, Element, ElementAccum, Is_even_MN>(B_ptr, B_block_buffer_load_global_offset, B_lds + B_lds_stage_offset, seqlen_B_stride, warp_id);
}
}
else if (B_prefetch_level==0){
vmcnt_wait_nosync(0);
}
//MLS可以一条指令读32*32,而buffer_load_dword一条指令读2*64的数据,所以waitcnt需要进行修改
//且MLS是一个warp读32*32,4个warp读128*32,需要判断warp_id来wait
if(STAGES > 1) {
if constexpr(B_prefetch_level==1){
if((k_loop - 1) % WARP_NUM == warp_id)
{
if(Is_preload_C) {
vmcnt_wait_nosync(1);
} else {
vmcnt_wait_nosync(0);
}
}
} else {
if(k_loop < (K/BLOCK_K)){
if(A_prefetch_level == 0 && B_prefetch_level != 0) {
//BM = 128
vmcnt_wait_nosync((BLOCK_M * BLOCK_K) / (32*32)/WARP_NUM);
} else if(A_prefetch_level != 0 && B_prefetch_level == 0) {
//BN = 32
if(warp_id == 0) {
vmcnt_wait_nosync(1);
}
} else if(A_prefetch_level == 0 && B_prefetch_level == 0) {
//BM= 128 & BN = 32
if(warp_id == 0) {
vmcnt_wait_nosync((BLOCK_M * BLOCK_K) / (32*32)/WARP_NUM + 1);
} else {
vmcnt_wait_nosync(1);
}
}
}
}
} else {
vmcnt_wait_nosync(0);
}
__syncthreads();
if constexpr (STAGES > 1) {
if constexpr (Is_store_B || B_prefetch_level == 1){
stage_id--;
} else {
stage_id ^= 1;
}
}
union_vec4_f16x2<Element> A_reg_tmp[2];
if (A_prefetch_level == 0 || (A_prefetch_level != 3 && k_loop >= (K/BLOCK_K)/2 )) {
int A_lds_stage_offset = stage_id * BLOCK_M * BLOCK_K;
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(A_lds + A_lds_stage_offset), A_reg_tmp[0].f16, A_reg_tmp[1].f16, true);
if constexpr (std::is_same_v<Element, half_t>) {
A_reg_tmp[0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(A_lds + A_lds_stage_offset, 0, 2, 1, 0);
A_reg_tmp[1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(A_lds + A_lds_stage_offset, 1024, 2, 1, 0);
} else {
A_reg_tmp[0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(A_lds + A_lds_stage_offset, 0, 2, 1, 0);
A_reg_tmp[1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(A_lds + A_lds_stage_offset, 1024, 2, 1, 0);
}
}
int B_lds_stage_offset = stage_id * WARP_N * BLOCK_K;
DS_READ_MATRIX_32X32_B16(ds_offset_cast(B_lds + B_lds_stage_offset), B_reg[0].f16, B_reg[1].f16, true);
if constexpr (STAGES == 1){
k_loop++;
}
//min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
// asm volatile("s_setprio 1");
lgkmcnt_wait(0);
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
for(int head_dim_idx=0; head_dim_idx< (BLOCK_K/32); head_dim_idx++) {
#pragma unroll
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
if constexpr (std::is_same<Element,Float8_e4m3_t>::value){
} else {
if (A_prefetch_level == 0 || (A_prefetch_level != 3 && k_loop >= (K/BLOCK_K)/2 + 1 )){
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx) * (WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
A_reg_tmp[min_tile_m].f16x4[min_tile_k],
B_reg[min_tile_n].f16x4[min_tile_k],
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx) * (WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32);
} else {
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx) * (WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
A_reg[(k_loop - 1) * 2 + min_tile_m].f16x4[min_tile_k],
B_reg[min_tile_n].f16x4[min_tile_k],
C_reg[n_loop*(WARP_N/32*WARP_M/32) + (n_idx)*(WARP_M/32) + m_idx][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
}
}
}
// asm volatile("s_setprio 0");
if constexpr (STAGES > 1){
if constexpr (!Is_store_B && B_prefetch_level !=1) {
stage_id ^= 1;
__builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
} else{
stage_id++;
}
} else {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
}
}
#endif
}
\ No newline at end of file
#pragma once
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "utils.h"
#include "static_switch.h"
#include "numeric_types.h"
#include "intrinsic_mls_ds.h"
template<int K, int BLOCK_M, int BLOCK_K, int WARP_M, typename Element, typename ElementAccum, bool Is_even_MN>
inline __device__ void prefetch_to_vgpr(
vec4_uint k_ptr,
Element* k_lds,
union_vec2_f16x2<Element> k_reg[(K/BLOCK_K)*((WARP_M*BLOCK_K)/(32*32))*2][2],
int max_seq_k_offset,
int row_stride) {
const int WARP_NUM = (BLOCK_M)/(WARP_M);
const int k_lds_load_num = (BLOCK_M * BLOCK_K) / (4*32);
const int K_LOAD_REQUESTS = k_lds_load_num / WARP_NUM;
int warp_id =0;
int warp_id_vec = threadIdx.x / 64; //warp id in a block
warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
int k_warp_m_id = (warp_id & ((BLOCK_M/WARP_M) - 1));
int lane_id = threadIdx.x & 63; //lane id, 0-63
int k_lane_m_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1); //(0, 1, 2, 3) --> (0, 2, 1, 3)
int k_lane_head_dim_idx = lane_id & 15;
// int lds_offset = row * 8 + col * 32;
int stage_id = 0;
// MLS
vec4_uint k_srsrc;
k_srsrc[2] = row_stride; // stride
k_srsrc[3] = 0;
#pragma unroll
for(int k_loop = 0; k_loop<K/BLOCK_K; k_loop++) {
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
//global->lds, left matrix
int q_block_buffer_load_global_offset = k_loop * BLOCK_K ;//+ block_id_m * BLOCK_M * K;
// k_ptr buffer load mini size is 4*32, (BLOCK_M * BLOCK_K) mini size is (32*32)
int k_lds_stage_offset = stage_id * (BLOCK_M/32) * (BLOCK_K/32)*(32*34);
for(int load = 0,warp_loop = warp_id; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7)*2; // padding size in shared memory per buffer load, to avoid bank conflict
int k_warp_buffer_load_m_id = (warp_loop & (BLOCK_M/4 - 1)); //这样子对L1和utlc1有啥影响呢?
// int q_warp_buffer_load_k_id = (warp_loop / (BLOCK_M/4));
int q_warp_buffer_load_lds_offset = k_lds_stage_offset/* + (q_warp_buffer_load_k_id * BLOCK_M * 34)*/ + ((k_warp_buffer_load_m_id >> 3)*(32*34) + (k_warp_buffer_load_m_id & 7)*(4*32));
// int q_warp_buffer_load_global_offset = (q_warp_buffer_load_k_id * 32);
int gvOffset_s = (q_block_buffer_load_global_offset/* + q_warp_buffer_load_global_offset*/) / 2;
int gvOffset_v;
if constexpr (not Is_even_MN) {
gvOffset_v = ((min(k_warp_buffer_load_m_id * 4 + k_lane_m_idx, max_seq_k_offset - 1)) * row_stride) / 2 + k_lane_head_dim_idx;
} else {
gvOffset_v = ((k_warp_buffer_load_m_id * 4 + k_lane_m_idx) * row_stride) / 2 + k_lane_head_dim_idx;
}
int lds_offset = (q_warp_buffer_load_lds_offset + padding) / 2; // + lane_id;
builtin_buffer_load_dword_lds_bypass_glc_slc(k_lds, k_ptr, lds_offset, gvOffset_s, gvOffset_v);
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
// k_lds_stage_offset = stage_id * (BLOCK_M/32) * (BLOCK_K/32)*(32*17);
vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
ds_read_tile_pad(WARP_M, BLOCK_K, WARP_NUM, Element, k_lds_v2fp16, k_lds_stage_offset, k_reg, k_loop, warp_id, lane_id);
}
}
}
//matrix_load单位:32 * 32
//ds_read_matrix单位:32 * 16
//M = 128, N = 128
template<bool trans, int M, int N, typename Element, typename ElementAccum, bool Is_even_MN>
inline __device__ void prefetch_to_vgpr_gfx938(
vec4_uint ptr,
Element* lds,
union_vec4_f16x2<Element> reg[M * N / (64 * 8)],//vec4_fp16x2有8个element,64个线程
int max_column_offset,
int warp_id) {
constexpr int ELEMENT_BYTES = sizeof(Element);
const int stages = 2;
const int WARP_NUM = 4;
int row_stride = ptr[2];
vec4_uint srsrc;
srsrc[2] = row_stride;
srsrc[3] = 0;
//计算LDS地址,每个warp使用一个32*32
int lds_offset = (warp_id * 32 * 32);
size_t lds_load_offset = reinterpret_cast<size_t>(lds) + lds_offset * ELEMENT_BYTES;
int stages_id = 0;
if(stages == 2) {
int m_loop = 0;
int n_loop = 0;
int global_offset = (warp_id * row_stride * 32 + n_loop * 32);
int lds_offset_stage = (lds_offset + stages_id * (WARP_NUM * 32 * 32)) * ELEMENT_BYTES;
if constexpr (!Is_even_MN) {
//对M方向进行边界判断,看需要pad多少0
int nm_filter_max = (m_loop * 128 + (warp_id + 1) * 32) - max_column_offset;
int nm_filter = max(0, (m_loop * 128 + (warp_id + 1) * 32) - max_column_offset);
if(nm_filter_max >= 32) {
global_offset = (0 * row_stride * 32 + n_loop * 32);
nm_filter = max(0, (m_loop * 128 + 0 * 32) - max_column_offset);
}
srsrc[3] = nm_filter << 8; // set only once
}
*(uint64_t*)&srsrc = VA_LIMIT_BITS(*(uint64_t*)&ptr + global_offset * ELEMENT_BYTES);
if(trans) {
inline_matrix_load_32x32_b16_lds_trans<0, 0>(lds, srsrc, lds_offset_stage, 0);
} else {
inline_matrix_load_32x32_b16_lds<0, 0>(lds, srsrc, lds_offset_stage, 0);
}
}
for(int m_loop = 0; m_loop < M / 128; ++m_loop) {
for(int n_loop = stages - 1; n_loop < N / 32 + stages - 1; ++n_loop) {
if(stages == 2) {
stages_id ^= 1;
}
//更新global地址
int global_offset = (warp_id * row_stride * 32 + n_loop * 32);
int lds_offset_stage = (lds_offset + stages_id * (WARP_NUM * 32 * 32)) * ELEMENT_BYTES;
// size_t lds_load_offset_stage = reinterpret_cast<size_t>(lds) + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) * ELEMENT_BYTES + lds_offset * ELEMENT_BYTES;
if constexpr (!Is_even_MN) {
//对M方向进行边界判断,看需要pad多少0
int nm_filter_max = (m_loop * 128 + (warp_id + 1) * 32) - max_column_offset;
int nm_filter = max(0, (m_loop * 128 + (warp_id + 1) * 32) - max_column_offset);
if(nm_filter_max >= 32) {
global_offset = (0 * row_stride * 32 + n_loop * 32);
nm_filter = max(0, (m_loop * 128 + 0 * 32) - max_column_offset);
}
srsrc[3] = nm_filter << 8; // set only once
}
*(uint64_t*)&srsrc = VA_LIMIT_BITS(*(uint64_t*)&ptr + global_offset * ELEMENT_BYTES);
if(n_loop < N / 32) {
if(trans) {
inline_matrix_load_32x32_b16_lds_trans<0, 0>(lds, srsrc, lds_offset_stage, 0);
} else {
inline_matrix_load_32x32_b16_lds<0, 0>(lds, srsrc, lds_offset_stage, 0);
}
}
if(stages == 2 && n_loop < N /32) {
vmcnt_wait_nosync(1);
} else {
vmcnt_wait_nosync(0);
}
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
if(trans){
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(lds_load_offset_stage), reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16, reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16, true);
if constexpr (std::is_same_v<Element, half_t>) {
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 1024, 2, 1, 0);
} else {
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 1024, 2, 1, 0);
}
} else {
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(lds_load_offset_stage), reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16, reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16, false);
if constexpr (std::is_same_v<Element, half_t>) {
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 1024, 2, 1, 0);
} else {
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 1024, 2, 1, 0);
}
}
lgkmcnt_wait(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
}
}
}
//matrix_load单位:32 * 32
//ds_read_matrix单位:32 * 16
//M = 32, N = 128
template<bool trans, int M, int N, typename Element, typename ElementAccum, bool Is_even_MN, int WARP_NUM = 4>
inline __device__ void prefetch_to_lds_gfx938(
vec4_uint ptr,
int global_start_offset,
Element* lds,
int max_column_offset,
int warp_id) {
const int ELEMENT_BYTES = sizeof(Element);
const int LOAD_NUM = M * N / (32 * 32);
int row_stride = ptr[2];
vec4_uint srsrc;
srsrc[2] = row_stride;
srsrc[3] = 0;
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
//直接拉通M * N,看有多少个 32*32 的矩阵需要load
for(int loop = 0; loop < (LOAD_NUM + WARP_NUM - 1) / WARP_NUM; loop++) {
int loop_warp = loop * WARP_NUM + warp_id;
if (loop_warp < LOAD_NUM) {
int m_loop = loop_warp / (N / 32);
int n_loop = loop_warp % (N / 32);
//更新global地址
int global_offset = (global_start_offset + m_loop * row_stride + n_loop * 32) * ELEMENT_BYTES;
if constexpr (!Is_even_MN) {
//对M方向进行边界判断,看需要pad多少0
int nm_filter_max = (m_loop + 1) * 32 - max_column_offset;
int nm_filter = nm_filter_max;
if(nm_filter_max >= 32) {
global_offset = (global_start_offset + 0 * row_stride + n_loop * 32) * ELEMENT_BYTES;
nm_filter = (0 + 1) * 32 - max_column_offset;
}
nm_filter = max(0, nm_filter);
srsrc[3] = nm_filter << 8; // set only once
}
*(uint64_t*)&srsrc = VA_LIMIT_BITS(*(uint64_t*)&ptr + global_offset);
//计算LDS地址,每个warp使用一个32*32;下一个loop重复利用
int lds_offset = (loop_warp * 32 * 32) * ELEMENT_BYTES;
int lds_load_offset = reinterpret_cast<size_t>(lds) + lds_offset;
if (trans) {
inline_matrix_load_32x32_b16_lds_trans<0, 0>(lds, srsrc, lds_offset, 0);
} else {
inline_matrix_load_32x32_b16_lds<0, 0>(lds, srsrc, lds_offset, 0);
}
}
}
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
}
template<bool Is_even_MN, int K/*head_dim*/, int BLOCK_M, int BLOCK_N, int BLOCK_K, int WARP_M, int WARP_N, typename Element>
__forceinline__ __device__ void prefetch_to_tmp_lds_wait(vec4_uint B_ptr, Element* B_lds, int max_n_len_offset, int warp_id, int row_stride)
{
const int WARP_NUM = BLOCK_M/WARP_M;
int lane_id = threadIdx.x & 63; //lane id, 0-63
for(int n_loop = 0 ; n_loop < BLOCK_N/WARP_N; n_loop++){
for(int k_loop = 0; k_loop < K/BLOCK_K; k_loop++) {
const int lgkmcnt = (BLOCK_N/WARP_N * K/BLOCK_K - 1) - (n_loop * K/BLOCK_K + k_loop);
lgkmcnt_wait(lgkmcnt);
int B_block_buffer_load_global_offset = k_loop * BLOCK_K + n_loop * WARP_N * K;
// headdim=256时的LDS用量为 256/32 * 32 * 34 * 2byte= 17 KB,如果同时读Q和dO到LDS,就会超过32KB
// headdim=224时的LDS用量为 224/32 * 32 * 34 * 2byte= 14.875 KB,如果同时读Q和dO到LDS,不会超32KB
int B_lds_stage_offset = k_loop * (WARP_N/32) * (BLOCK_K/32)*(32*34) + n_loop * (K/32) * (WARP_N/32)*(32*34);
buffer_load_lds_tile_pad(Is_even_MN, WARP_NUM, row_stride, WARP_N, BLOCK_K, Element, B_ptr, B_lds, B_block_buffer_load_global_offset, B_lds_stage_offset, max_n_len_offset, warp_id, lane_id);
}
}
}
\ No newline at end of file
#pragma once
#include "numeric_types.h"
#include "utils.h"
using namespace flash;
//32*32的tile,结果矩阵根据奇偶分开设计
//mask_type == 0:无mask
//mask_type == 1: mask矩阵右上角
//mask_type == 2: mask矩阵左下角
template <bool Is_even_MN, int mask_type>
inline __device__ void apply_mask_bwd(union_vec4_fp32 tensor[1][4], int M, int N, int M_minus_N, int window_size_left, int window_size_right) {
const int lane_id = threadIdx.x & 63;
const int lane_m_idx = (lane_id & 15);
const int lane_n_idx = (lane_id >> 4);
//无mask,仅进行边界判断
if(!Is_even_MN && mask_type == 0) {
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
int N_offset = lane_n_idx * 2 + min_tile_n + vec_idx * 8;
if(N_offset > N - 1){
tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
//mask右上角
if (mask_type == 1 && (!Is_even_MN || Is_even_MN && std::abs(M_minus_N) < 128)) {
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
int M_offset = lane_m_idx * 2 + min_tile_m;
for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
int N_offset = lane_n_idx * 2 + min_tile_n + vec_idx * 8;
int N_limit = Is_even_MN ? (M_offset + M_minus_N) : min(N - 1, M_offset + M_minus_N);
if(N_offset > N_limit){
tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
//mask左下角
if (mask_type == 2 && (!Is_even_MN || Is_even_MN && std::abs(M_minus_N) < 128)) {
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
int M_offset = lane_m_idx * 2 + min_tile_m;
for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
int N_offset = lane_n_idx * 2 + min_tile_n + vec_idx * 8;
int N_limit = (M_offset + M_minus_N);
if((!Is_even_MN && N_offset > N - 1) || N_offset < N_limit){
tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
//local mask
if (mask_type == 3) {// && (!Is_even_MN || Is_even_MN && (std::abs(M_minus_N - window_size_left) < 128 || std::abs(M_minus_N + window_size_right) < 128))
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
int M_offset = lane_m_idx * 2 + min_tile_m;
for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
int N_offset = lane_n_idx * 2 + min_tile_n + vec_idx * 8;
int N_limit_left = (M_offset + M_minus_N - window_size_left);
int N_limit_right = (M_offset + M_minus_N + window_size_right);
if((!Is_even_MN && N_offset > N - 1) || N_offset <= N_limit_left || N_offset >= N_limit_right){
tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
}
//32*32的tile,结果矩阵根据mmac_4interleave设计
//mask_type == 0:无mask
//mask_type == 1: mask矩阵右上角
//mask_type == 2: mask矩阵左下角
template <bool Is_even_MN, int mask_type>
inline __device__ void apply_mask_bwd_gfx938(union_vec4_fp32 tensor[1][4], int M, int N, int M_minus_N, int window_size_left, int window_size_right) {
const int lane_id = threadIdx.x & 63;
const int lane_m_idx = (lane_id & 15);
const int lane_n_idx = (lane_id >> 4);
//无mask,仅进行边界判断
if(!Is_even_MN && mask_type == 0) {
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
int N_offset = min_tile_n * 16 + lane_n_idx * 4 + vec_idx;
if(N_offset > N - 1){
tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
//mask右上角
if (mask_type == 1 && (!Is_even_MN || Is_even_MN && std::abs(M_minus_N) < 128)) {
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
int M_offset = min_tile_m * 16 + lane_m_idx;
for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
int N_offset = min_tile_n * 16 + lane_n_idx * 4 + vec_idx;
int N_limit = Is_even_MN ? (M_offset + M_minus_N) : min(N - 1, M_offset + M_minus_N);
if(N_offset > N_limit){
tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
//mask左下角
if (mask_type == 2 && (!Is_even_MN || Is_even_MN && std::abs(M_minus_N) < 128)) {
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
int M_offset = min_tile_m * 16 + lane_m_idx;
for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
int N_offset = min_tile_n * 16 + lane_n_idx * 4 + vec_idx;
int N_limit = (M_offset + M_minus_N);
if((!Is_even_MN && N_offset > N - 1) || N_offset < N_limit){
tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
//local mask
if (mask_type == 3) {// && (!Is_even_MN || Is_even_MN && (std::abs(M_minus_N - window_size_left) < 128 || std::abs(M_minus_N + window_size_right) < 128))
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
int M_offset = min_tile_m * 16 + lane_m_idx;
for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
int N_offset = min_tile_n * 16 + lane_n_idx * 4 + vec_idx;
int N_limit_left = (M_offset + M_minus_N - window_size_left);
int N_limit_right = (M_offset + M_minus_N + window_size_right);
if((!Is_even_MN && N_offset > N - 1) || N_offset <= N_limit_left || N_offset >= N_limit_right){
tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
}
template <bool encode_dropout_in_sign_bit=false, typename DataType, int WARP_M, int WARP_N>
inline __device__ void apply_dropout(const DataType tensor[(WARP_M/32)*(WARP_N/32)][4], uint8_t p_dropout_in_uint8_t,
unsigned long long seed, unsigned long long offset,
int block_col_start, int block_row_start,
int block_col_stride) {
// tensor has shape (8, MMA_M, MMA_N / 2)
auto encode_dropout = [](bool keep, DataType val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : DataType(0));
};
// static_assert(decltype(size<2>(tensor))::value % 2 == 0);
const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
#pragma unroll
for (int n = 0; n < (WARP_N/32); ++n, block_col_start += block_col_stride) {
uint2 rowcol = make_uint2(block_row_start, block_col_start);
#pragma unroll
for (int m = 0; m < (WARP_M/32); ++m, ++rowcol.x) {
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
// Special implementation for 16-bit types: we duplicate the threshold to the
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
// the random value is less than the threshold.
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
// We're exploiting the fact that floating point comparison is equivalent to integer
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
if (!encode_dropout_in_sign_bit
&& (std::is_same<DataType, Float16>::value || std::is_same<DataType, BFloat16>::value)) {
// uint16_t rnd_16[16];
// #pragma unroll
// for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
// uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
// #pragma unroll
// for (int j = 0; j < 2; j++) {
// Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
// // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
// // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// uint32_t mask;
// asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
// tensor_uint32(i) &= mask;
// }
// // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
// }
} else {
//min tile for a warp is 32*32
#pragma unroll
for (int n_idx = 0; n_idx < 2; n_idx++) {
#pragma unroll
for (int m_idx = 0; m_idx < 2; m_idx++) {
for(int vec_idx=0; vec_idx<4; vec_idx++) { //mmac min_tile is 16*16, a warp is 64 thread
tensor[(n*(WARP_N/16)*(WARP_M/16) + m*(WARP_M/16)) + n_idx * 2 + m_idx][vec_idx] = encode_dropout(rnd_8[n_idx * 8 + m_idx] <= p_dropout_in_uint8_t, tensor[(n*(WARP_N/16)*(WARP_M/16) + m*(WARP_M/16)) + n_idx * 2 + m_idx][vec_idx]);
}
}
// Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
}
}
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
// // }
}
}
}
template<bool zero_init=true, typename Operator, int OpType, typename DataType0, typename DataType1, int WARP_M, int WARP_N>
__device__ inline void thread_reduce_(const DataType0 tensor[(WARP_M/32)*(WARP_N/32)][4], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) {
if(zero_init == true) {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
summary[m_idx*2 + min_tile_m] = (OpType==0)? 0 : -INFINITY; //OpType:0 is sum operator, 1 is max operator
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
for(int vec_idx=0; vec_idx<4; vec_idx++) { //mmac min_tile is 16*16, a warp is 64 thread
summary[m_idx*2 + min_tile_m] = op(summary[m_idx*2 + min_tile_m], tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2 + min_tile_m][vec_idx]);
}
}
}
}
}
} else {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
summary_cur[m_idx*2 + min_tile_m] = summary[m_idx*2 + min_tile_m];// op(summary[m_idx*2 + min_tile_m], tensor[m_idx][min_tile_m][0]);
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
for(int vec_idx=0; vec_idx<4; vec_idx++) { //mmac min_tile is 16*16, a warp is 64 thread
summary_cur[m_idx*2 + min_tile_m] = op(summary_cur[m_idx*2 + min_tile_m], tensor[m_idx + n_idx*(WARP_M/32)][min_tile_n*2 + min_tile_m][vec_idx]);
}
}
}
}
}
}
}
template<typename Operator, typename DataType, int WARP_M>
__device__ inline void quad_allreduce_(DataType *dst, DataType *src, Operator &op) {
#pragma unroll
for (int i = 0; i < (WARP_M/16); i++) {
dst[i] = Allreduce<64>::run(src[i], op);
}
}
template<bool zero_init=true, typename Operator, int OpType, typename DataType0, typename DataType1, int WARP_M, int WARP_N>
__device__ inline void reduce_(const DataType0 tensor[(WARP_M/32)*(WARP_N/32)][4], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) {
if(zero_init == true) {
thread_reduce_<true, Operator, OpType, DataType0, DataType1, WARP_M, WARP_N>(tensor, summary, op);
quad_allreduce_<Operator, DataType1, WARP_M>(summary, summary, op);
} else {
thread_reduce_<false, Operator, OpType, DataType0, DataType1, WARP_M, WARP_N>(tensor, summary, op, summary_cur);
quad_allreduce_<Operator, DataType1, WARP_M>(summary_cur, summary_cur, op);
}
}
//zero_init==true, max is current max_score, max_cur=nullptr
//zero_init==true, max is prev max_score, max_cur!=nullptr
template<bool zero_init=true, typename DataType0, typename DataType1, int WARP_M, int WARP_N>
__device__ inline void reduce_max(const DataType0 tensor[(WARP_M/32)*(WARP_N/32)][4], DataType1 *max , DataType1 *max_cur=nullptr) {
MaxOp<float> max_op;
if(zero_init == true) {
reduce_<true, MaxOp<float>, 1, DataType0, DataType1, WARP_M, WARP_N>(tensor, max, max_op);
} else {
reduce_<false, MaxOp<float>, 1, DataType0, DataType1, WARP_M, WARP_N>(tensor, max, max_op, max_cur);
}
}
template<bool zero_init=true, typename DataType0, typename DataType1, int WARP_M, int WARP_N>
__device__ inline void reduce_sum(DataType0 tensor[(WARP_M/32)*(WARP_N/32)][4], DataType1 *sum, DataType1 *sum_cur=nullptr){
SumOp<float> sum_op;
if(zero_init == true) {
reduce_<true, SumOp<float>, 0, DataType0, DataType1, WARP_M, WARP_N>(tensor, sum, sum_op);
} else {
reduce_<false, SumOp<float>, 0, DataType0, DataType1, WARP_M, WARP_N>(tensor, sum, sum_op, sum_cur);
}
}
// Apply the exp to all the elements.
template <bool Scale_max=true, int BLOCK_M, int WARP_N, typename DataType0, typename DataType1>
inline __device__ void scale_apply_exp2_bwd(DataType0 tensor[(BLOCK_M/32)*(WARP_N/32)][4], const DataType1 *max, const float scale) {
// #if defined(__gfx936__)
// auto vec2_scale = vec2_fp32{scale, scale};
// #endif
#pragma unroll
for (int mi = 0; mi < (BLOCK_M/32); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
for(int vec_idx=0; vec_idx<4; vec_idx++) {
const float max_scaled = (max[(mi*2 + min_tile_m)*4 + vec_idx] * (Scale_max ? scale : float(M_LOG2E)));
// #if defined(__gfx936__)
// auto vec2_max_scaled = vec2_fp32{-max_scaled, -max_scaled};
// #endif
#pragma unroll
for (int ni = 0; ni < (WARP_N/32); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
//min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
#if 0//defined(__gfx936__)
auto vec2_tensor = vec2_fp32{tensor[ni + mi*(WARP_N/32)][min_tile_m*2].f32[vec_idx], tensor[ni + mi*(WARP_N/32)][min_tile_m*2 + 1].f32[vec_idx]};
auto vec2_scale = vec2_fp32{scale, scale};
auto vec2_max_scaled = vec2_fp32{-max_scaled, -max_scaled};
auto tensor_tmp =
__builtin_hcu_pk_fma_f32(
vec2_tensor,
vec2_scale,
vec2_max_scaled);
// __builtin_hcu_v_pk_fma_f32(
// vec2_tensor,
// vec2_scale,
// vec2_max_scaled);
tensor[ni + mi*(WARP_N/32)][min_tile_m*2].f32[vec_idx] = __llvm_exp2_f32(tensor_tmp[0]);
tensor[ni + mi*(WARP_N/32)][min_tile_m*2 + 1].f32[vec_idx] = __llvm_exp2_f32(tensor_tmp[1]);
#else
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { //使用__llvm_exp2_f32会产生nan,使用exp2f则没问题
// tensor[ni + mi*(WARP_N/32)][min_tile_n + min_tile_m*2].f32[vec_idx] =exp2f(tensor[ni + mi*(WARP_N/32)][min_tile_n + min_tile_m*2].f32[vec_idx] * scale - max_scaled);
tensor[ni + mi*(WARP_N/32)][min_tile_n + min_tile_m*2].f32[vec_idx] =__llvm_exp2_f32(tensor[ni + mi*(WARP_N/32)][min_tile_n + min_tile_m*2].f32[vec_idx] * scale - max_scaled);
}
#endif
}
}
}
}
}
// Apply the exp to all the elements.
template <bool Scale_max=true, int WARP_M, int BLOCK_N, typename DataType0, typename DataType1>
inline __device__ void scale_apply_exp2_bwd_seq_q_major(DataType0 tensor[(BLOCK_N/32)*(WARP_M/32)][4], const DataType1 max[WARP_M/16], const float scale) {
// const float max_scaled = max[0] * float(M_LOG2E);
#pragma unroll
for (int ni = 0; ni < (BLOCK_N/32); ++ni) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
#pragma unroll
for (int mi = 0; mi < (WARP_M/32); ++mi) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
//min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
const float max_scaled = (max[mi*2 + min_tile_m] * (Scale_max ? scale : float(M_LOG2E)));
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] =
__llvm_exp2_f32(tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] * scale - max_scaled);
// tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] =
// exp2f(tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] * scale - max_scaled);
// tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] =
// __llvm_exp2_f32(tensor[mi + ni*(WARP_M/32)][min_tile_n*2 + min_tile_m].f32[vec_idx] * scale - max_scaled + 64) * __llvm_exp2_f32(-64);
}
}
}
}
}
}
#if 0
template<bool Is_first, bool Check_inf=false, typename DataType0, typename DataType1,int K/*head_dim*/, int kBlockK, int WARP_M, int WARP_N>
inline __device__ void softmax_rescale_o(DataType0 scores[(WARP_N/32)*(WARP_M/32)][4], DataType1 *scores_max, DataType1 *scores_sum,
DataType0 acc_o[(K/kBlockK) * ((WARP_M/32)*(kBlockK/32))][4], float softmax_scale_log2) {
if (Is_first) {
reduce_max</*zero_init=*/true, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_max);
scale_apply_exp2<true, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_max, softmax_scale_log2);
reduce_sum<true, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_sum);
} else {
float scores_max_cur[WARP_M/16]; //calculate max of each row
reduce_max</*zero_init=*/false, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_max, scores_max_cur); // scores_max is prev scores max
for (int mi = 0; mi < (WARP_M/32); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
float scores_max_cur_reg = !Check_inf
? scores_max_cur[mi*2 + min_tile_m]
: (scores_max_cur[mi*2 + min_tile_m] == -INFINITY ? 0.0f : scores_max_cur[mi*2 + min_tile_m]);
float scores_scale = __llvm_exp2_f32((scores_max[mi*2 + min_tile_m] - scores_max_cur_reg) * softmax_scale_log2);
scores_sum[mi*2 + min_tile_m] *= scores_scale;
#pragma unroll
for(int pv_n_loop=0; pv_n_loop<(K/kBlockK); pv_n_loop++) {
#pragma unroll
for (int ni = 0; ni < (kBlockK/32); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
for(int vec_idx=0; vec_idx<4; vec_idx++) {
//min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
acc_o[pv_n_loop * ((WARP_M/32)*(kBlockK/32)) + (mi + ni*(WARP_M/32))][min_tile_n*2 + min_tile_m][vec_idx] = acc_o[pv_n_loop * ((WARP_M/32)*(kBlockK/32)) + (mi + ni*(WARP_M/32))][min_tile_n*2 + min_tile_m][vec_idx] * scores_scale;
}
}
}
}
}
}
scale_apply_exp2<true, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_max_cur, softmax_scale_log2);
float scores_sum_cur[WARP_M/16]={0.0f};
reduce_sum<true, DataType0, DataType1, WARP_M, WARP_N>(scores, scores_sum_cur);
#pragma unroll
for (int mi = 0; mi < (WARP_M/16); ++mi) { scores_sum[mi] += scores_sum_cur[mi]; }
}
};
#endif
#pragma once
#include "hip/hip_fp16.h"
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include "numeric_types.h"
#include "intrinsic.h"
#if defined(__gfx936__) || defined(__gfx938__)
#define parallel_degree 3
#else
#define parallel_degree 2
#endif
template<typename T>
void check(T result, char const* const func, const char* const file, int const line)
{
if (result) {
throw std::runtime_error(std::string("[GPU][ERROR] HIP runtime error: ") + hipGetErrorString(result) + " " + file + ":" + std::to_string(line) + " \n");
}
}
#define check_hip_error(val) check((val), #val, __FILE__, __LINE__)
namespace flash {
inline __device__ constexpr int ceil_div(int const& a, int const& b) {
return (a + b - 1) / b;
}
template<class T>
__device__ vec4_fp32 mmac(const vec4_Element<T> &v1, const vec4_Element<T> &v2, const vec4_fp32 &v3)
{
#if defined(__gfx916__) || defined(__gfx926__)
return {0, 0, 0, 0};
#else
return __builtin_hcu_mmac_f32_16x16x16_f16(v1, v2, v3);
#endif
}
template<>
__device__ vec4_fp32 mmac<half_t>(const vec4_fp16 &v1, const vec4_fp16 &v2, const vec4_fp32 &v3)
{
#if defined(__gfx916__) || defined(__gfx926__)
return {0, 0, 0, 0};
#else
return __builtin_hcu_mmac_f32_16x16x16_f16(v1, v2, v3);
#endif
}
template<>
__device__ vec4_fp32 mmac<bhalf_t>(const vec4_bf16 &v1, const vec4_bf16 &v2, const vec4_fp32 &v3)
{
#if defined(__gfx916__) || defined(__gfx926__)
return {0, 0, 0, 0};
#else
return __builtin_hcu_mmac_f32_16x16x16_bf16(v1, v2, v3);
#endif
}
template<typename T>
__forceinline__ __device__ T __shfl_xor_tmp(T x, const int lane_mask) {
int lane_id = threadIdx.x & 63;
int index = (lane_id ^ lane_mask) << 2;
int res = __builtin_amdgcn_ds_bpermute(index, *(int*)&x); // attention, __builtin only support int
return *(T*)&res;
}
template<typename T>
struct MaxOp {
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ inline float operator()(float const &x, float const &y) { return max(x, y); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const & x, T const & y) {
T res = (x + y);
return res;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 64);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
x = op(x, __shfl_xor_tmp(x, 32));
return op(x, __shfl_xor_tmp(x, 16));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<32> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
//x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
x = op(x, __shfl_xor_tmp(x, 16));
return x;
}
};
template<typename T, int WARP_M>
void copy(T *src, T *dst) {
for(int i=0; i<(WARP_M/16); i++) {
dst[i] = src[i];
}
}
//TODO:后续优化得用上V_CVT_PKRTZ_FP16_FP32
//QK(seq_q, seq_k), two float in seq_k dim convert to two half, and packed into a U32
template <int BLOCK_M, int WARP_N, typename ElementType>
inline __device__ void convert_pk_type(union_vec2_f16x2<ElementType> p_reg[(BLOCK_M/32)*(WARP_N/32)][4], union_vec4_fp32 s_reg[(BLOCK_M/32)*(WARP_N/32)][4]) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int m_idx=0; m_idx<(BLOCK_M/32); m_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
// for(int vec_idx=0; vec_idx<4; vec_idx++) {
// p_reg[n_idx + m_idx*(WARP_N/32)][min_tile_n*2 + min_tile_m].f16[vec_idx] = DownCast<float,ElementType,true>(s_reg[n_idx + m_idx*(WARP_N/32)][min_tile_n*2 + min_tile_m].f32[vec_idx]);
// }
for(int vec_idx=0; vec_idx<2; vec_idx++) {
p_reg[n_idx + m_idx*(WARP_N/32)][min_tile_n*2 + min_tile_m].f16x2[vec_idx][0] = DownCast<float,ElementType,true>(s_reg[n_idx + m_idx*(WARP_N/32)][min_tile_n*2 + min_tile_m].f32[vec_idx*2]);
p_reg[n_idx + m_idx*(WARP_N/32)][min_tile_n*2 + min_tile_m].f16x2[vec_idx][1] = DownCast<float,ElementType,true>(s_reg[n_idx + m_idx*(WARP_N/32)][min_tile_n*2 + min_tile_m].f32[vec_idx*2+1]);
}
}
}
}
}
}
//TODO:后续优化得用上V_CVT_PKRTZ_FP16_FP32
//QK(seq_q, seq_k), two float in seq_k dim convert to two half, and packed into a U32
template <int BLOCK_M, int WARP_N, typename ElementType>
inline __device__ void convert_pk_type_gfx938(union_vec4_f16x2<ElementType> p_reg[(BLOCK_M/32)*(WARP_N/32)*2], union_vec4_fp32 s_reg[(BLOCK_M/32)*(WARP_N/32)][4]) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int m_idx=0; m_idx<(BLOCK_M/32); m_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
for(int vec_idx=0; vec_idx<4; vec_idx++) {
p_reg[(n_idx + m_idx*(WARP_N/32)) * 2 + min_tile_n].f16[min_tile_m * 4 + vec_idx] = DownCast<float,ElementType,false>(s_reg[n_idx + m_idx*(WARP_N/32)][min_tile_n*2 + min_tile_m].f32[vec_idx]);
// p_reg[(n_idx + m_idx*(WARP_N/32)) * 2 + min_tile_n].f16[min_tile_m * 4 + vec_idx] = s_reg[n_idx + m_idx*(WARP_N/32)][min_tile_n*2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
}
template<const int kHeadDim, typename T>
__device__ __forceinline__ vec4_uint tcp_cache_swizzle_func(T* ptr) {
vec4_uint res;
*(uint64_t*)&res = reinterpret_cast<uint64_t>(ptr);
if constexpr (kHeadDim == 196) {
res[1] += 0x41800000; // 62 bit: cache swizzle; 48~61: Stride
} else if constexpr (kHeadDim == 128) {
res[1] += 0x41000000; // stride 256 Bytes and change tagram
} else if constexpr (kHeadDim == 64) {
res[1] += 0x40800000; // stride 128 Bytes and change tagram
}
res[2] = 0x80000000;
res[3] = 0x00020000;
return res;
}
template<typename T>
__device__ __forceinline__ vec4_uint prepare_for_matrix_load_gfx938(T* ptr, int row_stride) {
vec4_uint srsrc;
*(uint64_t*)&srsrc = reinterpret_cast<uint64_t>(ptr);
srsrc[2] = row_stride;
srsrc[3] = 0;
return srsrc;
}
template<class T, class AccumType>
inline __device__ vec4_fp32 mmac(const vec4_Element<T> &v1, const vec4_Element<T> &v2, const vec4_fp32 &v3)
{
#if defined(__gfx916__) || defined(__gfx926__)
return {0, 0, 0, 0};
#else
return __builtin_hcu_mmac_f32_16x16x16_f16(v1, v2, v3);
#endif
}
template<>
inline __device__ vec4_fp32 mmac<half_t, float>(const vec4_fp16 &v1, const vec4_fp16 &v2, const vec4_fp32 &v3)
{
#if defined(__gfx916__) || defined(__gfx926__)
return {0, 0, 0, 0};
#else
return __builtin_hcu_mmac_f32_16x16x16_f16(v1, v2, v3);
#endif
}
template<>
inline __device__ vec4_fp32 mmac<bhalf_t, float>(const vec4_bf16 &v1, const vec4_bf16 &v2, const vec4_fp32 &v3)
{
#if defined(__gfx916__) || defined(__gfx926__)
return {0, 0, 0, 0};
#else
return __builtin_hcu_mmac_f32_16x16x16_bf16(v1, v2, v3);
#endif
}
//封装buffer_load
template<int Is_M_equal,int WARP_NUM,int N_row_len,int M,int N,typename Element>
__forceinline__ __device__ void buffer_load_lds_tile(vec4_uint global_ptr, Element* lds_ptr, int global_offset, int lds_stage_offset, int max_M_len, int warp_id, int lane_id) {
int bytes_per_Element = 2;
if constexpr (std::is_same<Element, int8_t>::value || std::is_same<Element, Float8_e4m3_t>::value) {
bytes_per_Element = 1;
}
int Elment_per_dword = 4/bytes_per_Element;
//M维度index变换,(0, 1, 2, 3) --> (0, 2, 1, 3)
int lane_M_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1);
int lane_N_idx = lane_id & 15;
const int lds_load_num = (M*N*bytes_per_Element) / (4*64);
// for(int warp_loop=warp_id; warp_loop<lds_load_num; warp_loop+=WARP_NUM) {
for(int load = 0,warp_loop = warp_id; load < lds_load_num/WARP_NUM; warp_loop += WARP_NUM, ++load) {
int warp_buffer_load_lds_offset = lds_stage_offset + warp_loop * (4*32);
int gsOffset = global_offset/Elment_per_dword;
int gvOffset;
if constexpr (Is_M_equal){
gvOffset = (warp_loop * 4 + lane_M_idx) * N_row_len/Elment_per_dword + lane_N_idx;
} else {
gvOffset = (min(warp_loop * 4 + lane_M_idx, max_M_len - 1) * N_row_len)/Elment_per_dword + lane_N_idx;
}
int lds_offset = warp_buffer_load_lds_offset/Elment_per_dword;
builtin_buffer_load_dword_lds(lds_ptr, global_ptr, lds_offset, gsOffset, gvOffset);
}
}
//封装buffer_load
template<int Is_M_equal,int WARP_NUM,int N_row_len,int M,int N,typename Element>
__forceinline__ __device__ void buffer_load_lds_tile_pad(vec4_uint global_ptr, Element* lds_ptr, int global_offset, int lds_stage_offset, int max_M_len, int warp_id, int lane_id) {
int bytes_per_Element = 2;
if constexpr (std::is_same<Element, int8_t>::value || std::is_same<Element, Float8_e4m3_t>::value) {
bytes_per_Element = 1;
}
int Elment_per_dword = 4/bytes_per_Element;
//M维度index变换,(0, 1, 2, 3) --> (0, 2, 1, 3)
int lane_M_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1);
int lane_N_idx = lane_id & 15;
const int lds_load_num = (M*N*bytes_per_Element) / (4*64);
for(int load = 0,warp_loop = warp_id; load < lds_load_num/WARP_NUM; warp_loop += WARP_NUM, ++load) {
int padding = (warp_loop & 7)*2; // padding size in shared memory per buffer load, to avoid bank conflict
int warp_buffer_load_lds_offset = lds_stage_offset + ((warp_loop >> 3)*(32*34) + ( warp_loop & 7)*(4*32));
int gsOffset = global_offset/Elment_per_dword;
int gvOffset;
if constexpr (Is_M_equal){
gvOffset = (warp_loop * 4 + lane_M_idx) * N_row_len/Elment_per_dword + lane_N_idx;
} else {
gvOffset = (min(warp_loop * 4 + lane_M_idx, max_M_len - 1) * N_row_len)/Elment_per_dword + lane_N_idx;
}
int lds_offset = (warp_buffer_load_lds_offset + padding)/Elment_per_dword;
builtin_buffer_load_dword_lds(lds_ptr, global_ptr, lds_offset, gsOffset, gvOffset);
}
}
//封装ds_read
template<int M, int N, int WARP_NUM, typename Element>
__forceinline__ __device__ void ds_read_tile_pad(vec2_Element<Element>* lds_v2fp16, int lds_stage_offset, union_vec2_f16x2<Element> (*reg)[2], int loop, int warp_id, int lane_id){
#pragma unroll
for(int m_idx = 0; m_idx < M / 32; m_idx ++){
#pragma unroll
for(int n_idx = 0; n_idx < N / 32; n_idx ++){
#pragma unroll
for(int i=0; i<2; i++) {
#pragma unroll
for(int j=0; j<4; j++) {
int lds_offset = lds_stage_offset + n_idx*M*17 + (warp_id*(M/32) + m_idx)*(N*17) + j*2 + i*32 + (lane_id & 1)*16 + ((lane_id & 15)>>1)*64 + /*padding*/ ((lane_id & 15)>>1) + ((lane_id/16) &1)*8 + (lane_id/32);
inline_ds_read_b32_wait(lds_v2fp16, lds_offset, reg[(loop)*((M*N)/(32*32))*2 + (n_idx*(M/32) + m_idx)*2 + i][j/2].f16x2[j%2]);
}
}
}
}
}
//封装ds_read2
template<int M, int N, int WARP_NUM, typename Element>
__forceinline__ __device__ void ds_read2_tile_pad_no_wait(vec2_Element<Element>* lds_v2fp16, int lds_stage_offset, union_vec2_f16x2<Element> (*reg)[2], int loop, int warp_id, int lane_id){
#pragma unroll
for(int m_idx = 0; m_idx < M / 32; m_idx ++){
#pragma unroll
for(int n_idx = 0; n_idx < N / 32; n_idx ++){
#pragma unroll
for(int i=0; i<2; i++) {
#pragma unroll
for(int j=0; j<2; j++) {
int lds_offset = lds_stage_offset + n_idx*M*17 + (warp_id*(M/32) + m_idx)*(N*17) + j*4 + i*32 + (lane_id & 1)*16 + ((lane_id & 15)>>1)*64 + /*padding*/ ((lane_id & 15)>>1) + ((lane_id/16) &1)*8 + (lane_id/32);
inline_ds_read2_b32_no_wait(lds_v2fp16, lds_offset, reg[(loop)*((M*N)/(32*32))*2 + (n_idx*(M/32) + m_idx)*2 + i][j].f32, 2);
}
}
}
}
}
//封装buffer_load
#define buffer_load_lds_tile_pad(Is_M_equal, WARP_NUM, N_row_len, M, N, Element, global_ptr, lds_ptr, global_offset, lds_stage_offset, max_M_len, warp_id, lane_id)\
{\
int bytes_per_Element = 2;\
if constexpr (std::is_same<Element, int8_t>::value || std::is_same<Element, Float8_e4m3_t>::value) {\
bytes_per_Element = 1;\
}\
int Elment_per_dword = 4/bytes_per_Element;\
int lane_M_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1);\
int lane_N_idx = lane_id & 15;\
const int lds_load_num = (M*N*bytes_per_Element) / (4*64);\
for(int load = 0,warp_loop = warp_id; load < lds_load_num/WARP_NUM; warp_loop += WARP_NUM, ++load) {\
int padding = (warp_loop & 7);\
int gsOffset = global_offset/Elment_per_dword;\
int gvOffset;\
if constexpr (Is_M_equal){\
gvOffset = (warp_loop * 4 + lane_M_idx) * N_row_len/Elment_per_dword + lane_N_idx;\
} else {\
gvOffset = (min(warp_loop * 4 + lane_M_idx, max_M_len - 1) * N_row_len)/Elment_per_dword + lane_N_idx;\
}\
int lds_offset = lds_stage_offset/Elment_per_dword + padding + warp_loop * 64;\
builtin_buffer_load_dword_lds(lds_ptr, global_ptr, lds_offset, gsOffset, gvOffset);\
}\
}
//封装buffer_load
#define buffer_load_lds_tile_pad_1(Is_M_equal, WARP_NUM, N_row_len, M, N, Element, global_ptr, lds_ptr, global_offset, lds_stage_offset, max_M_len, warp_id, lane_id)\
{\
int bytes_per_Element = 2;\
if constexpr (std::is_same<Element, int8_t>::value || std::is_same<Element, Float8_e4m3_t>::value) {\
bytes_per_Element = 1;\
}\
int Elment_per_dword = 4/bytes_per_Element;\
int lane_M_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1);\
int lane_N_idx = lane_id & 15;\
const int lds_load_num = (M*N*bytes_per_Element) / (4*64);\
for(int load = 0,warp_loop = warp_id; load < lds_load_num/WARP_NUM; warp_loop += WARP_NUM, ++load) {\
int padding = (warp_loop & 7);\
int gsOffset = global_offset/Elment_per_dword;\
int gvOffset;\
gvOffset = (warp_loop * 4 + lane_M_idx) * N_row_len/Elment_per_dword + lane_N_idx;\
int lds_offset = lds_stage_offset/Elment_per_dword + padding + warp_loop * 64;\
builtin_buffer_load_dword_lds(lds_ptr, global_ptr, lds_offset, gsOffset, gvOffset);\
}\
}
//封装buffer_load
#define buffer_load_lds_tile(Is_M_equal, WARP_NUM, N_row_len, M, N, Element, global_ptr, lds_ptr, global_offset, lds_stage_offset, max_M_len, warp_id, lane_id)\
{\
int bytes_per_Element = 2;\
if constexpr (std::is_same<Element, int8_t>::value || std::is_same<Element, Float8_e4m3_t>::value) {\
bytes_per_Element = 1;\
}\
int Elment_per_dword = 4/bytes_per_Element;\
int lane_M_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1);\
int lane_N_idx = lane_id & 15;\
const int lds_load_num = (M*N*bytes_per_Element) / (4*64);\
for(int load = 0,warp_loop = warp_id; load < lds_load_num/WARP_NUM; warp_loop += WARP_NUM, ++load) {\
int gsOffset = global_offset/Elment_per_dword;\
int gvOffset;\
if constexpr (Is_M_equal){\
gvOffset = (warp_loop * 4 + lane_M_idx) * N_row_len/Elment_per_dword + lane_N_idx;\
} else {\
gvOffset = (min(warp_loop * 4 + lane_M_idx, max_M_len - 1) * N_row_len)/Elment_per_dword + lane_N_idx;\
}\
int lds_offset = lds_stage_offset/Elment_per_dword + warp_loop * 64;\
builtin_buffer_load_dword_lds(lds_ptr, global_ptr, lds_offset, gsOffset, gvOffset);\
}\
}
#define ds_read_tile_pad(M, N, WARP_NUM, Element, lds_v2fp16, lds_stage_offset, reg, loop, warp_id, lane_id)\
{\
for(int m_idx = 0; m_idx < M / 32; m_idx ++){\
for(int n_idx = 0; n_idx < N / 32; n_idx ++){\
for(int i=0; i<2; i++) {\
for(int j=0; j<4; j++) {\
int lds_offset = lds_stage_offset + n_idx*M*17 + (warp_id*(M/32) + m_idx)*(N*17) + j*2 + i*32 + (lane_id & 1)*16 + ((lane_id & 15)>>1)*64 + /*padding*/ ((lane_id & 15)>>1) + ((lane_id/16) &1)*8 + (lane_id/32);\
inline_ds_read_b32_wait(lds_v2fp16, lds_offset, reg[(loop)*((M*N)/(32*32))*2 + (n_idx*(M/32) + m_idx)*2 + i][j/2].f16x2[j%2]);\
}\
}\
}\
}\
}
#define ds_read2_tile_pad_no_wait(M,N,WARP_NUM,Element,lds_v2fp16,precompute_offset,reg,loop,warp_id,lane_id)\
{\
for(int m_idx = 0; m_idx < M / 32; m_idx ++){\
for(int n_idx = 0; n_idx < N / 32; n_idx ++){\
for(int i=0; i<2; i++) {\
for(int j=0; j<2; j++) {\
inline_ds_read2_b32_no_wait(lds_v2fp16, precompute_B_lds_offset[i*2+j], reg[(loop)*((M*N)/(32*32))*2 + (n_idx*(M/32) + m_idx)*2 + i][j].f32, 2); \
}\
}\
}\
}\
}
#define ds_offset_cast(OFFSET) \
static_cast<int>(reinterpret_cast<uintptr_t>(OFFSET) & 0xFFFFFFFF)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment