Commit a1eef562 authored by shenzhe's avatar shenzhe Committed by zhanghj2
Browse files

Add DSA MLS sparse prefill dispatch

parent 4e0bdf6e
#pragma once #pragma once
#include <cstdlib>
#include "common.h" #include "common.h"
#include "params.h" #include "params.h"
#include "gfx93/prefill/sparse/dsa_mls/fwd.h"
#include "gfx93/prefill/sparse/phase1.h" #include "gfx93/prefill/sparse/phase1.h"
...@@ -39,6 +42,12 @@ class Fwd_Sm90_Impl : public FwdImplBase { ...@@ -39,6 +42,12 @@ class Fwd_Sm90_Impl : public FwdImplBase {
protected: protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override { void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
if ((std::getenv("FLASH_MLA_FORCE_DSA_MLS_PREFILL") != nullptr && gfx93::fwd::dsa_mls::can_run(params)) ||
gfx93::fwd::dsa_mls::should_run(params)) {
gfx93::fwd::dsa_mls::run(params);
return;
}
DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() { DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {
DISPATCH_BOOLEAN_FLAG(params.topk_length != nullptr, HAVE_TOPK_LENGTH, [&]() { DISPATCH_BOOLEAN_FLAG(params.topk_length != nullptr, HAVE_TOPK_LENGTH, [&]() {
gfx93::fwd::run_fwd_phase1_kernel<HEAD_DIM_QK, HAVE_TOPK_LENGTH>(params); gfx93::fwd::run_fwd_phase1_kernel<HEAD_DIM_QK, HAVE_TOPK_LENGTH>(params);
......
#pragma once
#include <algorithm>
#include <hip/hip_runtime.h>
#include "legacy/include/flash.h"
#include "legacy/include/kernel_traits.h"
#include "legacy/include/static_switch.h"
#include "legacy/src/flash_fwd_b16_mla.h"
namespace gfx93::fwd::dsa_mls {
template<typename T, int Headdim, int HeaddimV>
void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStream_t stream) {
constexpr int kBlockM = 64;
constexpr int kBlockN = 64;
constexpr int WARP_M = 16;
dim3 dimBlock;
dimBlock.x = std::min((kBlockM / WARP_M) * 64, 1024);
dimBlock.y = 1;
dimBlock.z = 1;
dim3 dimGrid;
dimGrid.x = (params.seqlen_q + kBlockM - 1) / kBlockM;
dimGrid.y = 1;
dimGrid.z = params.b;
using Kernel_traits = Flash_fwd_kernel_traits<
Headdim, HeaddimV, kBlockM, kBlockN, 32, WARP_M, 64, 2,
false, false, T, T>;
constexpr bool Is_dropout = false;
constexpr bool IsEvenMNConst = false;
BOOL_SWITCH(params.mtp > 1, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (params.topk == 2048) {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64<
Kernel_traits, true, Is_dropout, false, Is_causal,
IsEvenMNConst, true, false, Is_MTP, 0, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
} else {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64_topk1024<
Kernel_traits, true, Is_dropout, false, Is_causal,
IsEvenMNConst, true, false, Is_MTP, 0, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
}
});
});
}
} // namespace gfx93::fwd::dsa_mls
#include "fwd.h"
#include <cstring>
#include <stdexcept>
#include <string>
#include "dispatch.h"
namespace gfx93::fwd::dsa_mls {
bool can_run(const SparseAttnFwdParams& params) {
if (params.d_v != 512) return false;
if (params.d_qk != 512 && params.d_qk != 576) return false;
if (params.h_kv != 1) return false;
if (params.h_q != 64 && params.h_q != 128) return false;
if (!(params.topk <= 1024 || params.topk == 2048)) return false;
if (params.topk == 2048 && (params.attn_sink != nullptr || params.topk_length != nullptr)) return false;
return true;
}
bool should_run(const SparseAttnFwdParams& params) {
if (!can_run(params)) return false;
if (params.d_qk == 512 &&
((params.h_q == 64 && params.topk == 512) ||
(params.h_q == 128 && params.topk == 1024))) {
return true;
}
if (params.d_qk == 576 && params.h_q == 64 && params.topk == 2048 && params.s_kv >= 32768) {
return true;
}
return false;
}
static Flash_fwd_mla_params_dsa make_legacy_params(const SparseAttnFwdParams& src) {
if (!can_run(src)) {
throw std::runtime_error(
"DSA MLS sparse prefill only supports d_qk=512/576, d_v=512, "
"h_kv=1, h_q=64/128, topk<=1024 or topk=2048 without attn_sink/topk_length");
}
Flash_fwd_mla_params_dsa dst;
std::memset(&dst, 0, sizeof(dst));
dst.layout = 1;
dst.b = 1;
dst.h = 1;
dst.h_k = src.h_kv;
dst.h_h_k_ratio = dst.h / dst.h_k;
dst.mtp = 1;
dst.ngroups = src.h_q / src.h_kv;
dst.topk = src.topk;
dst.d = src.d_qk;
dst.d_v = src.d_v;
dst.scale_softmax = src.sm_scale;
dst.scale_softmax_log2 = src.sm_scale_div_log2;
dst.cu_seqlens_q = nullptr;
dst.cu_seqlens_k = nullptr;
dst.cu_seqlens_k_new = nullptr;
dst.topk_length = src.topk_length;
dst.attn_sink = src.attn_sink;
dst.q_ptr = src.q;
dst.k_ptr = src.kv;
dst.v_ptr = src.kv;
dst.o_ptr = src.out;
dst.sparse_indices = src.indices;
dst.softmax_lse_ptr = src.lse;
dst.scores_max_ptr = src.max_logits;
dst.scores_sum_ptr = nullptr;
dst.block_table = nullptr;
dst.block_table_batch_stride = 0;
dst.page_block_size = 0;
dst.is_causal = false;
dst.q_batch_stride = 0;
dst.q_token_stride = src.stride_q_s_q;
dst.q_head_stride = src.stride_q_h_q;
dst.q_row_stride = dst.q_head_stride;
dst.k_batch_stride = 0;
dst.k_row_stride = src.stride_kv_s_kv;
dst.k_head_stride = src.stride_kv_h_kv;
dst.v_batch_stride = 0;
dst.v_row_stride = src.stride_kv_s_kv;
dst.v_head_stride = src.stride_kv_h_kv;
dst.o_batch_stride = 0;
dst.o_row_stride = src.h_q * src.d_v;
dst.o_head_stride = src.d_v;
dst.sparse_indices_batch_stride = 0;
dst.sparse_indices_row_stride = src.stride_indices_s_q;
dst.sparse_indices_head_stride = src.stride_indices_h_kv;
dst.sparse_indices_topk_stride = 1;
dst.seqlen_q = src.s_q * dst.ngroups;
dst.seqlen_k = src.s_kv;
dst.max_seqlen = src.s_q;
dst.is_bf16 = true;
dst.is_e4m3 = false;
dst.is_int8 = false;
dst.cu_count = src.num_sm;
dst.seqlenq_ngroups_swapped = true;
dst.is_seqlens_k_cumulative = false;
dst.splitkv_use_fp32_as_accum = false;
dst.num_splits = 0;
dst.partition_size = src.topk;
return dst;
}
void run(const SparseAttnFwdParams& params) {
Flash_fwd_mla_params_dsa legacy_params = make_legacy_params(params);
hipStream_t stream = reinterpret_cast<hipStream_t>(params.stream);
if (params.d_qk == 512) {
run_dsa_prefill_nopage_64_dispatch<BFloat16, 512, 512>(legacy_params, stream);
} else if (params.d_qk == 576) {
run_dsa_prefill_nopage_64_dispatch<BFloat16, 576, 512>(legacy_params, stream);
} else {
throw std::runtime_error("Unsupported d_qk value in DSA MLS sparse prefill");
}
}
} // namespace gfx93::fwd::dsa_mls
#pragma once
#include "../../../../params.h"
namespace gfx93::fwd::dsa_mls {
bool can_run(const SparseAttnFwdParams& params);
bool should_run(const SparseAttnFwdParams& params);
void run(const SparseAttnFwdParams& params);
} // namespace gfx93::fwd::dsa_mls
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Varlen=true, bool Is_Kvcache=false, bool USE_BSHD_LAYOUT = false>
struct BlockInfo {
template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb)
: sum_s_q((!Varlen || params.cu_seqlens_q == nullptr) ? -1 : params.cu_seqlens_q[bidb])
, sum_s_k((!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative) ? -1 : params.cu_seqlens_k[bidb])
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr || Is_Kvcache ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(seqlen_k_cache/* + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)*/)
, nheads(params.h)
, nheads_k(params.h_k)
, leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])
{
}
template <typename index_t>
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride;
}
inline __device__ int q_offset1(const int batch_stride, const int row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : (USE_BSHD_LAYOUT ? uint32_t(sum_s_q) * row_stride : uint32_t(sum_s_q) * row_stride * nheads);
}
inline __device__ int k_offset1(const int batch_stride, const int row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : (USE_BSHD_LAYOUT ? uint32_t(sum_s_k) * row_stride : uint32_t(sum_s_k) * row_stride * nheads_k);
}
inline __device__ int k_offset1_write(const int batch_stride, const int row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : (USE_BSHD_LAYOUT ? uint32_t(sum_s_k) * row_stride : uint32_t(sum_s_k) * row_stride * nheads);
}
inline __device__ int q_offset2(const int head_stride, const int bidh) const {
return (USE_BSHD_LAYOUT || sum_s_q == -1) ? bidh * head_stride : uint32_t(actual_seqlen_q) * head_stride * bidh;
}
inline __device__ int k_offset2(const int head_stride, const int bidh) const {
return (USE_BSHD_LAYOUT || sum_s_k == -1) ? bidh * head_stride : uint32_t(actual_seqlen_k) * head_stride *bidh;
}
const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const int leftpad_k;
const int seqlen_k_cache;
int actual_seqlen_k;
const int nheads;
const int nheads_k;
};
// Simplified blockinfo for tranditional varlen fwd inference
template<bool USE_BSHD_LAYOUT=false>
struct SimplifyBlockInfo {
template<typename Params>
__device__ SimplifyBlockInfo(const Params &params, const int bidb)
: sum_s_q(params.cu_seqlens_q[bidb])
, sum_s_k(params.cu_seqlens_k[bidb])
, actual_seqlen_q(params.cu_seqlens_q[bidb + 1] - sum_s_q)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache((params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(seqlen_k_cache/* + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)*/)
, nheads(params.h)
, nheads_k(params.h_k)
// , leftpad_k(0)
{
}
inline __device__ int q_offset1(const int batch_stride, const int row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : (USE_BSHD_LAYOUT ? uint32_t(sum_s_q) * row_stride : uint32_t(sum_s_q) * row_stride * nheads);
}
inline __device__ int k_offset1(const int batch_stride, const int row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : (USE_BSHD_LAYOUT ? uint32_t(sum_s_k) * row_stride : uint32_t(sum_s_k) * row_stride * nheads_k);
}
inline __device__ int q_offset2(const int head_stride, const int bidh) const {
return (USE_BSHD_LAYOUT || sum_s_q == -1) ? bidh * head_stride : uint32_t(actual_seqlen_q) * head_stride * bidh;
}
inline __device__ int k_offset2(const int head_stride, const int bidh) const {
return (USE_BSHD_LAYOUT || sum_s_k == -1) ? bidh * head_stride : uint32_t(actual_seqlen_k) * head_stride *bidh;
}
const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
// const int leftpad_k;
const int seqlen_k_cache;
int actual_seqlen_k;
const int nheads;
const int nheads_k;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SafeDecodeBlockInfo {
__device__ SafeDecodeBlockInfo() = default;
template<typename Params, bool Is_Q_varlen, bool Is_K_Cumulative>
__device__ void set_params(const Params &params, const int bidb) {
// process Q
if constexpr (Is_Q_varlen) { // Is_Q_varlen also means Is_Q_Cumulative = true
this->sum_s_q = params.cu_seqlens_q[bidb];
this->actual_seqlen_q = params.cu_seqlens_q[bidb + 1] - this->sum_s_q;
} else {
this->actual_seqlen_q = params.seqlen_q;
}
// process KV
if constexpr (Is_K_Cumulative) {
this->sum_s_k = params.cu_seqlens_k[bidb];
this->actual_seqlen_k = params.cu_seqlens_k[bidb + 1] - sum_s_k;
} else {
this->actual_seqlen_k = params.cu_seqlens_k[bidb];
}
}
int sum_s_q;
int sum_s_k;
int actual_seqlen_q;
int actual_seqlen_k;
};
} // namespace flash
#pragma once
#include <block_info.h>
#include "utils.h"
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template<bool Clear_dQaccum=true, bool Is_even_MN, class Element, class ElementAccumType, int kBlockM_, int kBlockN_, int WARP_M_, int WARP_N_, int kHeadDim_, int STAGES, bool USE_BSHD_LAYOUT, typename Params>
inline __device__ void compute_dot_do_o(const Params &params) {
Element *do_ptr = static_cast<Element*>(params.do_ptr);
Element *o_ptr = static_cast<Element*>(params.o_ptr);
ElementAccumType* dsoftmax_sum = static_cast<ElementAccumType*>(params.dsoftmax_sum);
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.z;
// The block index for the head.
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;
//wave size should be defined in launch file. Here use 64 threads
int lane_id = threadIdx.x & 63; //lane id, 0-63
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id =0;
__shared__ Element do_lds[STAGES*(kBlockM_/32) * (kBlockN_/32)*(32*34)];
__shared__ Element o_lds[STAGES*(kBlockM_/32) * (kBlockN_/32)*(32*34)];
float dP_sum_cur[(kBlockM_/16)] = {0.0f};
int stage_id = 0;
constexpr int kBlockM = kBlockM_;
constexpr int kBlockN = kBlockN_;
constexpr int kHeadDim = kHeadDim_;
const int WARP_NUM = (kBlockM_)/(WARP_M_);
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
const int row_offset_do = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + m_block * kBlockM * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + m_block * kBlockM * seqlen_o_stride;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
// Element *gdO = reinterpret_cast<Element *>(do_ptr) + row_offset_do;
auto gdO = tcp_cache_swizzle_func<kHeadDim_, Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_do);
// Element *gO = reinterpret_cast<Element *>(o_ptr) + row_offset_o;
auto gO = tcp_cache_swizzle_func<kHeadDim_, Element>(reinterpret_cast<Element *>(o_ptr) + row_offset_o);
ElementAccumType *dP_sum = reinterpret_cast<ElementAccumType *>(dsoftmax_sum) + row_offset_dpsum;
asm volatile("v_readfirstlane_b32 %0,%1"
: "=s"(warp_id)
: "v"(warp_id_vec)
:);
vec2_Element<Element> do_reg[(kHeadDim_/kBlockN_)*((WARP_M_*kBlockN_)/(32*32))*2][4]; //ds_read mini size is 32*32,2 is seq, 4 is head dim
vec2_Element<Element> o_reg[(kHeadDim_/kBlockN_)*((WARP_M_*kBlockN_)/(32*32))*2][4]; //ds_read mini size is 32*32,2 is seq, 4 is head dim
// int A_lane_m_idx = (lane_id >> 4);
int do_lane_m_idx = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1); //(0, 1, 2, 3) --> (0, 2, 1, 3)
int do_lane_head_dim_idx = (lane_id & 15);
//global->lds, left matrix
// printf("kBlockN_==%d, kHeadDim_=%d, WARP_M_=%d\n",kBlockN_, kHeadDim_, WARP_M_);
for(int k_loop=0; k_loop<kHeadDim_/kBlockN_; k_loop++) {
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
int do_block_buffer_load_global_offset = k_loop * kBlockN_;
// do_ptr buffer load mini size is 4*32, (kBlockM_ * kBlockN_) mini size is (32*32)
const int do_lds_load_num = (kBlockM_ * kBlockN_) / (4*32);
int do_lds_stage_offset = stage_id * (kBlockM_/32) * (kBlockN_/32)*(32*34);
for(int warp_loop=warp_id; warp_loop < do_lds_load_num; warp_loop+=WARP_NUM) {
int padding = (warp_loop & 7)*2; // padding size in shared memory per buffer load, to avoid bank conflict
int do_warp_buffer_load_m_id = (warp_loop & (kBlockM_/4 - 1)); //这样子对L1和utlc1有啥影响呢?
int do_warp_buffer_load_k_id = (warp_loop / (kBlockM_/4));
int do_warp_buffer_load_lds_offset = do_lds_stage_offset + (do_warp_buffer_load_k_id * kBlockM_ * 34) + ((do_warp_buffer_load_m_id >> 3)*(32*34) + (do_warp_buffer_load_m_id & 7)*(4*32)) ;
int do_warp_buffer_load_global_offset = (do_warp_buffer_load_k_id * 32);
int gsOffset = (do_block_buffer_load_global_offset + do_warp_buffer_load_global_offset)/2 ;
// int gvOffset = (do_lane_m_idx * kHeadDim_)/2 + do_lane_head_dim_idx;
int lds_offset = (do_warp_buffer_load_lds_offset + padding)/2;
{
int gvOffset;
if constexpr (!Is_even_MN) {
gvOffset = (min((do_lane_m_idx + (do_warp_buffer_load_m_id * 4)),binfo.actual_seqlen_q - m_block * kBlockM - 1) * seqlen_do_stride)/2 + do_lane_head_dim_idx;
} else {
gvOffset = ((do_lane_m_idx + (do_warp_buffer_load_m_id * 4)) * seqlen_do_stride)/2 + do_lane_head_dim_idx;
}
builtin_buffer_load_dword_lds(do_lds, gdO, lds_offset, gsOffset, gvOffset);
}
{
int gvOffset;
if constexpr (!Is_even_MN) {
gvOffset = (min((do_lane_m_idx + (do_warp_buffer_load_m_id * 4)),binfo.actual_seqlen_q - m_block * kBlockM - 1) * seqlen_o_stride)/2 + do_lane_head_dim_idx;
} else {
gvOffset = ((do_lane_m_idx + (do_warp_buffer_load_m_id * 4)) * seqlen_o_stride)/2 + do_lane_head_dim_idx;
}
builtin_buffer_load_dword_lds(o_lds, gO, lds_offset, gsOffset, gvOffset);
}
}
vmcnt_wait(0);
// By right we need to scale dP up by 1/params.p_dropout, but instead we don't and only scale the final
// results (dQ and dK) by 1/params.p_dropout. So we need to keep dP_sum scaled down by params.p_dropout here,
// so that (dP - dP_sum) is on the same scale.
{
//lds -> vgpr use ds_read_m; left matrix
int do_warp_m_id = (warp_id & ((kBlockM_/WARP_M_) - 1));
int do_lds_stage_offset = stage_id * (kBlockM_/32) * (kBlockN_/32)*(32*17);
vec2_Element<Element> *do_lds_v2fp16 = (vec2_Element<Element> *)(do_lds);
vec2_Element<Element> *o_lds_v2fp16 = (vec2_Element<Element> *)(o_lds);
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(kBlockN_/32); head_dim_idx++) { //32 half in col direction
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) {
//a warp load min size is (row, col) = (32,16) float
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { //sequence direction
#pragma unroll
for(int vec_id=0; vec_id<4; vec_id++) { //head dim direction
int lds_offset = do_lds_stage_offset + head_dim_idx*kBlockM_*17 + (warp_id*(WARP_M_/32) + m_idx)*(32*17) + vec_id*2 + min_tile_m*32 + (lane_id & 1)*16 + ((lane_id & 15)>>1)*64 + /*padding*/ ((lane_id & 15)>>1) + ((lane_id/16) &1)*8 + (lane_id/32);
inline_ds_read_b32_wait(do_lds_v2fp16, lds_offset, do_reg[/*(k_loop)*((WARP_M_*kBlockN_)/(32*32))*2 +*/ (head_dim_idx*(WARP_M_/32) + m_idx)*2 + min_tile_m][vec_id]);
}
}
}
}
#pragma unroll
for(int head_dim_idx=0; head_dim_idx<(kBlockN_/32); head_dim_idx++) { //32 half in col direction
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) {
//a warp load min size is (row, col) = (32,16) float
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { //sequence direction
#pragma unroll
for(int vec_id=0; vec_id<4; vec_id++) { //head dim direction
int lds_offset = do_lds_stage_offset + head_dim_idx*kBlockM_*17 + (warp_id*(WARP_M_/32) + m_idx)*(32*17) + vec_id*2 + min_tile_m*32 + (lane_id & 1)*16 + ((lane_id & 15)>>1)*64 + /*padding*/ ((lane_id & 15)>>1) + ((lane_id/16) &1)*8 + (lane_id/32);
inline_ds_read_b32_wait(o_lds_v2fp16, lds_offset, o_reg[/*(k_loop)*((WARP_M_*kBlockN_)/(32*32))*2 +*/ (head_dim_idx*(WARP_M_/32) + m_idx)*2 + min_tile_m][vec_id]);
}
}
}
}
}
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < (kBlockN/32); ++head_dim_idx) {
#pragma unroll
for(int vec_id = 0; vec_id<4; vec_id++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
if (Is_even_MN || (m_block * kBlockM + mi*32 + min_tile_m + (threadIdx.x & 15)*2) < binfo.actual_seqlen_q) {
dP_sum_cur[mi*2 + min_tile_m] += UpCast<Element,float,true>(do_reg[(head_dim_idx*(WARP_M_/32) + mi)*2 + min_tile_m][vec_id][min_tile_n]) * UpCast<Element,float,true>(o_reg[(head_dim_idx*(WARP_M_/32) + mi)*2 + min_tile_m][vec_id][min_tile_n]);
}
}
}
}
}
}
}
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
flash::SumOp<float> sum_op;
dP_sum_cur[mi*2 + min_tile_m] = flash::Allreduce<64>::run(dP_sum_cur[mi*2 + min_tile_m], sum_op) * params.p_dropout;
if ((threadIdx.x >> 4) == 0) {
dP_sum[mi*32 + min_tile_m + (threadIdx.x & 15)*2] = dP_sum_cur[mi*2 + min_tile_m];
}
}
}
}
\ No newline at end of file
#pragma once
#include <block_info.h>
#include "utils.h"
#include "prefetch.h"
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template<bool Clear_dQaccum=true, bool Is_even_MN, class Element, class ElementAccum, int kBlockM, int kBlockN, int WARP_M, int WARP_N, int K, int STAGES, bool USE_BSHD_LAYOUT, typename Params>
inline __device__ void compute_dot_do_o_gfx938(const Params &params) {
Element *do_ptr = static_cast<Element*>(params.do_ptr);
Element *o_ptr = static_cast<Element*>(params.o_ptr);
ElementAccum* dsoftmax_sum = static_cast<ElementAccum*>(params.dsoftmax_sum);
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.z;
// The block index for the head.
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;
//wave size should be defined in launch file. Here use 64 threads
int lane_id = threadIdx.x & 63; //lane id, 0-63
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = 0;
__shared__ Element dO_lds[kBlockM * kBlockN];
__shared__ Element O_lds[kBlockM * kBlockN];
float dP_sum_cur[(kBlockM/16)] = {0.0f};
const int WARP_NUM = (kBlockM)/(WARP_M);
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
const int row_offset_do = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + m_block * kBlockM * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + m_block * kBlockM * seqlen_o_stride;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
auto gdO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_do, seqlen_do_stride);
auto gO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(o_ptr) + row_offset_o, seqlen_o_stride);
ElementAccum *dP_sum = reinterpret_cast<ElementAccum *>(dsoftmax_sum) + row_offset_dpsum;
warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
union_vec4_f16x2<Element> dO_reg[((WARP_M*kBlockN)/(32*32))*2];
union_vec4_f16x2<Element> O_reg[((WARP_M*kBlockN)/(32*32))*2];
for(int k_loop=0; k_loop<K/kBlockN; k_loop++) {
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
int do_block_buffer_load_global_offset = k_loop * kBlockN;
//read 32 * 128
prefetch_to_lds_gfx938<true, kBlockM, kBlockN, Element, ElementAccum, Is_even_MN, 1>(gdO, do_block_buffer_load_global_offset, dO_lds, binfo.actual_seqlen_q - m_block * kBlockM, warp_id);
prefetch_to_lds_gfx938<true, kBlockM, kBlockN, Element, ElementAccum, Is_even_MN, 1>(gO, do_block_buffer_load_global_offset, O_lds, binfo.actual_seqlen_q - m_block * kBlockM, warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
for(int i = 0; i < kBlockN / 32; ++i) {
DS_READ_MATRIX_32X32_B16(ds_offset_cast(dO_lds + i * 32 * 32), dO_reg[i * 2 + 0].f16, dO_reg[i * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(ds_offset_cast(O_lds + i * 32 * 32), O_reg[i * 2 + 0].f16, O_reg[i * 2 + 1].f16, true);
// if constexpr (std::is_same_v<Element, half_t>) {
// dO_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(dO_lds + i * 32 * 32, 0, 2, 1, 0);
// dO_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(dO_lds + i * 32 * 32, 1024, 2, 1, 0);
// O_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(O_lds + i * 32 * 32, 0, 2, 1, 0);
// O_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(O_lds + i * 32 * 32, 1024, 2, 1, 0);
// } else {
// dO_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(dO_lds + i * 32 * 32, 0, 2, 1, 0);
// dO_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(dO_lds + i * 32 * 32, 1024, 2, 1, 0);
// O_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(O_lds + i * 32 * 32, 0, 2, 1, 0);
// O_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(O_lds + i * 32 * 32, 1024, 2, 1, 0);
// }
}
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < (kBlockN/32); ++head_dim_idx) {
#pragma unroll
for(int vec_id = 0; vec_id<4; vec_id++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
if (Is_even_MN || (m_block * kBlockM + min_tile_m*16 + (threadIdx.x & 15)) < binfo.actual_seqlen_q) {
dP_sum_cur[min_tile_m] += UpCast<Element,float,false>(dO_reg[head_dim_idx*2 + min_tile_m].f16[vec_id * 2 + min_tile_n]) * UpCast<Element,float,false>(O_reg[head_dim_idx*2 + min_tile_m].f16[vec_id * 2 + min_tile_n]);
}
}
}
}
}
}
#pragma unroll
for (int mi = 0; mi < (WARP_M/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
flash::SumOp<float> sum_op;
dP_sum_cur[mi*2 + min_tile_m] = flash::Allreduce<64>::run(dP_sum_cur[mi*2 + min_tile_m], sum_op) * params.p_dropout;
if ((threadIdx.x >> 4) == 0) {
dP_sum[mi*32 + min_tile_m * 16 + (threadIdx.x & 15)] = dP_sum_cur[mi*2 + min_tile_m];
}
}
}
}
#pragma once
#include <block_info.h>
#include "utils.h"
#include "prefetch.h"
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template<bool Clear_dQaccum=true, bool Is_even_MN, class Element, class ElementAccum, int kBlockM, int kBlockN, int WARP_M, int WARP_N, int K, int STAGES, bool USE_BSHD_LAYOUT, typename Params>
inline __device__ void compute_dot_do_o_gfx946(const Params &params) {
Element *do_ptr = static_cast<Element*>(params.do_ptr);
Element *o_ptr = static_cast<Element*>(params.o_ptr);
ElementAccum* dsoftmax_sum = static_cast<ElementAccum*>(params.dsoftmax_sum);
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.z;
// The block index for the head.
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;
//wave size should be defined in launch file. Here use 64 threads
int lane_id = threadIdx.x & 63; //lane id, 0-63
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = 0;
__shared__ Element dO_lds[kBlockM * kBlockN];
__shared__ Element O_lds[kBlockM * kBlockN];
float dP_sum_cur[(kBlockM/16)] = {0.0f};
const int WARP_NUM = (kBlockM)/(WARP_M);
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
const int row_offset_do = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + m_block * kBlockM * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + m_block * kBlockM * seqlen_o_stride;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
auto gdO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_do, seqlen_do_stride);
auto gO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(o_ptr) + row_offset_o, seqlen_o_stride);
ElementAccum *dP_sum = reinterpret_cast<ElementAccum *>(dsoftmax_sum) + row_offset_dpsum;
asm volatile("v_readfirstlane_b32 %0,%1"
: "=s"(warp_id)
: "v"(warp_id_vec)
:);
union_vec4_f16x2<Element> dO_reg[((WARP_M*kBlockN)/(32*32))*2];
union_vec4_f16x2<Element> O_reg[((WARP_M*kBlockN)/(32*32))*2];
for(int k_loop=0; k_loop<K/kBlockN; k_loop++) {
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
int do_block_buffer_load_global_offset = k_loop * kBlockN;
//read 32 * 128
prefetch_to_lds_gfx938<true, kBlockM, kBlockN, Element, ElementAccum, Is_even_MN, 1>(gdO, do_block_buffer_load_global_offset, dO_lds, binfo.actual_seqlen_q - m_block * kBlockM, warp_id);
prefetch_to_lds_gfx938<true, kBlockM, kBlockN, Element, ElementAccum, Is_even_MN, 1>(gO, do_block_buffer_load_global_offset, O_lds, binfo.actual_seqlen_q - m_block * kBlockM, warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
for(int i = 0; i < kBlockN / 32; ++i) {
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(dO_lds + i * 32 * 32), dO_reg[i * 2 + 0].f16, dO_reg[i * 2 + 1].f16, true);
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(O_lds + i * 32 * 32), O_reg[i * 2 + 0].f16, O_reg[i * 2 + 1].f16, true);
if constexpr (std::is_same_v<Element, half_t>) {
dO_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(dO_lds + i * 32 * 32, 0, 2, 1, 0);
dO_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(dO_lds + i * 32 * 32, 1024, 2, 1, 0);
O_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(O_lds + i * 32 * 32, 0, 2, 1, 0);
O_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(O_lds + i * 32 * 32, 1024, 2, 1, 0);
} else {
dO_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(dO_lds + i * 32 * 32, 0, 2, 1, 0);
dO_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(dO_lds + i * 32 * 32, 1024, 2, 1, 0);
O_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(O_lds + i * 32 * 32, 0, 2, 1, 0);
O_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(O_lds + i * 32 * 32, 1024, 2, 1, 0);
}
}
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < (kBlockN/32); ++head_dim_idx) {
#pragma unroll
for(int vec_id = 0; vec_id<4; vec_id++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
if (Is_even_MN || (m_block * kBlockM + min_tile_m*16 + (threadIdx.x & 15)) < binfo.actual_seqlen_q) {
dP_sum_cur[min_tile_m] += UpCast<Element,float,false>(dO_reg[head_dim_idx*2 + min_tile_m].f16[vec_id * 2 + min_tile_n]) * UpCast<Element,float,false>(O_reg[head_dim_idx*2 + min_tile_m].f16[vec_id * 2 + min_tile_n]);
}
}
}
}
}
}
#pragma unroll
for (int mi = 0; mi < (WARP_M/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
flash::SumOp<float> sum_op;
dP_sum_cur[mi*2 + min_tile_m] = flash::Allreduce<64>::run(dP_sum_cur[mi*2 + min_tile_m], sum_op) * params.p_dropout;
if ((threadIdx.x >> 4) == 0) {
dP_sum[mi*32 + min_tile_m * 16 + (threadIdx.x & 15)] = dP_sum_cur[mi*2 + min_tile_m];
}
}
}
}
\ No newline at end of file
#include <iostream>
#include <memory>
#include <vector>
#include <random>
#include <fstream>
#include <stdlib.h>
#include <dirent.h>
#include <unistd.h>
#include <sys/stat.h>
#include "assert.h"
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "flash.h"
#include "utils.h"
#include "wait.h"
#include "../numeric_types.h"
#include "philox.cuh"
#include "softmax_tiling.h"
#include "gpu_gemm_nn.h"
#include "gpu_gemm_tt.h"
#include "intrinsic.h"
#include "intrinsic_mls_ds.h"
#include "static_switch.h"
#include "dot_do_o.h"
#include "dot_do_o_gfx938.h"
#include "dot_do_o_gfx946.h"
#include "prefetch.h"
#include "flash_singleton.h"
#include "flash_attention_dv_dk_bwd.h"
#include "flash_attention_dv_dk_bwd_gfx938.h"
#include "flash_attention_dv_dk_bwd_gfx946.h"
#include "flash_attention_dq_bwd.h"
#include "flash_attention_dq_bwd_gfx938.h"
#include "flash_attention_dq_bwd_gfx946.h"
using std::make_shared;
using std::shared_ptr;
template <int kBlockM_, int kBlockN_, int WARP_M_, int WARP_N_, typename Element>
inline __device__ void reshape(Element* smem, vec4_Element<Element> ds_reg_fp16[(WARP_N_/32)*(WARP_M_/32)][4], int warp_id) {
int lane_id = threadIdx.x & 63; //lane id, 0-63
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
int lds_offset = warp_id*(WARP_N_/32)*33*kBlockM_ + n_idx*33*kBlockM_ + m_idx*32*33 + min_tile_m*16*33 + vec_idx*4*33 + (lane_id>>4)*33 + min_tile_n*16 + (lane_id&15);
Element ds_reg_tmp = ds_reg_fp16[(WARP_N_/32)*m_idx + n_idx][min_tile_m*2 + min_tile_n][vec_idx];
{
smem[lds_offset] = ds_reg_fp16[(WARP_N_/32)*m_idx + n_idx][min_tile_m*2 + min_tile_n][vec_idx];
}
}
}
}
}
}
#pragma unroll
for(int m_idx=0; m_idx<(kBlockM_/32); m_idx++) {
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) {
#pragma unroll
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int min_tile_n = 0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
int lds_offset = warp_id*33*kBlockM_ + m_idx*32*33 + min_tile_m*16 + vec_idx*4 + (lane_id>>4) + min_tile_n*16*33 + (lane_id&15)*33;
ds_reg_fp16[(WARP_N_/32)*m_idx + n_idx][min_tile_m*2 + min_tile_n][vec_idx] = smem[lds_offset];
}
}
}
}
}
}
/*
* q_ptr: Transposed 32x16 matrix
* k_ptr: Non-transposed 32x16 matrix
* qk_ptr: Non-transposed 32x32 matrixseqlen_q
*/
template<class DataType>
int check_param(int seqlen_q, int seqlen_k, int K, int kBlockM_, int kBlockN_, int kBlockK_, int WARP_M_, int WARP_N_, dim3 blockDim, dim3 gridDim, int maxBlockThreads, int STAGES) {
// min warp size is 32x32
if(WARP_M_<32 || WARP_N_<32) {
std::cout<<"Error, WARP_M_<32 or WARP_N_<32!"<<std::endl;
assert(((WARP_M_>=32) && (WARP_N_>=32)));
}
// check block threads number
const int blockThreads = ((kBlockM_*kBlockN_)/(WARP_M_*WARP_N_)*64);
if(blockThreads > maxBlockThreads) {
std::cout<<"Error,Block threads is greater than maxBlockThreads! "<<std::endl;
assert(blockThreads <= maxBlockThreads);
}
//check lds data numbers
int DataTypeSize = sizeof(DataType);
const int q_lds_size = STAGES * kBlockM_ * kBlockK_ * DataTypeSize;
const int k_lds_size = STAGES * kBlockN_ * kBlockK_ * DataTypeSize;
if(((q_lds_size + k_lds_size)/1024) > 64) {
std::cout<<"Error, shared memory size is greater than 64KB"<<std::endl;
assert(((q_lds_size + k_lds_size)/1024) <= 64); //BW lds 64KB
}
}
#ifdef DEBUGING
#define print_qk(block_id_m, bidb, bidh) {\
int qk_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = qk_global_offset + block_n_idx * WARP_N_ + lane_id%16*2*params.seqlen_k + lane_id/16*2 + warp_m_idx * params.seqlen_k + warp_n_idx + vec_idx * 8; \
kq_ptr[offset] = s_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_softmax_rescale_o(block_id_m, bidb, bidh) { \
int qk_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = qk_global_offset + block_n_idx * WARP_N_ + lane_id%16*2*params.seqlen_k + lane_id/16*2 + warp_m_idx * params.seqlen_k + warp_n_idx + vec_idx * 8; \
s_ptr[offset] = s_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_dp(block_id_m, bidb, bidh) {\
int dp_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int dp_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + dp_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = dp_global_offset + block_n_idx * WARP_N_ + lane_id%16*2*params.seqlen_k + lane_id/16*2 + warp_m_idx * params.seqlen_k + warp_n_idx + vec_idx * 8; \
dp_ptr[offset] = dp_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_ds(block_id_m, bidb, bidh) {\
int ds_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int ds_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = ds_global_offset + block_n_idx * WARP_N_ + lane_id%16*2*params.seqlen_k + lane_id/16*2 + warp_m_idx * params.seqlen_k + warp_n_idx + vec_idx * 8; \
ds_ptr[offset] = dS_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#endif
template<class Element, class ElementAccum, bool Is_dropout, bool Is_causal , bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel, int kBlockM_, int kBlockN_, int K, int K_v, int kBlockK_, int WARP_M_, int WARP_N_, int STAGES, int USE_BSHD_LAYOUT, typename Params>
__forceinline__ __device__ void compute_dq_1colblock(Params &params, int bidb, int bidh, int m_block
) {
#ifdef DEBUGING
ElementAccum * kq_ptr = static_cast<ElementAccum*>(params.kq_ptr);
ElementAccum * s_ptr = static_cast<ElementAccum*>(params.s_ptr);
ElementAccum * dp_ptr = static_cast<ElementAccum*>(params.dp_ptr);
ElementAccum * ds_ptr = static_cast<ElementAccum*>(params.ds_ptr);
#endif
Element* q_ptr = static_cast<Element*>(params.q_ptr);
Element* k_ptr = static_cast<Element*>(params.k_ptr);
Element* v_ptr = static_cast<Element*>(params.v_ptr);
Element* o_ptr = static_cast<Element*>(params.o_ptr);
Element* dq_ptr = static_cast<Element*>(params.dq_ptr);
Element* dk_ptr = static_cast<Element*>(params.dk_ptr);
Element* dv_ptr = static_cast<Element*>(params.dv_ptr);
Element* do_ptr = static_cast<Element*>(params.do_ptr);
ElementAccum* softmax_lse_ptr = static_cast<ElementAccum*>(params.softmax_lse_ptr);
ElementAccum* dsoftmax_sum = static_cast<ElementAccum*>(params.dsoftmax_sum);
//flash-attention QK, kBlockN_==WARP_N_;
const int M_BLOCK_NUM = params.seqlen_q/kBlockM_;
const int N_BLOCK_NUM = params.seqlen_k/kBlockN_;
extern __shared__ Element smem[];
#if 1//defined(__gfx936__)
const bool Is_store_K = true;
const bool Is_preload_K = true;
const bool Is_preload_V = true;
#else
const bool Is_store_K = false;
const bool Is_preload_K = false;
const bool Is_preload_V = false;
#endif
const int K_prefetch_level = Is_preload_K ? 1 : 0;
const int V_prefetch_level = Is_preload_V ? 1 : 0;
const int Q_prefetch_level = 3;
Element* K_lds = (Element*)&(smem);
Element* Q_lds = (Element*)&(smem);
Element* dO_lds = (Element*)&(smem);
Element* V_lds = (Element*)&(smem) + (kBlockN_/32)*(K/32)*(32*34);//(Is_preload_K || Is_store_K) ? (Element*)&(smem) + (kBlockN_/32)*(K/32)*(32*34) : (Element*)&(smem);
int tidx = threadIdx.x;
int lane_id = threadIdx.x & 63; //lane id, 0-63
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (m_block < 0 || m_block * kBlockM_ >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM_ - params.window_size_left) / kBlockN_);
const int n_block_max = (!Is_causal && !Is_local) ? ceil_div(binfo.actual_seqlen_k, kBlockN_) : std::min(ceil_div(binfo.actual_seqlen_k, kBlockN_), flash::ceil_div((m_block + 1) * kBlockM_ + params.window_size_right, kBlockN_));
int seqlen_q_stride = params.q_row_stride;
int seqlen_k_stride = params.k_row_stride;
int seqlen_v_stride = params.v_row_stride;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
int seqlen_dq_stride = params.dq_row_stride;
// We move K and V to the last block.
const int row_offset_q = binfo.q_offset1(params.q_batch_stride, params.q_row_stride, bidb) + binfo.q_offset2(params.q_head_stride,bidh) + m_block * kBlockM_ * seqlen_q_stride;
const int row_offset_k = binfo.k_offset1(params.k_batch_stride, params.k_row_stride, bidb) + binfo.k_offset2(params.k_head_stride,bidh/params.h_h_k_ratio) + (n_block_max - 1) * kBlockN_ * seqlen_k_stride;
const int row_offset_v = binfo.k_offset1(params.v_batch_stride, params.v_row_stride, bidb) + binfo.k_offset2(params.v_head_stride,bidh/params.h_h_k_ratio) + (n_block_max - 1) * kBlockN_ * seqlen_v_stride;
const int row_offset_dO = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + m_block * kBlockM_ * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + m_block * kBlockM_ * seqlen_o_stride;
const int row_offset_dq = binfo.q_offset1(params.dq_batch_stride, params.dq_row_stride, bidb) + binfo.q_offset2(params.dq_head_stride,bidh) + m_block * kBlockM_ * seqlen_dq_stride;
const int row_offset_lse = params.cu_seqlens_q == nullptr ? (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM_ : bidh * params.total_q + binfo.sum_s_q + m_block * kBlockM_;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM_;
// Element * gQ = reinterpret_cast<Element *>(q_ptr) + row_offset_q;
auto gQ = tcp_cache_swizzle_func<K, Element>(reinterpret_cast<Element *>(q_ptr) + row_offset_q);
// Element * gK = reinterpret_cast<Element *>(k_ptr) + row_offset_k;
auto gK = tcp_cache_swizzle_func<K, Element>(reinterpret_cast<Element *>(k_ptr) + row_offset_k);
// Element * gV = reinterpret_cast<Element *>(v_ptr) + row_offset_v;
auto gV = tcp_cache_swizzle_func<K_v, Element>(reinterpret_cast<Element *>(v_ptr) + row_offset_v);
// Element * gdO = reinterpret_cast<Element *>(do_ptr) + row_offset_dO;
auto gdO = tcp_cache_swizzle_func<K_v, Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_dO);
Element * gO = reinterpret_cast<Element *>(o_ptr) + row_offset_o;
dq_ptr = reinterpret_cast<Element *>(dq_ptr) + row_offset_dq;
auto gdQ = tcp_cache_swizzle_func<K, Element>(reinterpret_cast<Element *>(dq_ptr));
ElementAccum *gLSE = reinterpret_cast<ElementAccum *>(softmax_lse_ptr) + row_offset_lse;
ElementAccum *gdPsum = reinterpret_cast<ElementAccum *>(dsoftmax_sum) + row_offset_dpsum;
constexpr int n_masking_steps = (!Is_causal && !Is_local)
? 1
: ((Is_even_MN && Is_causal) ? flash::ceil_div(kBlockM_, kBlockN_) : flash::ceil_div(kBlockM_, kBlockN_) + 1);
// int warp_id =0;
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
union_vec2_f16x2<Element> q_reg[(K/kBlockK_)*((WARP_M_*kBlockK_)/(32*32))*2][2];
union_vec2_f16x2<Element> dO_reg[(K_v/kBlockK_)*((WARP_M_*kBlockK_)/(32*32))*2][2];
union_vec4_fp32 acc_dq[(K/kBlockK_) * ((WARP_M_/32)*(kBlockK_/32))][4]={0};
float lse[WARP_M_/16];
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
int lse_idx = warp_id*WARP_M_ + mi*32 + ((lane_id & 15)*2) + min_tile_m;
lse[mi*2 + min_tile_m] = (Is_even_MN || lse_idx < binfo.actual_seqlen_q - m_block * kBlockM_) ? gLSE[lse_idx] : INFINITY;
}
}
float dP_sum_reg[WARP_M_/16];
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
int dP_sum_idx = warp_id*WARP_M_ + mi*32 + ((lane_id & 15)*2) + min_tile_m;
dP_sum_reg[mi*2 + min_tile_m] = gdPsum[dP_sum_idx];
}
}
prefetch_to_vgpr<K, kBlockM_, kBlockK_, WARP_N_, Element, ElementAccum, Is_even_MN>(gQ, Q_lds, q_reg, (binfo.actual_seqlen_q - m_block * kBlockM_), seqlen_q_stride);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
prefetch_to_vgpr<K_v, kBlockM_, kBlockK_, WARP_N_, Element, ElementAccum, Is_even_MN>(gdO, dO_lds, dO_reg, (binfo.actual_seqlen_q - m_block * kBlockM_), seqlen_do_stride);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
if constexpr (Is_preload_K){
prefetch_to_tmp_lds_wait<Is_even_MN, K, kBlockM_, kBlockN_, kBlockK_, WARP_M_, WARP_N_, Element>(gK, K_lds, (binfo.actual_seqlen_k - (n_block_max - 1) * kBlockN_), warp_id, seqlen_k_stride);
}
if constexpr (Is_preload_V){
prefetch_to_tmp_lds_wait<Is_even_MN, K_v, kBlockM_, kBlockN_, kBlockK_, WARP_M_, WARP_N_, Element>(gV, V_lds, (binfo.actual_seqlen_k - (n_block_max - 1) * kBlockN_), warp_id, seqlen_v_stride);
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
for (int n_block = n_block_max - 1; n_block >= n_block_min ; --n_block) {
union_vec2_f16x2<Element> v_reg[((WARP_N_*kBlockK_)/(32*32))*2][2];
union_vec4_fp32 dp_reg[(WARP_M_/32)*(kBlockN_/32)][4]= {0};
{
//dp gemm
gemm_tt_kq<false, Is_preload_K, Is_even_MN, 3, V_prefetch_level, K_v, kBlockM_, kBlockN_, kBlockK_, WARP_N_, WARP_N_, STAGES, Element>(gdO, gV, dO_lds, V_lds, (binfo.actual_seqlen_q - m_block * kBlockM_), (binfo.actual_seqlen_k - n_block * kBlockN_), dO_reg, v_reg, dp_reg, warp_id, seqlen_do_stride, seqlen_v_stride);
}
#ifdef DEBUGING
print_dp(m_block, bidb, bidh);
#endif
union_vec2_f16x2<Element> k_reg[((WARP_M_*kBlockK_)/(32*32))*2][2];
//c mini tile is 32*32
union_vec4_fp32 s_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
//qk gemm
gemm_tt_kq<Is_store_K, false, Is_even_MN, Q_prefetch_level, K_prefetch_level, K, kBlockM_, kBlockN_, kBlockK_, WARP_N_, WARP_N_, STAGES, Element>(gQ, gK, Q_lds, K_lds, (binfo.actual_seqlen_q - m_block * kBlockM_), (binfo.actual_seqlen_k - n_block * kBlockN_), q_reg, k_reg, s_reg, warp_id, seqlen_q_stride, seqlen_k_stride);
*(uint64_t*)&gV -= ((kBlockN_ * seqlen_v_stride) * sizeof(Element));
if (Is_preload_V && n_block > n_block_min){
prefetch_to_tmp_lds_wait<Is_even_MN, K_v, kBlockM_, kBlockN_, kBlockK_, WARP_M_, WARP_N_, Element>(gV, V_lds, (binfo.actual_seqlen_k - (n_block - 1) * kBlockN_), warp_id, seqlen_v_stride);
}
apply_mask_bwd<Is_even_MN, Is_local ? 3 : (Is_causal ? 1 : 0)>(s_reg, binfo.actual_seqlen_q - m_block * kBlockM_ - warp_id * 32, binfo.actual_seqlen_k - n_block * kBlockN_, (m_block * kBlockM_ + warp_id * 32) - (n_block * kBlockN_), params.window_size_left, params.window_size_right);
#ifdef DEBUGING
print_qk(m_block, bidb, bidh);
#endif
scale_apply_exp2_bwd_seq_q_major</*scale_max=*/false, WARP_M_, kBlockN_, union_vec4_fp32, ElementAccum>(s_reg, lse, params.scale_softmax_log2);
#ifdef DEBUGING
print_softmax_rescale_o(m_block, bidb, bidh)
#endif
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
// return p * (dp - d);
};
union_vec4_fp32 dS_reg[(WARP_M_/32)*(kBlockN_/32)][4];
#pragma unroll
for (int ni = 0; ni < (kBlockN_/32); ++ni) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
// result register ds_reg reuse dp_reg
dS_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx] = pointwise_mult(
s_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx],
dp_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx],
dP_sum_reg[min_tile_m + mi*2]);
}
}
}
}
}
#ifdef DEBUGING
print_ds(m_block, bidb, bidh);
#endif
union_vec2_f16x2<Element> dS_reg_fp16[(WARP_M_/32)*(kBlockN_/32)][4];
convert_pk_type<WARP_M_, kBlockN_, Element>(dS_reg_fp16, dS_reg);
{
//dq gemm, K*dS
gpu_gemm_B_in_reg<Is_store_K , false , false, Is_even_MN, K, kBlockK_, kBlockM_, kBlockN_, kBlockK_, WARP_M_, 2, Element>(gK, gK, K_lds, dS_reg_fp16, acc_dq, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_k_stride);
}
*(uint64_t*)&gK -= ((kBlockN_ * seqlen_k_stride) * sizeof(Element));
// if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0){
// printf("(binfo.actual_seqlen_k - n_block * kBlockN_) = %d\n", (binfo.actual_seqlen_k - n_block * kBlockN_));
// }
#if 1//defined(__gfx936__)
{
__syncthreads();
if (Is_preload_K && n_block > n_block_min){
prefetch_to_tmp_lds_wait<Is_even_MN, K, kBlockM_, kBlockN_, kBlockK_, WARP_M_, WARP_N_, Element>(gK, K_lds, (binfo.actual_seqlen_k - (n_block - 1) * kBlockN_), warp_id, seqlen_k_stride);
}
}
#else
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
#endif
}
{
int dq_lane_seq_idx = (lane_id >> 4);
int dq_lane_head_dim_idx = (lane_id & 15);
int dq_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K/kBlockK_); k_loop++) {
int dq_block_buffer_store_global_offset = k_loop * kBlockK_;
#pragma unroll
for(int warp_m_idx=0; warp_m_idx<(WARP_M_/32); warp_m_idx++) {
int dq_warp_buffer_store_global_offset = (warp_id*WARP_M_ + warp_m_idx*32 + dq_lane_seq_idx*2) * seqlen_dq_stride;
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
dq_global_addr_offset = dq_block_buffer_store_global_offset + dq_warp_buffer_store_global_offset + k_tile_idx*32;
#pragma unroll 2
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
#pragma unroll 2
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
int dq_global_addr = dq_global_addr_offset + (min_tile_m + vec_index*8)*seqlen_dq_stride + min_tile_k + dq_lane_head_dim_idx*2;
if(Is_even_MN || ((m_block * kBlockM_) + (warp_id*WARP_M_ + warp_m_idx*32 + dq_lane_seq_idx*2) + min_tile_m + vec_index*8) < binfo.actual_seqlen_q) {
dq_ptr[dq_global_addr] = DownCast<ElementAccum, Element>(acc_dq[k_loop * ((WARP_M_/32)*(kBlockK_/32)) + (warp_m_idx*(kBlockK_/32) + k_tile_idx)][min_tile_k + min_tile_m*2].f32[vec_index] * params.scale_softmax_rp_dropout);
}
}
}
}
}
}
}
}
}
#undef print_qk
#undef print_softmax_rescale_o
#undef print_dp
#undef print_ds
#ifdef DEBUGING
#define print_qk(block_id_m, bidb, bidh) {\
int qk_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = qk_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
kq_ptr[offset] = s_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_softmax_rescale_o(block_id_m, bidb, bidh) { \
int qk_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = qk_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
s_ptr[offset] = s_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_dp(block_id_m, bidb, bidh) {\
int dp_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int dp_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + dp_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = dp_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
dp_ptr[offset] = dp_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_ds(block_id_m, bidb, bidh) {\
int ds_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int ds_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = ds_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
ds_ptr[offset] = dS_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#endif
template<class Element, class ElementAccum, bool Is_dropout, bool Is_causal , bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel, int kBlockM_, int kBlockN_, int K, int K_v, int kBlockK_, int WARP_M_, int WARP_N_, int STAGES, int USE_BSHD_LAYOUT, typename Params>
__forceinline__ __device__ void compute_dq_1colblock_gfx938(Params &params, int bidb, int bidh, int m_block
) {
#ifdef DEBUGING
ElementAccum * kq_ptr = static_cast<ElementAccum*>(params.kq_ptr);
ElementAccum * s_ptr = static_cast<ElementAccum*>(params.s_ptr);
ElementAccum * dp_ptr = static_cast<ElementAccum*>(params.dp_ptr);
ElementAccum * ds_ptr = static_cast<ElementAccum*>(params.ds_ptr);
#endif
Element* q_ptr = static_cast<Element*>(params.q_ptr);
Element* k_ptr = static_cast<Element*>(params.k_ptr);
Element* v_ptr = static_cast<Element*>(params.v_ptr);
Element* o_ptr = static_cast<Element*>(params.o_ptr);
Element* dq_ptr = static_cast<Element*>(params.dq_ptr);
Element* dk_ptr = static_cast<Element*>(params.dk_ptr);
Element* dv_ptr = static_cast<Element*>(params.dv_ptr);
Element* do_ptr = static_cast<Element*>(params.do_ptr);
ElementAccum* softmax_lse_ptr = static_cast<ElementAccum*>(params.softmax_lse_ptr);
ElementAccum* dsoftmax_sum = static_cast<ElementAccum*>(params.dsoftmax_sum);
//flash-attention QK, kBlockN_==WARP_N_;
const int M_BLOCK_NUM = params.seqlen_q/kBlockM_;
const int N_BLOCK_NUM = params.seqlen_k/kBlockN_;
extern __shared__ Element smem[];
#if 1//defined(__gfx936__)
const bool Is_store_K = true;
const bool Is_preload_K = true;
const bool Is_preload_V = true;
#else
const bool Is_store_K = false;
const bool Is_preload_K = false;
const bool Is_preload_V = false;
#endif
const int K_prefetch_level = Is_preload_K ? 1 : 0;
const int V_prefetch_level = Is_preload_V ? 1 : 0;
const int Q_prefetch_level = 3;
Element* K_lds = (Element*)&(smem);
Element* Q_lds = (Element*)&(smem);
Element* dO_lds = (Element*)&(smem);
Element* V_lds = (Element*)&(smem) + kBlockN_* K;
int tidx = threadIdx.x;
int lane_id = threadIdx.x & 63; //lane id, 0-63
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (m_block < 0 || m_block * kBlockM_ >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM_ - params.window_size_left) / kBlockN_);
const int n_block_max = (!Is_causal && !Is_local) ? ceil_div(binfo.actual_seqlen_k, kBlockN_) : std::min(ceil_div(binfo.actual_seqlen_k, kBlockN_), flash::ceil_div((m_block + 1) * kBlockM_ + params.window_size_right, kBlockN_));
int seqlen_q_stride = params.q_row_stride;
int seqlen_k_stride = params.k_row_stride;
int seqlen_v_stride = params.v_row_stride;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
int seqlen_dq_stride = params.dq_row_stride;
// We move K and V to the last block.
const int row_offset_q = binfo.q_offset1(params.q_batch_stride, params.q_row_stride, bidb) + binfo.q_offset2(params.q_head_stride,bidh) + m_block * kBlockM_ * seqlen_q_stride;
const int row_offset_k = binfo.k_offset1(params.k_batch_stride, params.k_row_stride, bidb) + binfo.k_offset2(params.k_head_stride,bidh/params.h_h_k_ratio) + (n_block_max - 1) * kBlockN_ * seqlen_k_stride;
const int row_offset_v = binfo.k_offset1(params.v_batch_stride, params.v_row_stride, bidb) + binfo.k_offset2(params.v_head_stride,bidh/params.h_h_k_ratio) + (n_block_max - 1) * kBlockN_ * seqlen_v_stride;
const int row_offset_dO = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + m_block * kBlockM_ * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + m_block * kBlockM_ * seqlen_o_stride;
const int row_offset_dq = binfo.q_offset1(params.dq_batch_stride, params.dq_row_stride, bidb) + binfo.q_offset2(params.dq_head_stride,bidh) + m_block * kBlockM_ * seqlen_dq_stride;
const int row_offset_lse = params.cu_seqlens_q == nullptr ? (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM_ : bidh * params.total_q + binfo.sum_s_q + m_block * kBlockM_;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM_;
auto gQ = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(q_ptr) + row_offset_q, seqlen_q_stride);
auto gK = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(k_ptr) + row_offset_k, seqlen_k_stride);
auto gV = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(v_ptr) + row_offset_v, seqlen_v_stride);
auto gdO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_dO, seqlen_do_stride);
Element * gO = reinterpret_cast<Element *>(o_ptr) + row_offset_o;
dq_ptr = reinterpret_cast<Element *>(dq_ptr) + row_offset_dq;
ElementAccum *gLSE = reinterpret_cast<ElementAccum *>(softmax_lse_ptr) + row_offset_lse;
ElementAccum *gdPsum = reinterpret_cast<ElementAccum *>(dsoftmax_sum) + row_offset_dpsum;
constexpr int n_masking_steps = (!Is_causal && !Is_local)
? 1
: ((Is_even_MN && Is_causal) ? flash::ceil_div(kBlockM_, kBlockN_) : flash::ceil_div(kBlockM_, kBlockN_) + 1);
// int warp_id =0;
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
union_vec4_f16x2<Element> q_reg[(K/kBlockK_)*((WARP_M_*kBlockK_)/(32*32))*2];
union_vec4_f16x2<Element> dO_reg[(K_v/kBlockK_)*((WARP_M_*kBlockK_)/(32*32))*2];
union_vec4_fp32 acc_dq[(K/kBlockK_) * ((WARP_M_/32)*(kBlockK_/32))][4]={0};
float lse[WARP_M_/16];
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
int lse_idx = warp_id*WARP_M_ + mi*32 + (lane_id & 15) + min_tile_m * 16;
lse[mi*2 + min_tile_m] = (Is_even_MN || lse_idx < binfo.actual_seqlen_q - m_block * kBlockM_) ? gLSE[lse_idx] : INFINITY;
}
}
float dP_sum_reg[WARP_M_/16];
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
int dP_sum_idx = warp_id*WARP_M_ + mi*32 + (lane_id & 15) + min_tile_m * 16;
dP_sum_reg[mi*2 + min_tile_m] = gdPsum[dP_sum_idx];
}
}
prefetch_to_vgpr_gfx938<true, kBlockM_, K, Element, ElementAccum, Is_even_MN>(gQ, Q_lds, q_reg, (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
prefetch_to_vgpr_gfx938<true, kBlockM_, K, Element, ElementAccum, Is_even_MN>(gdO, dO_lds, dO_reg, (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
if constexpr (Is_preload_V){
prefetch_to_lds_gfx938<true, kBlockN_, K_v, Element, ElementAccum, Is_even_MN>(gV, 0, V_lds, (binfo.actual_seqlen_k - (n_block_max - 1) * kBlockN_), warp_id);
}
if constexpr (Is_preload_K){
prefetch_to_lds_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gK, 0, K_lds, (binfo.actual_seqlen_k - (n_block_max - 1) * kBlockN_), warp_id);
}
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
for (int n_block = n_block_max - 1; n_block >= n_block_min ; --n_block) {
union_vec4_f16x2<Element> v_reg[((WARP_N_*kBlockK_)/(32*32))*2];
union_vec4_fp32 dp_reg[(WARP_M_/32)*(kBlockN_/32)][4]= {0};
//dP gemm
gemm_tt_kq_gfx938<false, Is_preload_K, Is_even_MN, 3, V_prefetch_level, K_v, kBlockM_, kBlockN_, kBlockK_, WARP_N_, WARP_N_, STAGES, Element>(
gdO, gV, dO_lds, V_lds, (binfo.actual_seqlen_q - m_block * kBlockM_), (binfo.actual_seqlen_k - n_block * kBlockN_), dO_reg, v_reg, dp_reg, warp_id, seqlen_do_stride, seqlen_v_stride
);
#ifdef DEBUGING
print_dp(m_block, bidb, bidh);
#endif
union_vec4_f16x2<Element> k_reg[((WARP_M_*kBlockK_)/(32*32))*2];
//c mini tile is 32*32
union_vec4_fp32 s_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
//qk gemm
gemm_tt_kq_gfx938<Is_store_K, false, Is_even_MN, Q_prefetch_level, K_prefetch_level, K, kBlockM_, kBlockN_, kBlockK_, WARP_N_, WARP_N_, STAGES, Element>(
gQ, gK, Q_lds, K_lds, (binfo.actual_seqlen_q - m_block * kBlockM_), (binfo.actual_seqlen_k - n_block * kBlockN_), q_reg, k_reg, s_reg, warp_id, seqlen_q_stride, seqlen_k_stride
);
*(uint64_t*)&gV -= ((kBlockN_ * seqlen_v_stride) * sizeof(Element));
if (Is_preload_V && n_block > n_block_min){
prefetch_to_lds_gfx938<true, kBlockN_, K_v, Element, ElementAccum, Is_even_MN>(gV, 0, V_lds, (binfo.actual_seqlen_k - (n_block - 1) * kBlockN_), warp_id);
}
apply_mask_bwd_gfx938<Is_even_MN, Is_local ? 3 : (Is_causal ? 1 : 0)>(s_reg, binfo.actual_seqlen_q - m_block * kBlockM_ - warp_id * 32, binfo.actual_seqlen_k - n_block * kBlockN_, (m_block * kBlockM_ + warp_id * 32) - (n_block * kBlockN_), params.window_size_left, params.window_size_right);
#ifdef DEBUGING
print_qk(m_block, bidb, bidh);
#endif
scale_apply_exp2_bwd_seq_q_major</*scale_max=*/false, WARP_M_, kBlockN_, union_vec4_fp32, ElementAccum>(s_reg, lse, params.scale_softmax_log2);
#ifdef DEBUGING
print_softmax_rescale_o(m_block, bidb, bidh)
#endif
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
// return p * (dp - d);
};
union_vec4_fp32 dS_reg[(WARP_M_/32)*(kBlockN_/32)][4];
#pragma unroll
for (int ni = 0; ni < (kBlockN_/32); ++ni) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
// result register ds_reg reuse dp_reg
dS_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx] = pointwise_mult(
s_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx],
dp_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx],
dP_sum_reg[min_tile_m + mi*2]);
}
}
}
}
}
#ifdef DEBUGING
print_ds(m_block, bidb, bidh);
#endif
union_vec4_f16x2<Element> dS_reg_fp16[(WARP_M_/32)*(kBlockN_/32)*2];
convert_pk_type_gfx938<WARP_M_, kBlockN_, Element>(dS_reg_fp16, dS_reg);
{
//dq gemm, K*dS
gpu_gemm_B_in_reg_gfx938<Is_store_K , false , Is_even_MN, K, kBlockK_, kBlockM_, kBlockN_, kBlockK_, WARP_M_, 2, Element>(gK, gK, K_lds, dS_reg_fp16, acc_dq, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_k_stride);
}
*(uint64_t*)&gK -= ((kBlockN_ * seqlen_k_stride) * sizeof(Element));
#if 1//defined(__gfx936__)
{
__syncthreads();
if (Is_preload_K && n_block > n_block_min){
prefetch_to_lds_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gK, 0, K_lds, (binfo.actual_seqlen_k - (n_block - 1) * kBlockN_), warp_id);
}
}
#else
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
#endif
}
//mmac
{
int dq_lane_seq_idx = (lane_id >> 4);
int dq_lane_head_dim_idx = (lane_id & 15);
int dq_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K/kBlockK_); k_loop++) {
int dq_block_buffer_store_global_offset = k_loop * kBlockK_;
#pragma unroll
for(int warp_m_idx=0; warp_m_idx<(WARP_M_/32); warp_m_idx++) {
int dq_warp_buffer_store_global_offset = (warp_id*WARP_M_ + warp_m_idx*32 + dq_lane_seq_idx) * seqlen_dq_stride;
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
dq_global_addr_offset = dq_block_buffer_store_global_offset + dq_warp_buffer_store_global_offset + k_tile_idx*32;
#pragma unroll 2
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
#pragma unroll 2
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
int dq_global_addr = dq_global_addr_offset + (min_tile_m*16 + vec_index*4)*seqlen_dq_stride + min_tile_k + dq_lane_head_dim_idx*2;
if(Is_even_MN || ((m_block * kBlockM_) + (warp_id*WARP_M_ + warp_m_idx*32 + dq_lane_seq_idx) + min_tile_m*16 + vec_index*4) < binfo.actual_seqlen_q) {
dq_ptr[dq_global_addr] = DownCast<ElementAccum, Element>(acc_dq[k_loop * ((WARP_M_/32)*(kBlockK_/32)) + (warp_m_idx*(kBlockK_/32) + k_tile_idx)][min_tile_k + min_tile_m*2].f32[vec_index] * params.scale_softmax_rp_dropout);
}
}
}
}
}
}
}
}
}
#undef print_qk
#undef print_softmax_rescale_o
#undef print_dp
#undef print_ds
#ifdef DEBUGING
#define print_qk(block_id_m, bidb, bidh) {\
int qk_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = qk_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
kq_ptr[offset] = s_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_softmax_rescale_o(block_id_m, bidb, bidh) { \
int qk_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = qk_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
s_ptr[offset] = s_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_dp(block_id_m, bidb, bidh) {\
int dp_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int dp_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + dp_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = dp_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
dp_ptr[offset] = dp_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_ds(block_id_m, bidb, bidh) {\
int ds_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int ds_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = ds_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
ds_ptr[offset] = dS_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#endif
template<class Element, class ElementAccum, bool Is_dropout, bool Is_causal , bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel, int kBlockM_, int kBlockN_, int K, int K_v, int kBlockK_, int WARP_M_, int WARP_N_, int STAGES, int USE_BSHD_LAYOUT, typename Params>
__forceinline__ __device__ void compute_dq_1colblock_gfx946(Params &params, int bidb, int bidh, int m_block
) {
#ifdef DEBUGING
ElementAccum * kq_ptr = static_cast<ElementAccum*>(params.kq_ptr);
ElementAccum * s_ptr = static_cast<ElementAccum*>(params.s_ptr);
ElementAccum * dp_ptr = static_cast<ElementAccum*>(params.dp_ptr);
ElementAccum * ds_ptr = static_cast<ElementAccum*>(params.ds_ptr);
#endif
Element* q_ptr = static_cast<Element*>(params.q_ptr);
Element* k_ptr = static_cast<Element*>(params.k_ptr);
Element* v_ptr = static_cast<Element*>(params.v_ptr);
Element* o_ptr = static_cast<Element*>(params.o_ptr);
Element* dq_ptr = static_cast<Element*>(params.dq_ptr);
Element* dk_ptr = static_cast<Element*>(params.dk_ptr);
Element* dv_ptr = static_cast<Element*>(params.dv_ptr);
Element* do_ptr = static_cast<Element*>(params.do_ptr);
ElementAccum* softmax_lse_ptr = static_cast<ElementAccum*>(params.softmax_lse_ptr);
ElementAccum* dsoftmax_sum = static_cast<ElementAccum*>(params.dsoftmax_sum);
//flash-attention QK, kBlockN_==WARP_N_;
const int M_BLOCK_NUM = params.seqlen_q/kBlockM_;
const int N_BLOCK_NUM = params.seqlen_k/kBlockN_;
extern __shared__ Element smem[];
#if 1//defined(__gfx936__)
const bool Is_store_K = true;
const bool Is_preload_K = true;
const bool Is_preload_V = true;
#else
const bool Is_store_K = false;
const bool Is_preload_K = false;
const bool Is_preload_V = false;
#endif
const int K_prefetch_level = Is_preload_K ? 1 : 0;
const int V_prefetch_level = Is_preload_V ? 1 : 0;
const int Q_prefetch_level = 3;
Element* K_lds = (Element*)&(smem);
Element* Q_lds = (Element*)&(smem);
Element* dO_lds = (Element*)&(smem);
Element* V_lds = (Element*)&(smem) + kBlockN_* K;
int tidx = threadIdx.x;
int lane_id = threadIdx.x & 63; //lane id, 0-63
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (m_block < 0 || m_block * kBlockM_ >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM_ - params.window_size_left) / kBlockN_);
const int n_block_max = (!Is_causal && !Is_local) ? ceil_div(binfo.actual_seqlen_k, kBlockN_) : std::min(ceil_div(binfo.actual_seqlen_k, kBlockN_), flash::ceil_div((m_block + 1) * kBlockM_ + params.window_size_right, kBlockN_));
int seqlen_q_stride = params.q_row_stride;
int seqlen_k_stride = params.k_row_stride;
int seqlen_v_stride = params.v_row_stride;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
int seqlen_dq_stride = params.dq_row_stride;
// We move K and V to the last block.
const int row_offset_q = binfo.q_offset1(params.q_batch_stride, params.q_row_stride, bidb) + binfo.q_offset2(params.q_head_stride,bidh) + m_block * kBlockM_ * seqlen_q_stride;
const int row_offset_k = binfo.k_offset1(params.k_batch_stride, params.k_row_stride, bidb) + binfo.k_offset2(params.k_head_stride,bidh/params.h_h_k_ratio) + (n_block_max - 1) * kBlockN_ * seqlen_k_stride;
const int row_offset_v = binfo.k_offset1(params.v_batch_stride, params.v_row_stride, bidb) + binfo.k_offset2(params.v_head_stride,bidh/params.h_h_k_ratio) + (n_block_max - 1) * kBlockN_ * seqlen_v_stride;
const int row_offset_dO = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + m_block * kBlockM_ * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + m_block * kBlockM_ * seqlen_o_stride;
const int row_offset_dq = binfo.q_offset1(params.dq_batch_stride, params.dq_row_stride, bidb) + binfo.q_offset2(params.dq_head_stride,bidh) + m_block * kBlockM_ * seqlen_dq_stride;
const int row_offset_lse = params.cu_seqlens_q == nullptr ? (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM_ : bidh * params.total_q + binfo.sum_s_q + m_block * kBlockM_;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM_;
auto gQ = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(q_ptr) + row_offset_q, seqlen_q_stride);
auto gK = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(k_ptr) + row_offset_k, seqlen_k_stride);
auto gV = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(v_ptr) + row_offset_v, seqlen_v_stride);
auto gdO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_dO, seqlen_do_stride);
Element * gO = reinterpret_cast<Element *>(o_ptr) + row_offset_o;
dq_ptr = reinterpret_cast<Element *>(dq_ptr) + row_offset_dq;
ElementAccum *gLSE = reinterpret_cast<ElementAccum *>(softmax_lse_ptr) + row_offset_lse;
ElementAccum *gdPsum = reinterpret_cast<ElementAccum *>(dsoftmax_sum) + row_offset_dpsum;
constexpr int n_masking_steps = (!Is_causal && !Is_local)
? 1
: ((Is_even_MN && Is_causal) ? flash::ceil_div(kBlockM_, kBlockN_) : flash::ceil_div(kBlockM_, kBlockN_) + 1);
// int warp_id =0;
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
union_vec4_f16x2<Element> q_reg[(K/kBlockK_)*((WARP_M_*kBlockK_)/(32*32))*2];
union_vec4_f16x2<Element> dO_reg[(K_v/kBlockK_)*((WARP_M_*kBlockK_)/(32*32))*2];
union_vec4_fp32 acc_dq[(K/kBlockK_) * ((WARP_M_/32)*(kBlockK_/32))][4]={0};
float lse[WARP_M_/16];
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
int lse_idx = warp_id*WARP_M_ + mi*32 + (lane_id & 15) + min_tile_m * 16;
lse[mi*2 + min_tile_m] = (Is_even_MN || lse_idx < binfo.actual_seqlen_q - m_block * kBlockM_) ? gLSE[lse_idx] : INFINITY;
}
}
float dP_sum_reg[WARP_M_/16];
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
int dP_sum_idx = warp_id*WARP_M_ + mi*32 + (lane_id & 15) + min_tile_m * 16;
dP_sum_reg[mi*2 + min_tile_m] = gdPsum[dP_sum_idx];
}
}
prefetch_to_vgpr_gfx938<true, kBlockM_, K, Element, ElementAccum, Is_even_MN>(gQ, Q_lds, q_reg, (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
prefetch_to_vgpr_gfx938<true, kBlockM_, K, Element, ElementAccum, Is_even_MN>(gdO, dO_lds, dO_reg, (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
if constexpr (Is_preload_V){
prefetch_to_lds_gfx938<true, kBlockN_, K_v, Element, ElementAccum, Is_even_MN>(gV, 0, V_lds, (binfo.actual_seqlen_k - (n_block_max - 1) * kBlockN_), warp_id);
}
if constexpr (Is_preload_K){
prefetch_to_lds_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gK, 0, K_lds, (binfo.actual_seqlen_k - (n_block_max - 1) * kBlockN_), warp_id);
}
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
for (int n_block = n_block_max - 1; n_block >= n_block_min ; --n_block) {
union_vec4_f16x2<Element> v_reg[((WARP_N_*kBlockK_)/(32*32))*2];
union_vec4_fp32 dp_reg[(WARP_M_/32)*(kBlockN_/32)][4]= {0};
//dP gemm
gemm_tt_kq_gfx938<false, Is_preload_K, Is_even_MN, 3, V_prefetch_level, K_v, kBlockM_, kBlockN_, kBlockK_, WARP_N_, WARP_N_, STAGES, Element>(
gdO, gV, dO_lds, V_lds, (binfo.actual_seqlen_q - m_block * kBlockM_), (binfo.actual_seqlen_k - n_block * kBlockN_), dO_reg, v_reg, dp_reg, warp_id, seqlen_do_stride, seqlen_v_stride
);
#ifdef DEBUGING
print_dp(m_block, bidb, bidh);
#endif
union_vec4_f16x2<Element> k_reg[((WARP_M_*kBlockK_)/(32*32))*2];
//c mini tile is 32*32
union_vec4_fp32 s_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
//qk gemm
gemm_tt_kq_gfx938<Is_store_K, false, Is_even_MN, Q_prefetch_level, K_prefetch_level, K, kBlockM_, kBlockN_, kBlockK_, WARP_N_, WARP_N_, STAGES, Element>(
gQ, gK, Q_lds, K_lds, (binfo.actual_seqlen_q - m_block * kBlockM_), (binfo.actual_seqlen_k - n_block * kBlockN_), q_reg, k_reg, s_reg, warp_id, seqlen_q_stride, seqlen_k_stride
);
*(uint64_t*)&gV -= ((kBlockN_ * seqlen_v_stride) * sizeof(Element));
if (Is_preload_V && n_block > n_block_min){
prefetch_to_lds_gfx938<true, kBlockN_, K_v, Element, ElementAccum, Is_even_MN>(gV, 0, V_lds, (binfo.actual_seqlen_k - (n_block - 1) * kBlockN_), warp_id);
}
apply_mask_bwd_gfx938<Is_even_MN, Is_local ? 3 : (Is_causal ? 1 : 0)>(s_reg, binfo.actual_seqlen_q - m_block * kBlockM_ - warp_id * 32, binfo.actual_seqlen_k - n_block * kBlockN_, (m_block * kBlockM_ + warp_id * 32) - (n_block * kBlockN_), params.window_size_left, params.window_size_right);
#ifdef DEBUGING
print_qk(m_block, bidb, bidh);
#endif
scale_apply_exp2_bwd_seq_q_major</*scale_max=*/false, WARP_M_, kBlockN_, union_vec4_fp32, ElementAccum>(s_reg, lse, params.scale_softmax_log2);
#ifdef DEBUGING
print_softmax_rescale_o(m_block, bidb, bidh)
#endif
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
// return p * (dp - d);
};
union_vec4_fp32 dS_reg[(WARP_M_/32)*(kBlockN_/32)][4];
#pragma unroll
for (int ni = 0; ni < (kBlockN_/32); ++ni) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
// result register ds_reg reuse dp_reg
dS_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx] = pointwise_mult(
s_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx],
dp_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx],
dP_sum_reg[min_tile_m + mi*2]);
// dS_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx] = s_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
#ifdef DEBUGING
print_ds(m_block, bidb, bidh);
#endif
union_vec4_f16x2<Element> dS_reg_fp16[(WARP_M_/32)*(kBlockN_/32)*2];
convert_pk_type_gfx938<WARP_M_, kBlockN_, Element>(dS_reg_fp16, dS_reg);
{
//dq gemm, K*dS
gpu_gemm_B_in_reg_gfx946<Is_store_K , false , Is_even_MN, K, kBlockK_, kBlockM_, kBlockN_, kBlockK_, WARP_M_, 2, Element>(gK, gK, K_lds, dS_reg_fp16, acc_dq, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_k_stride);
}
*(uint64_t*)&gK -= ((kBlockN_ * seqlen_k_stride) * sizeof(Element));
// if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0){
// printf("(binfo.actual_seqlen_k - n_block * kBlockN_) = %d\n", (binfo.actual_seqlen_k - n_block * kBlockN_));
// }
#if 1//defined(__gfx936__)
{
__syncthreads();
if (Is_preload_K && n_block > n_block_min){
prefetch_to_lds_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gK, 0, K_lds, (binfo.actual_seqlen_k - (n_block - 1) * kBlockN_), warp_id);
}
}
#else
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
#endif
}
#if 1
//这是正常的MLS+ds_read_matrix的layout
{
dq_ptr = dq_ptr + binfo.q_offset1(params.dq_batch_stride, params.dq_row_stride, bidb) + binfo.q_offset2(params.dq_head_stride,bidh);
auto gdQ = tcp_cache_swizzle_func<K_v, Element>(dq_ptr);
int dq_lane_seq_idx = (lane_id >> 4);
int dq_lane_head_dim_idx = (lane_id & 15);
int dq_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K_v/kBlockK_); k_loop++) {
#pragma unroll
for(int warp_m_idx=0; warp_m_idx<(WARP_M_/32); warp_m_idx++) {
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
int v_offset = dq_lane_head_dim_idx * seqlen_dq_stride + dq_lane_seq_idx * 4;
int s_offset = (min_tile_m * seqlen_dq_stride * 16 + vec_index % 2 * 2 + vec_index / 2 * 16) + (k_tile_idx*32) + ((warp_id*WARP_M_ + warp_m_idx*32) * seqlen_dq_stride) + (k_loop * kBlockK_ + m_block * kBlockM_ * seqlen_dq_stride);
int known_offset = 0;
vec2_Element<Element> v_data;
v_data[0] = DownCast<float,Element,true>(acc_dq[k_loop * ((WARP_M_/32)*(kBlockK_/32)) + (warp_m_idx*(kBlockK_/32) + k_tile_idx)][min_tile_m*2 + vec_index / 2].f32[vec_index % 2 * 2] * params.scale_softmax_rp_dropout);
v_data[1] = DownCast<float,Element,true>(acc_dq[k_loop * ((WARP_M_/32)*(kBlockK_/32)) + (warp_m_idx*(kBlockK_/32) + k_tile_idx)][min_tile_m*2 + vec_index / 2].f32[vec_index % 2 * 2 + 1] * params.scale_softmax_rp_dropout);
if (Is_even_MN || min_tile_m*16 + (warp_id*WARP_M_ + warp_m_idx*32) + m_block * kBlockM_ + dq_lane_head_dim_idx < binfo.actual_seqlen_q){
inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdQ, s_offset, /* immediate integer */known_offset);
}
}
}
}
}
}
}
#endif
}
#undef print_qk
#undef print_softmax_rescale_o
#undef print_dp
#undef print_ds
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment