Commit 518a5f4d authored by hly's avatar hly
Browse files

import aicc-master-dev

parent c2a1b310
......@@ -5,4 +5,3 @@
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_unified_dispatch<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);
......@@ -5,4 +5,3 @@
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_unified_dispatch<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);
......@@ -5,4 +5,3 @@
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_unified_dispatch<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);
......@@ -5,4 +5,3 @@
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_unified_dispatch<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);
......@@ -48,7 +48,7 @@ static __device__ inline float to_float(scalar_t in){
inline __device__ float uint82float(const uint8_t& input) {
#if (defined(__gfx938__) )
#if (defined(__gfx938__) ||defined(__gfx92a__))
return __builtin_hcu_cvt_f32_fp8(input,false,0,0);
#else
const uint32_t w = (uint32_t)input << 24;
......@@ -137,11 +137,11 @@ __forceinline__ __device__ scalar_t uint82half(const uint8_t& input) {
#define REUSEKV_SWITCH(reusekv,...) \
[&] { \
if (reusekv==48){ \
constexpr static int REUSE_KV_TIMES = 48; \
if (reusekv==64){ \
constexpr static int REUSE_KV_TIMES = 64; \
return __VA_ARGS__(); \
}else if (reusekv==36){ \
constexpr static int REUSE_KV_TIMES = 36; \
}else if (reusekv==48){ \
constexpr static int REUSE_KV_TIMES = 48; \
return __VA_ARGS__(); \
}else if (reusekv==32){ \
constexpr static int REUSE_KV_TIMES = 32; \
......@@ -257,18 +257,11 @@ inline __device__ float block_sum(float* red_smem, float sum) {
template<bool is_half>
inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c)
{
#if (defined(__gfx938__) )
if constexpr (is_half){reg_c=__builtin_hcu_mmac_f32_16x16x16_f16_lit_lts(reg_a,reg_b,reg_c,false,false);}
else{
reg_c=__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(*(v4bh*)&reg_a,*(v4bh*)&reg_b,reg_c,false,false);
}
#else
if constexpr (is_half){reg_c=__builtin_amdgcn_mmac_f32_16x16x16f16(reg_a,reg_b,reg_c);}
else{
reg_c=__builtin_amdgcn_mmac_f32_16x16x16bf16(*(v4bh*)&reg_a,*(v4bh*)&reg_b,reg_c);
if constexpr (is_half){
reg_c=__builtin_hcu_mmac_f32_16x16x16_f16(reg_a,reg_b,reg_c);
}else{
reg_c=__builtin_hcu_mmac_f32_16x16x16_bf16(*(v4bh*)&reg_a,*(v4bh*)&reg_b,reg_c);
}
#endif
}
......@@ -390,23 +383,28 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
for(int m=0;m<Mloop;m++){
qk_vec[m]={0,0,0,0};
}
half4x2 k_vec[HEAD_SIZE/32];
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int i=0;i<HEAD_SIZE/32;i++){
half4x2 k_vec;
if constexpr(is_fp8){
uint8x4x2 k_vec_u8=*reinterpret_cast<const uint8x4x2*>(k_ptr+i*32+rowid*HEAD_SIZE+rows*8);
scalar_t *p1=(scalar_t*)&k_vec;
scalar_t *p1=(scalar_t*)(k_vec+i);
uint8_t *p2=(uint8_t*)&k_vec_u8;
for(int ii=0;ii<8;ii++){
p1[ii]=uint82half<scalar_t,is_e4m3>(p2[ii]);
}
}
else{
k_vec=*reinterpret_cast<const half4x2*>(k_ptr+i*32+rowid*HEAD_SIZE+rows*8);
k_vec[i]=*reinterpret_cast<const half4x2*>(k_ptr+i*32+rowid*HEAD_SIZE+rows*8);
}
}
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int i=0;i<HEAD_SIZE/32;i++){
for(int m=0;m<Mloop;m++){
builtin_amdgcn_mmac<is_half>(k_vec.data[0],q_vec[m][i].data[0],qk_vec[m]);
builtin_amdgcn_mmac<is_half>(k_vec.data[1],q_vec[m][i].data[1],qk_vec[m]);
builtin_amdgcn_mmac<is_half>(k_vec[i].data[0],q_vec[m][i].data[0],qk_vec[m]);
builtin_amdgcn_mmac<is_half>(k_vec[i].data[1],q_vec[m][i].data[1],qk_vec[m]);
}
}
#pragma unroll
......@@ -597,7 +595,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride;
if(partition_idx<num_partitions-1){
if(partition_idx<num_partitions-1||block_idx < num_seq_blocks-1){
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
int offset=i*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD+warp_idx*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD/NUM_WARPS+rows*vecsize*4+rowid*BLOCK_SIZE;
......@@ -635,13 +633,11 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
v_vec=*reinterpret_cast<const half4_vec*>(v_ptr + offset);
}
//这里的if判断会影响一定的性能,因此只有最后一个patition才判断
if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < 4*vecsize; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : 0;
}
}
for(int ii=0;ii<vecsize;ii++){
for(int m=0;m<Mloop;m++){
builtin_amdgcn_mmac<is_half>(v_vec.data[ii],logits_vec[m].data[ii],accs[m][i]);
......@@ -756,8 +752,8 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_combine(
}
static int get_reusekv(int qhead,int kv_head){
if(qhead>kv_head*36) return 48;
if(qhead>kv_head*32) return 36;//glm4.7 mtp 3
if(qhead>kv_head*48) return 64;
if(qhead>kv_head*32) return 48;
if(qhead>kv_head*24) return 32;
if(qhead>kv_head*16) return 24;
if(qhead>kv_head*8) return 16;
......@@ -831,7 +827,7 @@ void paged_attention(
hipMalloc(&tmp_out_ptr, temp_out_size); // 100m
hipMemset(tmp_out_ptr,0,temp_out_size);
}
if(device_name=="gfx938"&&(key_cache.dtype()==torch::kFloat8_e5m2||key_cache.dtype()==torch::kFloat8_e4m3fn)){
if((device_name=="gfx938"|| device_name == "gfx92a")&&(key_cache.dtype()==torch::kFloat8_e5m2||key_cache.dtype()==torch::kFloat8_e4m3fn)){
paged_attention_938(out,query,key_cache,value_cache,block_tables,seq_lens,alibi_slopes,q_scale,k_scale,v_scale,max_seq_len,s_aux_,tmp_out_ptr,PARTITION_SIZE);
return;
}
......@@ -876,16 +872,16 @@ void paged_attention(
grid.x = num_kv_heads;
grid.y = num_seqs;
AT_ASSERTM(headsize%64==0 && headsize<=256, "Page Attention head size must be 64, 128, 192 or 256");
AT_ASSERTM(num_heads<=num_kv_heads*48, "Page Attention qheads*mtp/kvheads must be smaller than 48");
AT_ASSERTM(num_heads<=num_kv_heads*64, "Page Attention qheads*mtp/kvheads must be smaller than 48");
HEADSIZE_SWITCH(headsize,[&]{
Input_Type_SWITCH(query.dtype(),[&]{
Cache_Type_SWITCH(scalar_t,key_cache.dtype(),[&] {
REUSEKV_SWITCH(reusekv,[&] {
BOOL_SWITCH(block_size==64,is_block64,[&]{
constexpr int BLOCK_SIZE = (is_block64?64:128);
// constexpr int BLOCK_SIZE=128;
// constexpr int BLOCK_SIZE = (is_block64?64:128);
constexpr int BLOCK_SIZE=64;
// constexpr int HEAD_SIZE=128;
// using scalar_t=_Float16;
// using scalar_t=uint16_t;
// using cache_t = scalar_t;
constexpr bool is_e4m3=false;
// constexpr static int REUSE_KV_TIMES = 4;
......
......@@ -58,7 +58,7 @@ static __device__ inline float to_float(scalar_t in){
inline __device__ float uint82float(const uint8_t& input) {
#if (defined(__gfx938__) )
#if (defined(__gfx938__)||defined(__gfx92a__) )
return __builtin_hcu_cvt_f32_fp8(input,false,0,0);
#else
const uint32_t w = (uint32_t)input << 24;
......@@ -106,7 +106,7 @@ __forceinline__ __device__ scalar_t uint82half(const uint8_t& input) {
template <bool is_e4m3>
static __device__ int to_f8_from_f32(float v1,float v2,float v3,float v4) {
int val=0;
#if (defined(__gfx938__) )
#if (defined(__gfx938__) || defined(__gfx92a__))
if constexpr(is_e4m3){
val = __builtin_hcu_cvt_pk_fp8_f32(v1,v2,val,false);
val = __builtin_hcu_cvt_pk_fp8_f32(v3,v4,val,true);
......@@ -122,7 +122,7 @@ static __device__ int to_f8_from_f32(float v1,float v2,float v3,float v4) {
template <bool is_e4m3>
static __device__ float4_t to_fp32_from_fp8(int val) {
float4_t ret;
#if (defined(__gfx938__) )
#if (defined(__gfx938__) || defined(__gfx92a__))
if constexpr(is_e4m3){
ret[0] = __builtin_hcu_cvt_f32_fp8(val,false,0,0);
ret[1] = __builtin_hcu_cvt_f32_fp8(val,false,0,1);
......@@ -184,11 +184,11 @@ static __device__ float4_t to_fp32_from_fp8(int val) {
#define REUSEKV_SWITCH(reusekv,...) \
[&] { \
if (reusekv==48){ \
constexpr static int REUSE_KV_TIMES = 48; \
if (reusekv==64){ \
constexpr static int REUSE_KV_TIMES = 64; \
return __VA_ARGS__(); \
}else if (reusekv==36){ \
constexpr static int REUSE_KV_TIMES = 36; \
}else if (reusekv==48){ \
constexpr static int REUSE_KV_TIMES = 48; \
return __VA_ARGS__(); \
}else if (reusekv==32){ \
constexpr static int REUSE_KV_TIMES = 32; \
......@@ -303,13 +303,11 @@ inline __device__ float block_sum(float* red_smem, float sum) {
template<bool is_e4m3>
inline __device__ void builtin_amdgcn_mmac(const intx2& reg_a, const intx2& reg_b, float4_t& reg_c)
{
#if (defined(__gfx938__) )
if constexpr(is_e4m3){
reg_c=__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(reg_a,reg_b,reg_c,false,false);
reg_c=__builtin_hcu_mmac_f32_16x16x32_fp8_fp8(reg_a,reg_b,reg_c);
}else{
reg_c=__builtin_hcu_mmac_f32_16x16x32_bf8_bf8_lit_lts(reg_a,reg_b,reg_c,false,false);
reg_c=__builtin_hcu_mmac_f32_16x16x32_bf8_bf8(reg_a,reg_b,reg_c);
}
#endif
}
template <typename scalar_t,typename q_type,bool is_e4m3 ,int HEAD_SIZE, int BLOCK_SIZE,
......@@ -332,7 +330,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
const float* q_scale_ptr, const float* k_scale_ptr, const float* v_scale_ptr,
int max_num_partitions,int PARTITION_SIZE,
const scalar_t* __restrict__ s_aux_ptr,int mtp,bool has_abili) { // ★ Attention Sinks: [num_heads] scalar_t ★
#if (defined(__gfx938__) )
#if (defined(__gfx938__)||defined(__gfx92a__) )
const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
constexpr int kv_head_stride=BLOCK_SIZE*HEAD_SIZE;
......@@ -453,10 +451,16 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
for(int m=0;m<Mloop;m++){
qk_vec[m]={0,0,0,0};
}
intx4 k_vec[HEAD_SIZE/64];
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int i=0;i<HEAD_SIZE/64;i++){
intx4 k_vec=*reinterpret_cast<const intx4*>(k_ptr+i*64+rowid*HEAD_SIZE+rows*16);
intx2 *k_vec_2 = (intx2*)&k_vec;
k_vec[i]=*reinterpret_cast<const intx4*>(k_ptr+i*64+rowid*HEAD_SIZE+rows*16);
}
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int i=0;i<HEAD_SIZE/64;i++){
intx2 *k_vec_2 = (intx2*)(k_vec+i);
for(int m=0;m<Mloop;m++){
intx2 *q_vec_2 = (intx2*)(&q_vec[m][i]);
builtin_amdgcn_mmac<is_e4m3>(k_vec_2[0],q_vec_2[0],qk_vec[m]);
......@@ -655,7 +659,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
}
const uint8_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride;
if(partition_idx<num_partitions-1){
if(partition_idx<num_partitions-1||block_idx < num_seq_blocks-1){
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
int offset=i*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD+warp_idx*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD/NUM_WARPS+rows*16+rowid*BLOCK_SIZE;
......@@ -673,13 +677,11 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
int offset=i*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD+warp_idx*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD/NUM_WARPS+rows*16+rowid*BLOCK_SIZE;
int_vec v_vec = *reinterpret_cast<const int_vec*>(v_ptr + offset);
//这里的if判断会影响一定的性能,因此只有最后一个patition才判断
if (block_idx == num_seq_blocks - 1) {
uint8_t* v_vec_ptr = reinterpret_cast<uint8_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < 16; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : 0;
}
}
for(int ii=0;ii<vecsize;ii++){
for(int m=0;m<Mloop;m++){
builtin_amdgcn_mmac<is_e4m3>(v_vec.data[ii],logits_vec[m][ii],accs[m][i]);
......@@ -795,8 +797,8 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_combine(
}
static int get_reusekv(int qhead,int kv_head){
if(qhead>kv_head*36) return 48;
if(qhead>kv_head*32) return 36;//glm4.7 mtp 3
if(qhead>kv_head*48) return 64;
if(qhead>kv_head*32) return 48;
if(qhead>kv_head*24) return 32;
if(qhead>kv_head*16) return 24;
if(qhead>kv_head*8) return 16;
......@@ -870,17 +872,18 @@ void paged_attention_938(
int reusekv=get_reusekv(num_heads,num_kv_heads);
int headsize=query.size(3);
AT_ASSERTM(headsize%64==0 && headsize<=256, "Page Attention head size must be 64, 128, 192 or 256");
AT_ASSERTM(num_heads<=num_kv_heads*48, "Page Attention qheads*mtp/kvheads must be smaller than 48");
AT_ASSERTM(num_heads<=num_kv_heads*64, "Page Attention qheads*mtp/kvheads must be smaller than 48");
HEADSIZE_SWITCH(headsize,[&]{
Output_Type_SWITCH(out.dtype(),[&]{
Input_Type_SWITCH(scalar_t,query.dtype(),key_cache.dtype(),[&] {
REUSEKV_SWITCH(reusekv,[&] {
BOOL_SWITCH(block_size==64,is_block64,[&]{
constexpr int BLOCK_SIZE = (is_block64?64:128);
// constexpr int HEAD_SIZE=128;
// constexpr int BLOCK_SIZE = (is_block64?64:128);
constexpr int BLOCK_SIZE=64;
// constexpr int HEAD_SIZE=256;
// using scalar_t=uint16_t;
// constexpr bool is_e4m3=true;
// constexpr static int REUSE_KV_TIMES = 4;
// constexpr static int REUSE_KV_TIMES = 64;
// constexpr bool has_abili=false;
// constexpr bool use_mtp=false;
constexpr static int NUM_THREADS = 256;
......
......@@ -26,6 +26,7 @@ __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &t
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) {
#if defined(__gfx928__)
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
......@@ -36,6 +37,29 @@ __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &t
// printf("thread_reduce_ mi:%d ni:%d %7.4f %7.4f %7.4f\n", mi, ni, ori, tensor(mi, ni), summary(mi));
// }
}
#else
if constexpr (std::is_same_v<Operator, SumOp<float>>) {
using __float2 = __attribute__((ext_vector_type(2))) float;
__float2 sum_v = {zero_init ? 0.0f : summary(mi), 0.0f};
for (int ni = 0; ni < size<1>(tensor); ni += 2) {
__float2 vx2 = {tensor(mi, ni), tensor(mi, ni + 1)};
sum_v = __builtin_hcu_pk_add_f32(sum_v, vx2);
}
summary(mi) = sum_v.x + sum_v.y;
}
else {
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
// float ori = summary(mi);
summary(mi) = op(summary(mi), tensor(mi, ni));
// wangaq debug
// if (thread0()) {
// printf("thread_reduce_ mi:%d ni:%d %7.4f %7.4f %7.4f\n", mi, ni, ori, tensor(mi, ni), summary(mi));
// }
}
}
#endif
}
}
......@@ -131,6 +155,7 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tenso
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
#if defined(__gfx928__)
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
......@@ -141,6 +166,17 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tenso
// This macro is set in PyTorch and not FlashAttention
tensor(mi, ni) = custom_exp2f(tensor(mi, ni) * scale - max_scaled);
}
#else
using __float2 = __attribute__((ext_vector_type(2))) float;
__float2 scalex2 = {scale, scale};
__float2 max_scaledx2 = {-max_scaled, -max_scaled};
for (int ni = 0; ni < size<1>(tensor); ni += 2) {
__float2 vx2 = {tensor(mi, ni), tensor(mi, ni + 1)};
__float2 res = __builtin_hcu_pk_fma_f32(vx2, scalex2, max_scaledx2);
tensor(mi, ni) = custom_exp2f(res.x);
tensor(mi, ni + 1) = custom_exp2f(res.y);
}
#endif
}
}
......@@ -229,8 +265,19 @@ struct Softmax {
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
float scores_scale = custom_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
row_sum(mi) *= scores_scale;
#if defined(__gfx928__)
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
#else
using __float2 = __attribute__((ext_vector_type(2))) float;
__float2 scores_scalex2 = {scores_scale, scores_scale};
for (int ni = 0; ni < size<1>(acc_o_rowcol); ni += 2) {
__float2 vx2 = {acc_o_rowcol(mi, ni), acc_o_rowcol(mi, ni + 1)};
__float2 res = __builtin_hcu_pk_mul_f32(vx2, scores_scalex2);
acc_o_rowcol(mi, ni) = res.x;
acc_o_rowcol(mi, ni + 1) = res.y;
}
#endif
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
......@@ -584,8 +631,19 @@ struct Softmax {
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#if defined(__gfx928__)
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
#else
using __float2 = __attribute__((ext_vector_type(2))) float;
__float2 scores_scalex2 = {scale, scale};
for (int ni = 0; ni < size<1>(acc_o_rowcol); ni += 2) {
__float2 vx2 = {acc_o_rowcol(mi, ni), acc_o_rowcol(mi, ni + 1)};
__float2 res = __builtin_hcu_pk_mul_f32(vx2, scores_scalex2);
acc_o_rowcol(mi, ni) = res.x;
acc_o_rowcol(mi, ni + 1) = res.y;
}
#endif
}
return lse;
};
......
......@@ -33,16 +33,20 @@ __forceinline__ __device__ void s_nop() {
}
__forceinline__ __device__ void s_barrier() {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier");
__builtin_amdgcn_sched_barrier(0);
}
template<const int COUNT>
__forceinline__ __device__ void s_waitcnt() {
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_waitcnt vmcnt(%0)\n\t"
"s_barrier\n"
:: "B"(COUNT)
:);
__builtin_amdgcn_sched_barrier(0);
}
template<const int COUNT>
......@@ -1392,12 +1396,14 @@ lds_direct_copy(
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0);
#endif
} else if constexpr(mma_layout == _64x16) {
constexpr int elements_per_thread = 4;
......@@ -1413,12 +1419,14 @@ lds_direct_copy(
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0);
#endif
} else if constexpr(mma_layout == _16x128) {
constexpr int elements_per_thread = 8;
......@@ -1435,12 +1443,14 @@ lds_direct_copy(
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0);
#endif
} else if constexpr(mma_layout == _16x192) {
constexpr int elements_per_thread = 8;
......@@ -1457,12 +1467,14 @@ lds_direct_copy(
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0);
#endif
constexpr int elements_per_thread_tail = 4;
......@@ -1481,12 +1493,14 @@ lds_direct_copy(
// if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0);
#endif
} else if constexpr(mma_layout == _16x64_128) {
constexpr int elements_per_thread = 4;
......@@ -1505,12 +1519,14 @@ lds_direct_copy(
// if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0);
#endif
} else if constexpr(mma_layout == _16x64_64) {
constexpr int elements_per_thread = 4;
......@@ -1529,12 +1545,14 @@ lds_direct_copy(
// if (tidx < 64) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0);
#endif
} else if constexpr(mma_layout == _16x96) {
constexpr int elements_per_thread = 8;
......@@ -1552,12 +1570,14 @@ lds_direct_copy(
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
if (warp_id < 3) {
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0);
#endif
}
} else if constexpr(mma_layout == _16x96_multi_ins) {
......@@ -1575,12 +1595,14 @@ lds_direct_copy(
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0);
#endif
constexpr int elements_per_thread_tail = 2;
......@@ -1599,17 +1621,188 @@ lds_direct_copy(
// if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dword %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0);
#endif
}
}
template <
int k_idx,
int K_BUFF_SIZE = 0,
bool Is_even_MN=true,
MMA_LAYOUT mma_layout = _64x32,
int n_idx = 0,
bool Use_cache_swizzle = true,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy_even_k_dim256(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
const int row_stride,
const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 2;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
if constexpr (Use_cache_swizzle) {
glob_ptr.latter += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
}
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
if constexpr(mma_layout == _64x32) {
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 32*64;
int row = tidx % 16;
int col = lane / 16;
int row_offset = row * 4 + warp_id;
int col_offset = col * elements_per_thread;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + (k_idx % K_BUFF_SIZE) * mma_k * element_size;
const int offset_s = k_idx * 32 * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0);
#endif
} else if constexpr(mma_layout == _16x256) {
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*128;
int row = lane / 4;
int col = tidx % 4;
int row_offset = row + k_idx * 16;
int col_offset = col * elements_per_thread;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
const int offset_s = (warp_id * 32 + n_idx * 128) * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0);
#endif
}
}
template <
int k_idx,
bool Is_even_MN=true,
MMA_LAYOUT mma_layout = _64x32,
bool Use_cache_swizzle = true,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy_even_k(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
const int row_stride,
const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 2;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
if constexpr (Use_cache_swizzle) {
glob_ptr.latter += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
}
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
if constexpr(mma_layout == _64x32) {
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 32*64;
int row = tidx % 16;
int col = lane / 16;
int row_offset = row * 4 + warp_id;
int col_offset = col * elements_per_thread;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
const int offset_s = k_idx * 32 * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0);
#endif
} else if constexpr(mma_layout == _16x64_64) {
constexpr int elements_per_thread = 4;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*64;
int row = (tidx / 8) % 16;
int col = tidx % 8;
int row_offset = row + k_idx * 16;
int col_offset = col * elements_per_thread;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
// if (tidx < 64) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
const int offset_s = warp_id / 2 * 32 * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0);
#endif
}
}
#define fp8 unsigned char
__forceinline__ __device__ float fp8e5m2_to_fp32(const fp8& input) {
union uf16{
......@@ -1769,7 +1962,7 @@ lds_direct_copy(int k_slide,
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
......@@ -1791,7 +1984,7 @@ lds_direct_copy(int k_slide,
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
......@@ -1813,7 +2006,7 @@ lds_direct_copy(int k_slide,
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
......@@ -1837,7 +2030,7 @@ lds_direct_copy(int k_slide,
// if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
......@@ -1861,7 +2054,7 @@ lds_direct_copy(int k_slide,
// if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
......@@ -1885,7 +2078,7 @@ lds_direct_copy(int k_slide,
// if (tidx < 64) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
......@@ -1908,7 +2101,7 @@ lds_direct_copy(int k_slide,
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
if (warp_id < 3) {
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
......@@ -1931,7 +2124,7 @@ lds_direct_copy(int k_slide,
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
......@@ -1955,7 +2148,7 @@ lds_direct_copy(int k_slide,
// if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dword %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
......@@ -2021,7 +2214,7 @@ lds_direct_copy(int n_idx, int k_slide,
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
......@@ -2101,7 +2294,7 @@ lds_direct_copy_fp8(
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if defined(__gfx938__)
#if defined(__gfx938__)||defined(__gfx92a__)
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
......@@ -2123,7 +2316,7 @@ lds_direct_copy_fp8(
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if defined(__gfx938__)
#if defined(__gfx938__)||defined(__gfx92a__)
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
......@@ -2144,7 +2337,7 @@ lds_direct_copy_fp8(
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if defined(__gfx938__)
#if defined(__gfx938__)||defined(__gfx92a__)
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
......@@ -2167,7 +2360,7 @@ lds_direct_copy_fp8(
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if defined(__gfx938__)
#if defined(__gfx938__)||defined(__gfx92a__)
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
......@@ -2244,7 +2437,7 @@ lds_direct_copy_for_vertical_sparse(
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,idxen offen offset:0, lds \n" ::"v"(offset_v),
......@@ -2289,7 +2482,7 @@ lds_direct_copy_for_vertical_sparse(
// int index_v = offset_v;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_slide * mma_k * element_size;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,idxen offen offset:0, lds \n" ::"v"(offset_v),
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -20,7 +20,20 @@ void run_mha_fwd(Flash_fwd_params &params, hipStream_t stream, bool force_split_
}
if (params.seqused_k != nullptr) {
// Prefix prefill attention
if (!params.is_int8){
if (params.is_e4m3) {
// FP8 prefix prefill
FP16_SWITCH(!params.is_bf16, [&] {
if (params.d == 128 and params.d_value == 128) {
run_fp8_mha_fwd_prefix_prefill_<elem_type, 128, 128>(params, stream);
} else if (params.d == 192 and params.d_value == 128) {
run_fp8_mha_fwd_prefix_prefill_<elem_type, 192, 128>(params, stream);
} else if (params.d == 256 and params.d_value == 256) {
run_fp8_mha_fwd_prefix_prefill_<elem_type, 256, 256>(params, stream);
} else {
assert(false && "FP8 prefix prefill only supports head_dim=128/128, 192/128, or 256/256");
}
});
} else if (!params.is_int8){
FP16_SWITCH(!params.is_bf16, [&] {
if (params.d == 128 and params.d_value == 128) {
run_mha_fwd_prefix_prefill_<elem_type, 128, 128>(params, stream);
......@@ -65,6 +78,13 @@ void run_mha_fwd(Flash_fwd_params &params, hipStream_t stream, bool force_split_
else {
// Decoder-only attention
FP16_SWITCH(!params.is_bf16, [&] {
if (params.is_e4m3) {
if (params.d == 128 and params.d_value == 128) {
run_fp8_mha_fwd_<elem_type, 128, 128>(params, stream);
} else {
assert(false && "FP8 forward only supports head_dim=128");
}
} else {
#if defined(HEADDIM_128_ONLY)
run_mha_fwd_<elem_type, 128, 128>(params, stream);
#elif defined(HEADDIM_192_128_ONLY)
......@@ -74,6 +94,7 @@ void run_mha_fwd(Flash_fwd_params &params, hipStream_t stream, bool force_split_
run_mha_fwd_<elem_type, kHeadDimQ, kHeadDimV>(params, stream);
});
#endif
}
});
}
#endif
......
#pragma once
#include <block_info.h>
#include "utils.h"
#include "prefetch.h"
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template<bool Clear_dQaccum=true, bool Is_even_MN, class Element, class ElementAccum, int kBlockM, int kBlockN, int WARP_M, int WARP_N, int K, int STAGES, bool USE_BSHD_LAYOUT, typename Params>
inline __device__ void compute_dot_do_o_gfx946(const Params &params) {
Element *do_ptr = static_cast<Element*>(params.do_ptr);
Element *o_ptr = static_cast<Element*>(params.o_ptr);
ElementAccum* dsoftmax_sum = static_cast<ElementAccum*>(params.dsoftmax_sum);
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.z;
// The block index for the head.
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;
//wave size should be defined in launch file. Here use 64 threads
int lane_id = threadIdx.x & 63; //lane id, 0-63
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = 0;
__shared__ Element dO_lds[kBlockM * kBlockN];
__shared__ Element O_lds[kBlockM * kBlockN];
float dP_sum_cur[(kBlockM/16)] = {0.0f};
const int WARP_NUM = (kBlockM)/(WARP_M);
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
const int row_offset_do = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + m_block * kBlockM * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + m_block * kBlockM * seqlen_o_stride;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
auto gdO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_do, seqlen_do_stride);
auto gO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(o_ptr) + row_offset_o, seqlen_o_stride);
ElementAccum *dP_sum = reinterpret_cast<ElementAccum *>(dsoftmax_sum) + row_offset_dpsum;
asm volatile("v_readfirstlane_b32 %0,%1"
: "=s"(warp_id)
: "v"(warp_id_vec)
:);
union_vec4_f16x2<Element> dO_reg[((WARP_M*kBlockN)/(32*32))*2];
union_vec4_f16x2<Element> O_reg[((WARP_M*kBlockN)/(32*32))*2];
for(int k_loop=0; k_loop<K/kBlockN; k_loop++) {
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
int do_block_buffer_load_global_offset = k_loop * kBlockN;
//read 32 * 128
prefetch_to_lds_gfx938<true, kBlockM, kBlockN, Element, ElementAccum, Is_even_MN, 1>(gdO, do_block_buffer_load_global_offset, dO_lds, binfo.actual_seqlen_q - m_block * kBlockM, warp_id);
prefetch_to_lds_gfx938<true, kBlockM, kBlockN, Element, ElementAccum, Is_even_MN, 1>(gO, do_block_buffer_load_global_offset, O_lds, binfo.actual_seqlen_q - m_block * kBlockM, warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
for(int i = 0; i < kBlockN / 32; ++i) {
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(dO_lds + i * 32 * 32), dO_reg[i * 2 + 0].f16, dO_reg[i * 2 + 1].f16, true);
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(O_lds + i * 32 * 32), O_reg[i * 2 + 0].f16, O_reg[i * 2 + 1].f16, true);
if constexpr (std::is_same_v<Element, half_t>) {
dO_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(dO_lds + i * 32 * 32, 0, 2, 1, 0);
dO_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(dO_lds + i * 32 * 32, 1024, 2, 1, 0);
O_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(O_lds + i * 32 * 32, 0, 2, 1, 0);
O_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(O_lds + i * 32 * 32, 1024, 2, 1, 0);
} else {
dO_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(dO_lds + i * 32 * 32, 0, 2, 1, 0);
dO_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(dO_lds + i * 32 * 32, 1024, 2, 1, 0);
O_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(O_lds + i * 32 * 32, 0, 2, 1, 0);
O_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(O_lds + i * 32 * 32, 1024, 2, 1, 0);
}
}
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < (kBlockN/32); ++head_dim_idx) {
#pragma unroll
for(int vec_id = 0; vec_id<4; vec_id++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
if (Is_even_MN || (m_block * kBlockM + min_tile_m*16 + (threadIdx.x & 15)) < binfo.actual_seqlen_q) {
dP_sum_cur[min_tile_m] += UpCast<Element,float,false>(dO_reg[head_dim_idx*2 + min_tile_m].f16[vec_id * 2 + min_tile_n]) * UpCast<Element,float,false>(O_reg[head_dim_idx*2 + min_tile_m].f16[vec_id * 2 + min_tile_n]);
}
}
}
}
}
}
#pragma unroll
for (int mi = 0; mi < (WARP_M/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
flash::SumOp<float> sum_op;
dP_sum_cur[mi*2 + min_tile_m] = flash::Allreduce<64>::run(dP_sum_cur[mi*2 + min_tile_m], sum_op) * params.p_dropout;
if ((threadIdx.x >> 4) == 0) {
dP_sum[mi*32 + min_tile_m * 16 + (threadIdx.x & 15)] = dP_sum_cur[mi*2 + min_tile_m];
}
}
}
}
\ No newline at end of file
......@@ -24,12 +24,15 @@
#include "static_switch.h"
#include "dot_do_o.h"
#include "dot_do_o_gfx938.h"
#include "dot_do_o_gfx946.h"
#include "prefetch.h"
#include "flash_singleton.h"
#include "flash_attention_dv_dk_bwd.h"
#include "flash_attention_dv_dk_bwd_gfx938.h"
#include "flash_attention_dv_dk_bwd_gfx946.h"
#include "flash_attention_dq_bwd.h"
#include "flash_attention_dq_bwd_gfx938.h"
#include "flash_attention_dq_bwd_gfx946.h"
using std::make_shared;
using std::shared_ptr;
......
#ifdef DEBUGING
#define print_qk(block_id_m, bidb, bidh) {\
int qk_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = qk_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
kq_ptr[offset] = s_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_softmax_rescale_o(block_id_m, bidb, bidh) { \
int qk_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = qk_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
s_ptr[offset] = s_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_dp(block_id_m, bidb, bidh) {\
int dp_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int dp_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + dp_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = dp_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
dp_ptr[offset] = dp_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_ds(block_id_m, bidb, bidh) {\
int ds_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int ds_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = ds_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
ds_ptr[offset] = dS_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#endif
template<class Element, class ElementAccum, bool Is_dropout, bool Is_causal , bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel, int kBlockM_, int kBlockN_, int K, int K_v, int kBlockK_, int WARP_M_, int WARP_N_, int STAGES, int USE_BSHD_LAYOUT, typename Params>
__forceinline__ __device__ void compute_dq_1colblock_gfx946(Params &params, int bidb, int bidh, int m_block
) {
#ifdef DEBUGING
ElementAccum * kq_ptr = static_cast<ElementAccum*>(params.kq_ptr);
ElementAccum * s_ptr = static_cast<ElementAccum*>(params.s_ptr);
ElementAccum * dp_ptr = static_cast<ElementAccum*>(params.dp_ptr);
ElementAccum * ds_ptr = static_cast<ElementAccum*>(params.ds_ptr);
#endif
Element* q_ptr = static_cast<Element*>(params.q_ptr);
Element* k_ptr = static_cast<Element*>(params.k_ptr);
Element* v_ptr = static_cast<Element*>(params.v_ptr);
Element* o_ptr = static_cast<Element*>(params.o_ptr);
Element* dq_ptr = static_cast<Element*>(params.dq_ptr);
Element* dk_ptr = static_cast<Element*>(params.dk_ptr);
Element* dv_ptr = static_cast<Element*>(params.dv_ptr);
Element* do_ptr = static_cast<Element*>(params.do_ptr);
ElementAccum* softmax_lse_ptr = static_cast<ElementAccum*>(params.softmax_lse_ptr);
ElementAccum* dsoftmax_sum = static_cast<ElementAccum*>(params.dsoftmax_sum);
//flash-attention QK, kBlockN_==WARP_N_;
const int M_BLOCK_NUM = params.seqlen_q/kBlockM_;
const int N_BLOCK_NUM = params.seqlen_k/kBlockN_;
extern __shared__ Element smem[];
#if 1//defined(__gfx936__)
const bool Is_store_K = true;
const bool Is_preload_K = true;
const bool Is_preload_V = true;
#else
const bool Is_store_K = false;
const bool Is_preload_K = false;
const bool Is_preload_V = false;
#endif
const int K_prefetch_level = Is_preload_K ? 1 : 0;
const int V_prefetch_level = Is_preload_V ? 1 : 0;
const int Q_prefetch_level = 3;
Element* K_lds = (Element*)&(smem);
Element* Q_lds = (Element*)&(smem);
Element* dO_lds = (Element*)&(smem);
Element* V_lds = (Element*)&(smem) + kBlockN_* K;
int tidx = threadIdx.x;
int lane_id = threadIdx.x & 63; //lane id, 0-63
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (m_block < 0 || m_block * kBlockM_ >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM_ - params.window_size_left) / kBlockN_);
const int n_block_max = (!Is_causal && !Is_local) ? ceil_div(binfo.actual_seqlen_k, kBlockN_) : std::min(ceil_div(binfo.actual_seqlen_k, kBlockN_), flash::ceil_div((m_block + 1) * kBlockM_ + params.window_size_right, kBlockN_));
int seqlen_q_stride = params.q_row_stride;
int seqlen_k_stride = params.k_row_stride;
int seqlen_v_stride = params.v_row_stride;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
int seqlen_dq_stride = params.dq_row_stride;
// We move K and V to the last block.
const int row_offset_q = binfo.q_offset1(params.q_batch_stride, params.q_row_stride, bidb) + binfo.q_offset2(params.q_head_stride,bidh) + m_block * kBlockM_ * seqlen_q_stride;
const int row_offset_k = binfo.k_offset1(params.k_batch_stride, params.k_row_stride, bidb) + binfo.k_offset2(params.k_head_stride,bidh/params.h_h_k_ratio) + (n_block_max - 1) * kBlockN_ * seqlen_k_stride;
const int row_offset_v = binfo.k_offset1(params.v_batch_stride, params.v_row_stride, bidb) + binfo.k_offset2(params.v_head_stride,bidh/params.h_h_k_ratio) + (n_block_max - 1) * kBlockN_ * seqlen_v_stride;
const int row_offset_dO = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + m_block * kBlockM_ * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + m_block * kBlockM_ * seqlen_o_stride;
const int row_offset_dq = binfo.q_offset1(params.dq_batch_stride, params.dq_row_stride, bidb) + binfo.q_offset2(params.dq_head_stride,bidh) + m_block * kBlockM_ * seqlen_dq_stride;
const int row_offset_lse = params.cu_seqlens_q == nullptr ? (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM_ : bidh * params.total_q + binfo.sum_s_q + m_block * kBlockM_;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM_;
auto gQ = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(q_ptr) + row_offset_q, seqlen_q_stride);
auto gK = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(k_ptr) + row_offset_k, seqlen_k_stride);
auto gV = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(v_ptr) + row_offset_v, seqlen_v_stride);
auto gdO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_dO, seqlen_do_stride);
Element * gO = reinterpret_cast<Element *>(o_ptr) + row_offset_o;
dq_ptr = reinterpret_cast<Element *>(dq_ptr) + row_offset_dq;
ElementAccum *gLSE = reinterpret_cast<ElementAccum *>(softmax_lse_ptr) + row_offset_lse;
ElementAccum *gdPsum = reinterpret_cast<ElementAccum *>(dsoftmax_sum) + row_offset_dpsum;
constexpr int n_masking_steps = (!Is_causal && !Is_local)
? 1
: ((Is_even_MN && Is_causal) ? flash::ceil_div(kBlockM_, kBlockN_) : flash::ceil_div(kBlockM_, kBlockN_) + 1);
// int warp_id =0;
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
union_vec4_f16x2<Element> q_reg[(K/kBlockK_)*((WARP_M_*kBlockK_)/(32*32))*2];
union_vec4_f16x2<Element> dO_reg[(K_v/kBlockK_)*((WARP_M_*kBlockK_)/(32*32))*2];
union_vec4_fp32 acc_dq[(K/kBlockK_) * ((WARP_M_/32)*(kBlockK_/32))][4]={0};
float lse[WARP_M_/16];
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
int lse_idx = warp_id*WARP_M_ + mi*32 + (lane_id & 15) + min_tile_m * 16;
lse[mi*2 + min_tile_m] = (Is_even_MN || lse_idx < binfo.actual_seqlen_q - m_block * kBlockM_) ? gLSE[lse_idx] : INFINITY;
}
}
float dP_sum_reg[WARP_M_/16];
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
for(int min_tile_m = 0; min_tile_m<2; min_tile_m++) {
int dP_sum_idx = warp_id*WARP_M_ + mi*32 + (lane_id & 15) + min_tile_m * 16;
dP_sum_reg[mi*2 + min_tile_m] = gdPsum[dP_sum_idx];
}
}
prefetch_to_vgpr_gfx938<true, kBlockM_, K, Element, ElementAccum, Is_even_MN>(gQ, Q_lds, q_reg, (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
prefetch_to_vgpr_gfx938<true, kBlockM_, K, Element, ElementAccum, Is_even_MN>(gdO, dO_lds, dO_reg, (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
if constexpr (Is_preload_V){
prefetch_to_lds_gfx938<true, kBlockN_, K_v, Element, ElementAccum, Is_even_MN>(gV, 0, V_lds, (binfo.actual_seqlen_k - (n_block_max - 1) * kBlockN_), warp_id);
}
if constexpr (Is_preload_K){
prefetch_to_lds_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gK, 0, K_lds, (binfo.actual_seqlen_k - (n_block_max - 1) * kBlockN_), warp_id);
}
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
for (int n_block = n_block_max - 1; n_block >= n_block_min ; --n_block) {
union_vec4_f16x2<Element> v_reg[((WARP_N_*kBlockK_)/(32*32))*2];
union_vec4_fp32 dp_reg[(WARP_M_/32)*(kBlockN_/32)][4]= {0};
//dP gemm
gemm_tt_kq_gfx938<false, Is_preload_K, Is_even_MN, 3, V_prefetch_level, K_v, kBlockM_, kBlockN_, kBlockK_, WARP_N_, WARP_N_, STAGES, Element>(
gdO, gV, dO_lds, V_lds, (binfo.actual_seqlen_q - m_block * kBlockM_), (binfo.actual_seqlen_k - n_block * kBlockN_), dO_reg, v_reg, dp_reg, warp_id, seqlen_do_stride, seqlen_v_stride
);
#ifdef DEBUGING
print_dp(m_block, bidb, bidh);
#endif
union_vec4_f16x2<Element> k_reg[((WARP_M_*kBlockK_)/(32*32))*2];
//c mini tile is 32*32
union_vec4_fp32 s_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
//qk gemm
gemm_tt_kq_gfx938<Is_store_K, false, Is_even_MN, Q_prefetch_level, K_prefetch_level, K, kBlockM_, kBlockN_, kBlockK_, WARP_N_, WARP_N_, STAGES, Element>(
gQ, gK, Q_lds, K_lds, (binfo.actual_seqlen_q - m_block * kBlockM_), (binfo.actual_seqlen_k - n_block * kBlockN_), q_reg, k_reg, s_reg, warp_id, seqlen_q_stride, seqlen_k_stride
);
*(uint64_t*)&gV -= ((kBlockN_ * seqlen_v_stride) * sizeof(Element));
if (Is_preload_V && n_block > n_block_min){
prefetch_to_lds_gfx938<true, kBlockN_, K_v, Element, ElementAccum, Is_even_MN>(gV, 0, V_lds, (binfo.actual_seqlen_k - (n_block - 1) * kBlockN_), warp_id);
}
apply_mask_bwd_gfx938<Is_even_MN, Is_local ? 3 : (Is_causal ? 1 : 0)>(s_reg, binfo.actual_seqlen_q - m_block * kBlockM_ - warp_id * 32, binfo.actual_seqlen_k - n_block * kBlockN_, (m_block * kBlockM_ + warp_id * 32) - (n_block * kBlockN_), params.window_size_left, params.window_size_right);
#ifdef DEBUGING
print_qk(m_block, bidb, bidh);
#endif
scale_apply_exp2_bwd_seq_q_major</*scale_max=*/false, WARP_M_, kBlockN_, union_vec4_fp32, ElementAccum>(s_reg, lse, params.scale_softmax_log2);
#ifdef DEBUGING
print_softmax_rescale_o(m_block, bidb, bidh)
#endif
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
// return p * (dp - d);
};
union_vec4_fp32 dS_reg[(WARP_M_/32)*(kBlockN_/32)][4];
#pragma unroll
for (int ni = 0; ni < (kBlockN_/32); ++ni) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for (int mi = 0; mi < (WARP_M_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
// result register ds_reg reuse dp_reg
dS_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx] = pointwise_mult(
s_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx],
dp_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx],
dP_sum_reg[min_tile_m + mi*2]);
// dS_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx] = s_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
#ifdef DEBUGING
print_ds(m_block, bidb, bidh);
#endif
union_vec4_f16x2<Element> dS_reg_fp16[(WARP_M_/32)*(kBlockN_/32)*2];
convert_pk_type_gfx938<WARP_M_, kBlockN_, Element>(dS_reg_fp16, dS_reg);
{
//dq gemm, K*dS
gpu_gemm_B_in_reg_gfx946<Is_store_K , false , Is_even_MN, K, kBlockK_, kBlockM_, kBlockN_, kBlockK_, WARP_M_, 2, Element>(gK, gK, K_lds, dS_reg_fp16, acc_dq, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_k_stride);
}
*(uint64_t*)&gK -= ((kBlockN_ * seqlen_k_stride) * sizeof(Element));
// if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0){
// printf("(binfo.actual_seqlen_k - n_block * kBlockN_) = %d\n", (binfo.actual_seqlen_k - n_block * kBlockN_));
// }
#if 1//defined(__gfx936__)
{
__syncthreads();
if (Is_preload_K && n_block > n_block_min){
prefetch_to_lds_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gK, 0, K_lds, (binfo.actual_seqlen_k - (n_block - 1) * kBlockN_), warp_id);
}
}
#else
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
#endif
}
#if 1
//这是正常的MLS+ds_read_matrix的layout
{
dq_ptr = dq_ptr + binfo.q_offset1(params.dq_batch_stride, params.dq_row_stride, bidb) + binfo.q_offset2(params.dq_head_stride,bidh);
auto gdQ = tcp_cache_swizzle_func<K_v, Element>(dq_ptr);
int dq_lane_seq_idx = (lane_id >> 4);
int dq_lane_head_dim_idx = (lane_id & 15);
int dq_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K_v/kBlockK_); k_loop++) {
#pragma unroll
for(int warp_m_idx=0; warp_m_idx<(WARP_M_/32); warp_m_idx++) {
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
int v_offset = dq_lane_head_dim_idx * seqlen_dq_stride + dq_lane_seq_idx * 4;
int s_offset = (min_tile_m * seqlen_dq_stride * 16 + vec_index % 2 * 2 + vec_index / 2 * 16) + (k_tile_idx*32) + ((warp_id*WARP_M_ + warp_m_idx*32) * seqlen_dq_stride) + (k_loop * kBlockK_ + m_block * kBlockM_ * seqlen_dq_stride);
int known_offset = 0;
vec2_Element<Element> v_data;
v_data[0] = DownCast<float,Element,true>(acc_dq[k_loop * ((WARP_M_/32)*(kBlockK_/32)) + (warp_m_idx*(kBlockK_/32) + k_tile_idx)][min_tile_m*2 + vec_index / 2].f32[vec_index % 2 * 2] * params.scale_softmax_rp_dropout);
v_data[1] = DownCast<float,Element,true>(acc_dq[k_loop * ((WARP_M_/32)*(kBlockK_/32)) + (warp_m_idx*(kBlockK_/32) + k_tile_idx)][min_tile_m*2 + vec_index / 2].f32[vec_index % 2 * 2 + 1] * params.scale_softmax_rp_dropout);
if (Is_even_MN || min_tile_m*16 + (warp_id*WARP_M_ + warp_m_idx*32) + m_block * kBlockM_ + dq_lane_head_dim_idx < binfo.actual_seqlen_q){
inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdQ, s_offset, /* immediate integer */known_offset);
}
}
}
}
}
}
}
#endif
}
#undef print_qk
#undef print_softmax_rescale_o
#undef print_dp
#undef print_ds
......@@ -234,7 +234,7 @@ __forceinline__ __device__ void compute_dk_dv_1colblock_gfx938(Params &params, i
auto d1 = (!Is_dropout || p[1] >= 0 ? dp[1] - d[1] : d[1]);
// return vec2_fp32{p[0]*d0,p[1]*d1};
// return __builtin_hcu_pk_mul_f32(p, vec2_fp32{d0, d1});
return hcu_pk_mul_f32(p, vec2_fp32{d0, d1});
return __builtin_hcu_pk_mul_f32(p, vec2_fp32{d0, d1});
};
#else
auto pointwise_mult = [](float p, float dp, float d) {
......@@ -295,9 +295,12 @@ __forceinline__ __device__ void compute_dk_dv_1colblock_gfx938(Params &params, i
//提前读取V到vgpr
prefetch_to_vgpr_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gV, V_lds, v_reg, (binfo.actual_seqlen_k - n_block * kBlockN_), warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
//提前读取K到vgpr
prefetch_to_vgpr_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gK, K_lds, k_reg, (binfo.actual_seqlen_k - n_block * kBlockN_), warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
//提前读取Q到lds
if constexpr (Is_preload_Q){
......@@ -307,8 +310,8 @@ __forceinline__ __device__ void compute_dk_dv_1colblock_gfx938(Params &params, i
if constexpr (Is_preload_dO){
prefetch_to_lds_gfx938<true, kBlockM_, K_v, Element, ElementAccum, Is_even_MN>(gdO, 0, dO_lds, (binfo.actual_seqlen_q - (m_block_max - 1) * kBlockM_), warp_id);
}
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
union_vec4_fp32 acc_dv[(K_v/kBlockK_) * ((WARP_N_/32)*(kBlockK_/32))][4]={0};
......
#define print_kq(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int qk_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int qk_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_id*WARP_N_ + qk_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + qk_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + qk_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = qk_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
kq_ptr[offset + reg_id *params.seqlen_k] = (s_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_m*2 + min_tile_n].f32[reg_id]); \
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_softmax_rescale_o(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int s_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int s_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int s_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + s_warp_n_id*WARP_N_ + s_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + s_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + s_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = s_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
s_ptr[offset + reg_id * params.seqlen_k] = (s_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f32[reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_ds(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int ds_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int ds_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int ds_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_id*WARP_N_ + ds_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + ds_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + ds_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = ds_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
ds_ptr[offset + reg_id * params.seqlen_k] = (dS_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f32[reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_ds_fp16(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int ds_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int ds_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int ds_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_id*WARP_N_ + ds_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + ds_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + ds_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = ds_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
ds_ptr[offset + reg_id * params.seqlen_k] = UpCast<Element,float,true>(dS_reg_fp16[(m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx)*2 + min_tile_m].f16[min_tile_n*4 + reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
// #define print_ds_fp16(block_id_m, bidb, bidh) { \
// __builtin_amdgcn_sched_barrier(0);\
// __builtin_amdgcn_s_waitcnt(0);\
// __syncthreads();\
// __builtin_amdgcn_sched_barrier(0);\
// int ds_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
// int ds_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
// int ds_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_id*WARP_N_ + ds_warp_m_id*WARP_M_*params.seqlen_k; \
// for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
// for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
// for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
// for(int reg_id=0; reg_id<4; reg_id++) { \
// for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
// for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
// if(((n_block*kBlockN_ + ds_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
// ((block_id_m*kBlockM_ + ds_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
// int offset = ds_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
// ds_ptr[offset + reg_id * 8 * params.seqlen_k] = UpCast<Element,float>(dS_reg_fp16[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f16[reg_id]);\
// } \
// } \
// } \
// } \
// } \
// } \
// } \
// __builtin_amdgcn_sched_barrier(0);\
// __builtin_amdgcn_s_waitcnt(0);\
// __syncthreads();\
// __builtin_amdgcn_sched_barrier(0);\
// }
#define print_dp(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int dp_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int dp_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int dp_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + dp_warp_n_id*WARP_N_ + dp_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) {\
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + dp_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + dp_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = dp_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
dp_ptr[offset + reg_id * params.seqlen_k] = (dp_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32) + m_idx*(WARP_N_/32) + n_idx][min_tile_m*2 + min_tile_n].f32[reg_id]);\
} \
} \
} \
} \
} \
}\
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
/*
load q/k:累加方向为主序方向
ps: 在offset传0的情况下,T和R的取值似乎没有影响!?
调用matrix_load_32x32_b16:
R=0: offset in column direction
load Q: T=1: row major
load K: T=0: column major
m_ab=1: 线程数据在主序方向拼接
调用ds_read_matrix_trans_format(和m_ab保持一致):
element:0x2 row:0x2 col:0x1 alt:0x0
load v:累加方向为非主序方向
调用matrix_load_32x32_b16:
R=0: offset in column direction
T=1: row major
m_ab=0: 线程数据在非主序方向拼接
调用ds_read_matrix_format(和m_ab保持一致)
*/
template<class Element, class ElementAccum, bool Is_dropout, bool Is_causal , bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, int kBlockM_, int kBlockN_, int K, int K_v, int kBlockK_, int WARP_M_, int WARP_N_, bool USE_BSHD_LAYOUT, typename Params>
__forceinline__ __device__ void compute_dk_dv_1colblock_gfx946(Params &params, int bidb, int bidh, int n_block
) {
#ifdef DEBUGING
ElementAccum * kq_ptr = static_cast<ElementAccum*>(params.kq_ptr);
ElementAccum * s_ptr = static_cast<ElementAccum*>(params.s_ptr);
ElementAccum * dp_ptr = static_cast<ElementAccum*>(params.dp_ptr);
ElementAccum * ds_ptr = static_cast<ElementAccum*>(params.ds_ptr);
#endif
Element* q_ptr = static_cast<Element*>(params.q_ptr);
Element* k_ptr = static_cast<Element*>(params.k_ptr);
Element* v_ptr = static_cast<Element*>(params.v_ptr);
Element* o_ptr = static_cast<Element*>(params.o_ptr);
Element* p_ptr = static_cast<Element*>(params.p_ptr);
// Element* dq_ptr = static_cast<Element*>(params.dq_ptr);
Element* dk_ptr = static_cast<Element*>(params.dk_ptr);
Element* dv_ptr = static_cast<Element*>(params.dv_ptr);
Element* do_ptr = static_cast<Element*>(params.do_ptr);
ElementAccum* softmax_lse_ptr = static_cast<ElementAccum*>(params.softmax_lse_ptr);
ElementAccum* dsoftmax_sum = static_cast<ElementAccum*>(params.dsoftmax_sum);
//flash-attention QK, kBlockN_==WARP_N_;
// static_assert(kBlockM_=WARP_M_,"Error: kBlockM_ not equal WARP_M_!");
const int WARP_NUM = (kBlockM_*kBlockN_)/(WARP_M_*WARP_N_);
const int M_BLOCK_NUM = params.seqlen_q/kBlockM_;
const int N_BLOCK_NUM = params.seqlen_k/kBlockN_;
extern __shared__ Element smem[];
int K_lds_ratio;
// 0表示k不预取;1表示k预取一半到寄存器;2表示一半到寄存器,一半到LDS;3表示全部预取到寄存器
const int K_prefetch_level = 3;
const int STAGES = 2;
const bool Is_store_Q = true;
const bool Is_store_dO = true;
const bool Is_preload_Q = true;
const bool Is_preload_dO = true;
const int dP_dO_prefetch_level = Is_store_dO ? 1 : 0;
const int Q_prefetech_level = Is_preload_Q ? 1 : 0;
if constexpr (K_prefetch_level == 2){
K_lds_ratio = (K / kBlockK_) / 2;
} else {
K_lds_ratio = (K_prefetch_level == 3) ? 0 : STAGES;
}
Element* K_lds = (Element*)&(smem);
Element* dO_lds = K_lds + kBlockN_ * kBlockK_ * K_lds_ratio;
Element* V_lds = K_prefetch_level == 2 ? dO_lds : K_lds;
Element* Q_lds = Is_store_Q ? dO_lds + kBlockM_ * K_v : dO_lds;
#if 0//defined(__gfx936__)
auto pointwise_mult = [](vec2_fp32 p, vec2_fp32 dp, vec2_fp32 d) {
auto d0 = (!Is_dropout || p[0] >= 0 ? dp[0] - d[0] : d[0]);
auto d1 = (!Is_dropout || p[1] >= 0 ? dp[1] - d[1] : d[1]);
// return vec2_fp32{p[0]*d0,p[1]*d1};
// return __builtin_hcu_pk_mul_f32(p, vec2_fp32{d0, d1});
return __builtin_hcu_v_pk_mul_f32(p, vec2_fp32{d0, d1});
};
#else
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
};
#endif
int tidx = threadIdx.x;
int lane_id = threadIdx.x & 63; //lane id, 0-63
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (n_block * kBlockN_ >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0) return;
const int m_block_min = (!Is_causal && !Is_local) ? 0 : std::max(0, (n_block * kBlockN_ - params.window_size_right) / kBlockM_);
const int m_block_max = !Is_local ? ceil_div(binfo.actual_seqlen_q, kBlockM_) : std::min(ceil_div(binfo.actual_seqlen_q, kBlockM_), ceil_div((n_block + 1) * kBlockN_ + params.window_size_left, kBlockM_));
int seqlen_q_stride = params.q_row_stride;
int seqlen_k_stride = params.k_row_stride;
int seqlen_v_stride = params.v_row_stride;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
int seqlen_dk_stride = params.dk_row_stride;
int seqlen_dv_stride = params.dv_row_stride;
// We move K and V to the last block.
const int row_offset_q = binfo.q_offset1(params.q_batch_stride, params.q_row_stride, bidb) + binfo.q_offset2(params.q_head_stride,bidh) + (m_block_max - 1) * kBlockM_* seqlen_q_stride;
const int row_offset_k = binfo.k_offset1(params.k_batch_stride, params.k_row_stride, bidb) + binfo.k_offset2(params.k_head_stride,bidh/params.h_h_k_ratio) + n_block * kBlockN_ * seqlen_k_stride;
const int row_offset_v = binfo.k_offset1(params.v_batch_stride, params.v_row_stride, bidb) + binfo.k_offset2(params.v_head_stride,bidh/params.h_h_k_ratio) + n_block * kBlockN_ * seqlen_v_stride;
const int row_offset_dO = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + (m_block_max - 1) * kBlockM_ * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + (m_block_max - 1) * kBlockM_ * seqlen_o_stride;
// const int row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + (m_block_max - 1) * kBlockM_;
const int row_offset_lse = params.cu_seqlens_q == nullptr ? (bidb * params.h + bidh) * params.seqlen_q + (m_block_max - 1) * kBlockM_ : bidh * params.total_q + binfo.sum_s_q + (m_block_max - 1) * kBlockM_;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + (m_block_max - 1) * kBlockM_;
auto gQ = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(q_ptr) + row_offset_q, seqlen_q_stride);
auto gK = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(k_ptr) + row_offset_k, seqlen_k_stride);
auto gV = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(v_ptr) + row_offset_v, seqlen_v_stride);
auto gdO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_dO, seqlen_do_stride);
Element * gO = reinterpret_cast<Element *>(o_ptr) + row_offset_o;
ElementAccum *gLSE = reinterpret_cast<ElementAccum *>(softmax_lse_ptr) + row_offset_lse;
ElementAccum *gdPsum = reinterpret_cast<ElementAccum *>(dsoftmax_sum) + row_offset_dpsum;
constexpr int m_masking_steps = (!Is_causal && !Is_local)
? 0
: flash::ceil_div(kBlockN_, kBlockM_);
/***************************************************************************************************************************/
// int warp_id =0;
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
union_vec4_f16x2<Element> k_reg[(K/kBlockK_)*((WARP_N_*kBlockK_)/(32*32))*2/((K_prefetch_level == 3)? 1 : 2)]; //ds_read mini size is 32*32,2 is seq, 4 is head dim
union_vec4_f16x2<Element> v_reg[(K_v/kBlockK_)*((WARP_N_*kBlockK_)/(32*32))*2];
//提前读取V到vgpr
prefetch_to_vgpr_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gV, V_lds, v_reg, (binfo.actual_seqlen_k - n_block * kBlockN_), warp_id);
//提前读取K到vgpr
prefetch_to_vgpr_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gK, K_lds, k_reg, (binfo.actual_seqlen_k - n_block * kBlockN_), warp_id);
//提前读取Q到lds
if constexpr (Is_preload_Q){
prefetch_to_lds_gfx938<true, kBlockM_, K, Element, ElementAccum, Is_even_MN>(gQ, 0, Q_lds, (binfo.actual_seqlen_q - (m_block_max - 1) * kBlockM_), warp_id);
}
//提前读取dO到lds
if constexpr (Is_preload_dO){
prefetch_to_lds_gfx938<true, kBlockM_, K_v, Element, ElementAccum, Is_even_MN>(gdO, 0, dO_lds, (binfo.actual_seqlen_q - (m_block_max - 1) * kBlockM_), warp_id);
}
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
union_vec4_fp32 acc_dv[(K_v/kBlockK_) * ((WARP_N_/32)*(kBlockK_/32))][4]={0};
union_vec4_fp32 acc_dk[(K/kBlockK_) * ((WARP_N_/32)*(kBlockK_/32))][4]={0};
for (int m_block = m_block_max - 1; m_block >= m_block_min; --m_block) {
union_vec4_f16x2<Element> q_reg[((WARP_M_*kBlockK_)/(32*32))*2];
//c mini tile is 32*32
union_vec4_fp32 s_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
/*
qk gemm
结果矩阵layout:
0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48 0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48
1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49 1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49
...
0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48 0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48
1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49 1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49
*/
gemm_tt_kq_gfx938<Is_store_Q, Is_preload_dO, Is_even_MN, K_prefetch_level, Q_prefetech_level, K, kBlockN_, kBlockM_, kBlockK_, WARP_N_, WARP_M_, STAGES, Element>(
gK, gQ, K_lds, Q_lds, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), k_reg, q_reg, s_reg, warp_id, seqlen_k_stride, seqlen_q_stride);
/*
lse layout:
4 warp:
32
32
32
32
因为warp在seqlen_k维度,所以不区分warp
每16个thread持有相同的lse,所以需要/4
*/
float lse[kBlockM_/4];
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
#pragma unroll
for (int mi = 0; mi < (kBlockM_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
const int lse_idx = mi*32 + min_tile_m * 16 + (lane_id >> 4)*4 + vec_idx;
lse[(mi*2 + min_tile_m)*4 + vec_idx] = Is_even_MN || lse_idx < binfo.actual_seqlen_q - m_block * kBlockM_ ? gLSE[lse_idx] : INFINITY;
}
}
}
apply_mask_bwd_gfx938<Is_even_MN, Is_local ? 3 : (Is_causal ? 2 : 0)>(s_reg, binfo.actual_seqlen_k - n_block * kBlockN_ - warp_id * 32, binfo.actual_seqlen_q - m_block * kBlockM_, (n_block * kBlockN_ + warp_id * 32) - m_block * kBlockM_, params.window_size_right, params.window_size_left);
#ifdef DEBUGING
print_kq(m_block, bidb, bidh);
#endif
//do . o后在headdim维度reduce求和,读取方式和lse一样,因为pad了,所以无需边界判断
float dP_sum_reg[kBlockM_/4];
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
#pragma unroll
for (int mi = 0; mi < (kBlockM_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
const int dPsum_idx = mi*32 + min_tile_m * 16 + (lane_id >> 4)*4 + vec_idx;
dP_sum_reg[(mi*2 + min_tile_m)*4 + vec_idx] = gdPsum[dPsum_idx];
}
}
}
{
scale_apply_exp2_bwd</*scale_max=*/false, kBlockM_, WARP_N_>(s_reg, lse, params.scale_softmax_log2);
}
#ifdef DEBUGING
print_softmax_rescale_o(m_block, bidb, bidh);
#endif
// //TODO:drop
union_vec4_f16x2<Element> p_reg[(kBlockM_/32)*(WARP_N_/32)*2];
// convert_pk_type<kBlockM_, WARP_N_, Element>(p_reg, s_reg);
convert_pk_type_gfx938<kBlockM_, WARP_N_, Element>(p_reg, s_reg);
//QK(seq_q, seq_kv), seq_q is continuous, seq_kv is not continuous
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
{
//dv gemm,dO*P
gpu_gemm_B_in_reg_gfx946<Is_preload_dO, Is_store_dO, Is_even_MN, K_v, kBlockK_, kBlockN_, kBlockM_, kBlockK_, WARP_N_, 2, Element>(gdO, gQ, dO_lds, p_reg, acc_dv, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_do_stride);
}
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
union_vec4_f16x2<Element> dO_reg[((WARP_M_*kBlockK_)/(32*32))*2];
union_vec4_fp32 dp_reg[(WARP_N_/32)*(kBlockM_/32)][4]= {0};
{
// dP gemm dO * V
gemm_tt_kq_gfx938<Is_store_dO, false, Is_even_MN, 3, dP_dO_prefetch_level, K_v, kBlockN_, kBlockM_, kBlockK_, WARP_N_, WARP_M_, STAGES, Element>(
gV, gdO, V_lds, dO_lds, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), v_reg, dO_reg, dp_reg, warp_id, seqlen_v_stride, seqlen_do_stride);
}
#ifdef DEBUGING
print_dp(m_block, bidb, bidh);
#endif
union_vec4_fp32 dS_reg[(WARP_N_/32)*(kBlockM_/32)][4];
#pragma unroll
for (int mi = 0; mi < (kBlockM_/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for (int ni = 0; ni < (WARP_N_/32); ++ni) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#if 0//defined(__gfx936__)
#pragma unroll
for(int vec_idx=0; vec_idx<2; vec_idx++) {
// result register ds_reg reuse dp_reg
dS_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx] = pointwise_mult(
s_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx],
dp_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx],
vec2_fp32{gdPsum[vec_idx*16 + mi*8*4 + ((lane_id >> 4)*2) + min_tile_m], gdPsum[vec_idx*16 + mi*8*4 + ((lane_id >> 4)*2) + min_tile_m + 8]});
}
#else
#pragma unroll
for(int vec_idx=0; vec_idx<4; vec_idx++) {
// result register ds_reg reuse dp_reg
dS_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].f32[vec_idx] = pointwise_mult(
s_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].f32[vec_idx],
dp_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].f32[vec_idx],
dP_sum_reg[min_tile_m*4 + vec_idx]);
}
#endif
}
}
}
}
// #ifdef DEBUGING
// print_ds(m_block, bidb, bidh);
// #endif
union_vec4_f16x2<Element> dS_reg_fp16[(WARP_N_/32)*(kBlockM_/32)*2];
convert_pk_type_gfx938<kBlockM_, WARP_N_, Element>(dS_reg_fp16, dS_reg);
// #ifdef DEBUGING
// print_ds_fp16(m_block, bidb, bidh);
// #endif
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
{
//dk gemm, Q*dS
gpu_gemm_B_in_reg_gfx946<Is_store_Q , false, Is_even_MN, K, kBlockK_, kBlockN_, kBlockM_, kBlockK_, WARP_N_, 2, Element>(gQ, gdO, Q_lds, dS_reg_fp16, acc_dk, (binfo.actual_seqlen_k - n_block * kBlockN_), (binfo.actual_seqlen_q - m_block * kBlockM_), warp_id, seqlen_q_stride);
}
gLSE = gLSE + (-int(kBlockM_));
gdPsum = gdPsum - kBlockM_;
*(uint64_t*)&gQ -= ((kBlockM_ * seqlen_q_stride) * sizeof(Element));
*(uint64_t*)&gdO -= ((kBlockM_ * seqlen_do_stride) * sizeof(Element));
{
__syncthreads();
if (Is_preload_Q && m_block > m_block_min){
prefetch_to_lds_gfx938<true, kBlockM_, K, Element, ElementAccum, Is_even_MN>(gQ, 0, Q_lds, (binfo.actual_seqlen_q - (m_block - 1) * kBlockM_), warp_id);
}
// __syncthreads();
if (Is_preload_dO && m_block > m_block_min){
prefetch_to_lds_gfx938<true, kBlockM_, K_v, Element, ElementAccum, Is_even_MN>(gdO, 0, dO_lds, (binfo.actual_seqlen_q - (m_block - 1) * kBlockM_), warp_id);
}
}
}
#if 1
//这是正常的MLS+ds_read_matrix的layout
{
// dv_ptr = dv_ptr + binfo.k_offset1(params.dv_batch_stride, params.dv_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dv_head_stride,bidh);
dv_ptr = dv_ptr + binfo.k_offset1_write(params.dv_batch_stride, params.dv_row_stride, bidb) + binfo.k_offset2(params.dv_head_stride,bidh);
auto gdV = tcp_cache_swizzle_func<K_v, Element>(dv_ptr);
int dv_lane_seq_idx = (lane_id >> 4);
int dv_lane_head_dim_idx = (lane_id & 15);
int dv_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K_v/kBlockK_); k_loop++) {
#pragma unroll
for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
int v_offset = dv_lane_head_dim_idx * seqlen_dv_stride + dv_lane_seq_idx * 4;
int s_offset = (min_tile_n * seqlen_dv_stride * 16 + vec_index % 2 * 2 + vec_index / 2 * 16) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dv_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dv_stride);
int known_offset = 0;
vec2_Element<Element> v_data;
v_data[0] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + vec_index / 2].f32[vec_index % 2 * 2]);
v_data[1] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + vec_index / 2].f32[vec_index % 2 * 2 + 1]);
if (Is_even_MN || min_tile_n*16 + (warp_id*WARP_N_ + warp_n_idx*32) + n_block * kBlockN_ + dv_lane_head_dim_idx < binfo.actual_seqlen_k){
inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdV, s_offset, /* immediate integer */known_offset);
}
}
}
}
}
}
}
#endif
#if 1
//这是正常的MLS+ds_read_matrix的layout
{
dk_ptr = dk_ptr + binfo.k_offset1_write(params.dk_batch_stride, params.dk_row_stride, bidb) + binfo.k_offset2(params.dk_head_stride,bidh);
auto gdK = tcp_cache_swizzle_func<K_v, Element>(dk_ptr);
int dk_lane_seq_idx = (lane_id >> 4);
int dk_lane_head_dim_idx = (lane_id & 15);
int dk_global_addr_offset=0;
#pragma unroll
for(int k_loop = 0; k_loop<(K_v/kBlockK_); k_loop++) {
#pragma unroll
for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
#pragma unroll
for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int vec_index=0; vec_index<4; vec_index++) {
int v_offset = dk_lane_head_dim_idx * seqlen_dk_stride + dk_lane_seq_idx * 4;
int s_offset = (min_tile_n * seqlen_dk_stride * 16 + vec_index % 2 * 2 + vec_index / 2 * 16) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dk_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dk_stride);
int known_offset = 0;
vec2_Element<Element> v_data;
v_data[0] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + vec_index / 2].f32[vec_index % 2 * 2] * params.scale_softmax_rp_dropout);
v_data[1] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + vec_index / 2].f32[vec_index % 2 * 2 + 1] * params.scale_softmax_rp_dropout);
if (Is_even_MN || min_tile_n*16 + (warp_id*WARP_N_ + warp_n_idx*32) + n_block * kBlockN_ + dk_lane_head_dim_idx < binfo.actual_seqlen_k){
inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdK, s_offset, /* immediate integer */known_offset);
}
}
}
}
}
}
}
#endif
// #if 1
// {
// // dv_ptr = dv_ptr + binfo.k_offset1(params.dv_batch_stride, params.dv_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dv_head_stride,bidh);
// dv_ptr = dv_ptr + binfo.k_offset1_write(params.dv_batch_stride, params.dv_row_stride, bidb) + binfo.k_offset2(params.dv_head_stride,bidh);
// auto gdV = tcp_cache_swizzle_func<K_v, Element>(dv_ptr);
// int dv_lane_seq_idx = (lane_id >> 4);
// int dv_lane_head_dim_idx = (lane_id & 15);
// int dv_global_addr_offset=0;
// #pragma unroll
// for(int k_loop = 0; k_loop<(K_v/kBlockK_); k_loop++) {
// #pragma unroll
// for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
// #pragma unroll
// for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
// #pragma unroll
// for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
// #pragma unroll
// for(int vec_index=0; vec_index<4; vec_index++) {
// int v_offset = dv_lane_head_dim_idx * seqlen_dv_stride + dv_lane_seq_idx * 8;
// int s_offset = (min_tile_n * seqlen_dv_stride * 16 + vec_index * 2) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dv_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dv_stride);
// int known_offset = 0;
// vec2_Element<Element> v_data;
// v_data[0] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 0].f32[vec_index]);
// v_data[1] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 1].f32[vec_index]);
// if (Is_even_MN || min_tile_n*16 + (warp_id*WARP_N_ + warp_n_idx*32) + n_block * kBlockN_ + dv_lane_head_dim_idx < binfo.actual_seqlen_k){
// inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdV, s_offset, /* immediate integer */known_offset);
// }
// }
// }
// }
// }
// }
// }
// #endif
// // //test only
// // {
// // // dv_ptr = dv_ptr + binfo.k_offset1(params.dv_batch_stride, params.dv_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dv_head_stride,bidh);
// // dv_ptr = dv_ptr + binfo.k_offset1_write(params.dv_batch_stride, params.dv_row_stride, bidb) + binfo.k_offset2(params.dv_head_stride,bidh);
// // auto gdV = tcp_cache_swizzle_func<K_v, Element>(dv_ptr);
// // int dv_lane_seq_idx = (lane_id >> 4);
// // int dv_lane_head_dim_idx = (lane_id & 15);
// // int dv_global_addr_offset=0;
// // #pragma unroll
// // for(int k_loop = 0; k_loop<(K_v/kBlockK_); k_loop++) {
// // #pragma unroll
// // for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
// // #pragma unroll
// // for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
// // #pragma unroll
// // for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
// // #pragma unroll
// // for(int vec_index=0; vec_index<4; vec_index++) {
// // // int v_offset = dv_lane_head_dim_idx * seqlen_dv_stride + dv_lane_seq_idx * 8;
// // // int s_offset = (min_tile_n * seqlen_dv_stride * 16 + vec_index * 2) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dv_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dv_stride);
// // int v_offset = dv_lane_head_dim_idx * 2 + dv_lane_seq_idx * 4 * seqlen_dv_stride;
// // int s_offset = (min_tile_n * seqlen_dv_stride * 16 + vec_index * seqlen_dv_stride) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dv_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dv_stride);
// // int known_offset = 0;
// // vec2_Element<Element> v_data;
// // v_data[0] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + vec_index / 2].f32[vec_index % 2 * 2]);
// // v_data[1] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + vec_index / 2].f32[vec_index % 2 * 2 + 1]);
// // inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdV, s_offset, /* immediate integer */known_offset);
// // }
// // }
// // }
// // }
// // }
// // }
// {
// // dk_ptr = dk_ptr + binfo.k_offset1(params.dk_batch_stride, params.dk_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dk_head_stride,bidh);
// dk_ptr = dk_ptr + binfo.k_offset1_write(params.dk_batch_stride, params.dk_row_stride, bidb) + binfo.k_offset2(params.dk_head_stride,bidh);
// auto gdK = tcp_cache_swizzle_func<K, Element>(dk_ptr);
// int dk_lane_seq_idx = (lane_id >> 4);
// int dk_lane_head_dim_idx = (lane_id & 15);
// int dk_global_addr_offset=0;
// #pragma unroll
// for(int k_loop = 0; k_loop<(K/kBlockK_); k_loop++) {
// #pragma unroll
// for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
// #pragma unroll
// for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
// #pragma unroll
// for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
// #pragma unroll
// for(int vec_index=0; vec_index<4; vec_index++) {
// vec2_Element<Element> v_data;
// int v_offset = dk_lane_head_dim_idx * seqlen_dk_stride + dk_lane_seq_idx * 8;
// int s_offset = (min_tile_n * seqlen_dk_stride * 16 + vec_index * 2) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dk_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dk_stride);
// int known_offset = 0;
// v_data[0] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 0].f32[vec_index] * params.scale_softmax_rp_dropout);
// v_data[1] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 1].f32[vec_index] * params.scale_softmax_rp_dropout);
// if (Is_even_MN || min_tile_n*16 + (warp_id*WARP_N_ + warp_n_idx*32) + n_block * kBlockN_ + dk_lane_head_dim_idx < binfo.actual_seqlen_k){
// inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdK, s_offset, /* immediate integer */known_offset);
// }
// }
// }
// }
// }
// }
// }
// {
// // dv_ptr = dv_ptr + binfo.k_offset1(params.dv_batch_stride, params.dv_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dv_head_stride,bidh);
// dv_ptr = dv_ptr + binfo.k_offset1_write(params.dv_batch_stride, params.dv_row_stride, bidb) + binfo.k_offset2(params.dv_head_stride,bidh);
// auto gdV = tcp_cache_swizzle_func<K_v, Element>(dv_ptr);
// int dv_lane_seq_idx = (lane_id >> 4);
// int dv_lane_head_dim_idx = (lane_id & 15);
// int dv_global_addr_offset=0;
// #pragma unroll
// for(int k_loop = 0; k_loop<(K_v/kBlockK_); k_loop++) {
// #pragma unroll
// for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
// #pragma unroll
// for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
// #pragma unroll
// for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
// #pragma unroll
// for(int vec_index=0; vec_index<4; vec_index++) {
// int v_offset = dv_lane_head_dim_idx*2 + dv_lane_seq_idx * seqlen_dv_stride;
// int s_offset = (min_tile_n*seqlen_dv_stride*16 + vec_index * 4 * seqlen_dv_stride) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dv_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dv_stride);
// int known_offset = 0;
// vec2_Element<Element> v_data;
// v_data[0] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2].f32[vec_index]);
// v_data[1] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 1].f32[vec_index]);
// if (Is_even_MN || n_block * kBlockN_ + warp_id*WARP_N_ + warp_n_idx*32 + dv_lane_seq_idx + min_tile_n*16 + vec_index * 4 < binfo.actual_seqlen_k){
// inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdV, s_offset, /* immediate integer */known_offset);
// }
// }
// }
// }
// }
// }
// }
// {
// // dk_ptr = dk_ptr + binfo.k_offset1(params.dk_batch_stride, params.dk_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dk_head_stride,bidh);
// dk_ptr = dk_ptr + binfo.k_offset1_write(params.dk_batch_stride, params.dk_row_stride, bidb) + binfo.k_offset2(params.dk_head_stride,bidh);
// auto gdK = tcp_cache_swizzle_func<K, Element>(dk_ptr);
// int dk_lane_seq_idx = (lane_id >> 4);
// int dk_lane_head_dim_idx = (lane_id & 15);
// int dk_global_addr_offset=0;
// #pragma unroll
// for(int k_loop = 0; k_loop<(K/kBlockK_); k_loop++) {
// #pragma unroll
// for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
// #pragma unroll
// for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
// #pragma unroll
// for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
// #pragma unroll
// for(int vec_index=0; vec_index<4; vec_index++) {
// vec2_Element<Element> v_data;
// int v_offset = dk_lane_head_dim_idx*2 + dk_lane_seq_idx * seqlen_dk_stride;
// int s_offset = n_block * kBlockN_ * seqlen_dk_stride + (warp_id*WARP_N_) * seqlen_dk_stride + (min_tile_n*seqlen_dk_stride*16 + vec_index * 4 * seqlen_dk_stride + k_tile_idx*32 + k_loop * kBlockK_ + warp_n_idx*32);
// int known_offset = 0;
// v_data[0] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2].f32[vec_index] * params.scale_softmax_rp_dropout);
// v_data[1] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 1].f32[vec_index] * params.scale_softmax_rp_dropout);
// if (Is_even_MN || n_block * kBlockN_ + warp_id*WARP_N_ + dk_lane_seq_idx + min_tile_n*16 + vec_index * 4 < binfo.actual_seqlen_k){
// inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdK, s_offset, /* immediate integer */known_offset);
// }
// }
// }
// }
// }
// }
// }
}
#undef print_dq
#undef print_softmax_rescale_o
#undef print_ds
#undef print_ds_fp16
#undef print_dp
......@@ -401,13 +401,11 @@ __forceinline__ __device__ void gpu_gemm_B_in_reg_gfx938(
int A_lds_stage_offset = stage_id * BLOCK_K * BLOCK_M;
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(A_lds + A_lds_stage_offset), A_reg[0].f16, A_reg[1].f16, false);
if constexpr (std::is_same_v<Element, half_t>) {
auto *const f16_lds = hcu_ds_read_matrix_f16_lds_base(A_lds + A_lds_stage_offset);
A_reg[0].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(f16_lds, 0, 2, 1, 0);
A_reg[1].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(f16_lds, 1024, 2, 1, 0);
A_reg[0].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(A_lds + A_lds_stage_offset, 0, 2, 1, 0);
A_reg[1].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(A_lds + A_lds_stage_offset, 1024, 2, 1, 0);
} else {
auto *const bf16_lds = hcu_ds_read_matrix_bf16_lds_base(A_lds + A_lds_stage_offset);
A_reg[0].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(bf16_lds, 0, 2, 1, 0);
A_reg[1].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(bf16_lds, 1024, 2, 1, 0);
A_reg[0].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(A_lds + A_lds_stage_offset, 0, 2, 1, 0);
A_reg[1].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(A_lds + A_lds_stage_offset, 1024, 2, 1, 0);
}
} else {
// gfx938 m_ab = 0的gemm想要复用m_ab = 1的LDS数据
......@@ -485,3 +483,175 @@ __forceinline__ __device__ void gpu_gemm_B_in_reg_gfx938(
#endif
#endif
}
// K BLOCK_K BLOCK_N BLOCK_M BLOCK_K WARP_N
template<bool Is_preload_A, bool Is_store_A, bool Is_even_MN, int M/*head_dim*/, int BLOCK_M, int BLOCK_N, int BLOCK_K, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum = float>
__forceinline__ __device__ void gpu_gemm_B_in_reg_gfx946(
vec4_uint A_ptr,
vec4_uint C_ptr,
Element* A_lds,
union_vec4_f16x2<Element> B_reg[(WARP_M/32)*(BLOCK_K/32)*2],
union_vec4_fp32 C_reg[(M/BLOCK_M)*(WARP_M/32)*(WARP_N/32)][4],
int N/*seq_kv*/,
int K/*seq_q*/,
int warp_id,
int seqlen_A_stride) {
#if 1
const int WARP_NUM = (BLOCK_M*BLOCK_N)/(WARP_M*WARP_N);
const int A_lds_load_num = (BLOCK_M*BLOCK_K) / (4*32);
static_assert(BLOCK_K>=32, "Error: gpu_gemm_B_in_reg gemm BLOCK_K must be equal or greater than 32");
static_assert(BLOCK_N>=WARP_N, "Error: gpu_gemm_B_in_reg gemm BLOCK_N must be equal or greater than WARP_N");
static_assert(BLOCK_M==WARP_M, "Error: gpu_gemm_B_in_reg gemm BLOCK_M must be equal to WARP_M");
union_vec4_f16x2<Element> A_reg[((WARP_M*BLOCK_K)/(32*32))*2];
//c mini tile is 32*32
// vec4_fp32 o[(WARP_M/32)*(WARP_N/32)][4]={0};
// __shared__ Element A_lds[STAGES*BLOCK_N * BLOCK_K];
//wave size should be defined in launch file. Here use 64 threads
int lane_id = threadIdx.x & 63; //lane id, 0-63
int row = lane_id % 4;
int col = lane_id / 4;
int stage_id = 0;
if(STAGES > 1 && (!Is_preload_A)) {
int m_loop = 0;
int A_block_buffer_load_global_offset = m_loop * BLOCK_M;
int A_lds_stage_offset = stage_id * BLOCK_M * BLOCK_K;
prefetch_to_lds_gfx938<false, BLOCK_M, BLOCK_K, Element, ElementAccum, Is_even_MN>(A_ptr, A_block_buffer_load_global_offset, A_lds + A_lds_stage_offset, seqlen_A_stride, warp_id);
}
#if 1
// int lds_offset = row * 8 + col * 32;
for(int m_loop = 1; m_loop<(M/BLOCK_M) + 1; m_loop++) {
if(STAGES > 1) {
if constexpr(Is_preload_A || Is_store_A){
stage_id ++;
} else {
stage_id = stage_id ^ 1;
}
}
if(STAGES == 1) {
m_loop--;
}
if((!Is_preload_A)&& m_loop < (M/BLOCK_M)) {
int A_block_buffer_load_global_offset = m_loop*BLOCK_M;
int A_lds_stage_offset = (stage_id)*BLOCK_K*BLOCK_M;
prefetch_to_lds_gfx938<false, BLOCK_M, BLOCK_K, Element, ElementAccum, Is_even_MN>(A_ptr, A_block_buffer_load_global_offset, A_lds + A_lds_stage_offset, seqlen_A_stride, warp_id);
}
//BM = 32, BK = 32
if(warp_id == 0) {
if(!Is_preload_A){
if(STAGES > 1 && m_loop < (M/BLOCK_M)) {
vmcnt_wait(1);
} else {
vmcnt_wait(0);
}
}
}
if constexpr (STAGES > 1) {
if constexpr(Is_preload_A || Is_store_A){
stage_id --;
} else {
stage_id = stage_id ^ 1;
}
}
//lds -> vgpr use ds_read_m; left matrix
//由于ds_read方式发生了改变,mmac结果矩阵layout变化,存储的时候,offset要进行修改
{
int A_lds_stage_offset = stage_id * BLOCK_K * BLOCK_M;
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(A_lds + A_lds_stage_offset), A_reg[0].f16, A_reg[1].f16, false);
if constexpr (std::is_same_v<Element, half_t>) {
A_reg[0].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(A_lds + A_lds_stage_offset, 0, 2, 1, 0);
A_reg[1].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(A_lds + A_lds_stage_offset, 1024, 2, 1, 0);
} else {
A_reg[0].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(A_lds + A_lds_stage_offset, 0, 2, 1, 0);
A_reg[1].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(A_lds + A_lds_stage_offset, 1024, 2, 1, 0);
}
}
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
if constexpr (STAGES == 1){
m_loop++;
}
asm volatile("s_setprio 1");
#pragma unroll
for(int n_idx=0; n_idx<(WARP_N/32); n_idx++) {
#pragma unroll
for(int m_idx=0; m_idx<(WARP_M/32); m_idx++) {
#pragma unroll
for(int k_idx=0; k_idx<(BLOCK_K/32); k_idx++) { //BLOCK_K mini size is 32
//min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for(int min_tile_k=0; min_tile_k<2; min_tile_k++) {
if constexpr (std::is_same<Element,Float8_e4m3_t>::value){
} else {
//A采用ds_read后对应的mmac
C_reg[m_loop-1][min_tile_n*2 + min_tile_m].f32 = mmac_4interleave<Element, ElementAccum>(
//BN = 32, BK = 32
// vec4_Element<Element>{A_reg[min_tile_k].f16[0*2 + min_tile_m], A_reg[min_tile_k].f16[1*2 + min_tile_m], A_reg[min_tile_k].f16[2*2 + min_tile_m], A_reg[min_tile_k].f16[3*2 + min_tile_m]},
B_reg[min_tile_k].f16x4[min_tile_n],
A_reg[min_tile_k].f16x4[min_tile_m],
C_reg[m_loop-1][min_tile_n*2 + min_tile_m].f32);
}
}
}
}
}
}
}
// //test only
// for(int min_tile_n = 0; min_tile_n < 2; ++ min_tile_n) {
// for(int min_tile_m = 0; min_tile_m < 2; ++ min_tile_m) {
// C_reg[m_loop-1][min_tile_n*2 + min_tile_m].f32[0] = UpCast<Element,float, true>(B_reg[min_tile_m].f16x4[min_tile_n][0]);
// C_reg[m_loop-1][min_tile_n*2 + min_tile_m].f32[1] = UpCast<Element,float, true>(B_reg[min_tile_m].f16x4[min_tile_n][1]);
// C_reg[m_loop-1][min_tile_n*2 + min_tile_m].f32[2] = UpCast<Element,float, true>(B_reg[min_tile_m].f16x4[min_tile_n][2]);
// C_reg[m_loop-1][min_tile_n*2 + min_tile_m].f32[3] = UpCast<Element,float, true>(B_reg[min_tile_m].f16x4[min_tile_n][3]);
// }
// }
asm volatile("s_setprio 0");
if(STAGES > 1) {
if constexpr(Is_preload_A || Is_store_A){
stage_id ++;
} else {
stage_id ^=1;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
} else {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
}
}
#endif
#endif
}
\ No newline at end of file
......@@ -410,13 +410,11 @@ __forceinline__ __device__ void gemm_tt_kq_gfx938(
int A_lds_stage_offset = stage_id * BLOCK_M * BLOCK_K;
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(A_lds + A_lds_stage_offset), A_reg_tmp[0].f16, A_reg_tmp[1].f16, true);
if constexpr (std::is_same_v<Element, half_t>) {
auto *const f16_lds = hcu_ds_read_matrix_f16_lds_base(A_lds + A_lds_stage_offset);
A_reg_tmp[0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(f16_lds, 0, 2, 1, 0);
A_reg_tmp[1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(f16_lds, 1024, 2, 1, 0);
A_reg_tmp[0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(A_lds + A_lds_stage_offset, 0, 2, 1, 0);
A_reg_tmp[1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(A_lds + A_lds_stage_offset, 1024, 2, 1, 0);
} else {
auto *const bf16_lds = hcu_ds_read_matrix_bf16_lds_base(A_lds + A_lds_stage_offset);
A_reg_tmp[0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(bf16_lds, 0, 2, 1, 0);
A_reg_tmp[1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(bf16_lds, 1024, 2, 1, 0);
A_reg_tmp[0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(A_lds + A_lds_stage_offset, 0, 2, 1, 0);
A_reg_tmp[1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(A_lds + A_lds_stage_offset, 1024, 2, 1, 0);
}
}
int B_lds_stage_offset = stage_id * WARP_N * BLOCK_K;
......
......@@ -117,13 +117,10 @@ inline __device__ void prefetch_to_vgpr_gfx938(
srsrc[3] = nm_filter << 8; // set only once
}
*(uint64_t*)&srsrc = VA_LIMIT_BITS(*(uint64_t*)&ptr + global_offset * ELEMENT_BYTES);
union union_vec4_uint rsrc_bits;
rsrc_bits.v32 = srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(lds) + lds_offset_stage;
if(trans) {
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds_trans<0, 0>(lds, srsrc, lds_offset_stage, 0);
} else {
matrix_load_b16_lds_builtin<32, 32, 0, 0>(lds_addr_warp, rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds<0, 0>(lds, srsrc, lds_offset_stage, 0);
}
}
for(int m_loop = 0; m_loop < M / 128; ++m_loop) {
......@@ -147,13 +144,10 @@ inline __device__ void prefetch_to_vgpr_gfx938(
}
*(uint64_t*)&srsrc = VA_LIMIT_BITS(*(uint64_t*)&ptr + global_offset * ELEMENT_BYTES);
if(n_loop < N / 32) {
union union_vec4_uint rsrc_bits;
rsrc_bits.v32 = srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(lds) + lds_offset_stage;
if(trans) {
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds_trans<0, 0>(lds, srsrc, lds_offset_stage, 0);
} else {
matrix_load_b16_lds_builtin<32, 32, 0, 0>(lds_addr_warp, rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds<0, 0>(lds, srsrc, lds_offset_stage, 0);
}
}
......@@ -167,36 +161,20 @@ inline __device__ void prefetch_to_vgpr_gfx938(
if(trans){
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(lds_load_offset_stage), reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16, reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16, true);
if constexpr (std::is_same_v<Element, half_t>) {
auto *const f16_lds = hcu_ds_read_matrix_f16_lds_base(
lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 =
__builtin_hcu_ds_read_matrix_trans_format_f16(f16_lds, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 =
__builtin_hcu_ds_read_matrix_trans_format_f16(f16_lds, 1024, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 1024, 2, 1, 0);
} else {
auto *const bf16_lds = hcu_ds_read_matrix_bf16_lds_base(
lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 =
__builtin_hcu_ds_read_matrix_trans_format_bf16(bf16_lds, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 =
__builtin_hcu_ds_read_matrix_trans_format_bf16(bf16_lds, 1024, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 1024, 2, 1, 0);
}
} else {
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(lds_load_offset_stage), reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16, reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16, false);
if constexpr (std::is_same_v<Element, half_t>) {
auto *const f16_lds = hcu_ds_read_matrix_f16_lds_base(
lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 =
__builtin_hcu_ds_read_matrix_format_f16(f16_lds, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 =
__builtin_hcu_ds_read_matrix_format_f16(f16_lds, 1024, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 1024, 2, 1, 0);
} else {
auto *const bf16_lds = hcu_ds_read_matrix_bf16_lds_base(
lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 =
__builtin_hcu_ds_read_matrix_format_bf16(bf16_lds, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 =
__builtin_hcu_ds_read_matrix_format_bf16(bf16_lds, 1024, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 0, 2, 1, 0);
reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_format_bf16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 1024, 2, 1, 0);
}
}
lgkmcnt_wait(0);
......@@ -246,13 +224,11 @@ inline __device__ void prefetch_to_lds_gfx938(
*(uint64_t*)&srsrc = VA_LIMIT_BITS(*(uint64_t*)&ptr + global_offset);
//计算LDS地址,每个warp使用一个32*32;下一个loop重复利用
int lds_offset = (loop_warp * 32 * 32) * ELEMENT_BYTES;
union union_vec4_uint rsrc_bits;
rsrc_bits.v32 = srsrc;
size_t lds_addr_warp = reinterpret_cast<size_t>(lds) + lds_offset;
int lds_load_offset = reinterpret_cast<size_t>(lds) + lds_offset;
if (trans) {
matrix_load_b16_lds_trans_builtin<32, 32, 0, 0>(lds_addr_warp, rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds_trans<0, 0>(lds, srsrc, lds_offset, 0);
} else {
matrix_load_b16_lds_builtin<32, 32, 0, 0>(lds_addr_warp, rsrc_bits.i32, 0);
inline_matrix_load_32x32_b16_lds<0, 0>(lds, srsrc, lds_offset, 0);
}
}
}
......
......@@ -57,6 +57,7 @@ inline __device__ void apply_mask_bwd(union_vec4_fp32 tensor[1][4], int M, int N
}
}
}
//local mask
if (mask_type == 3) {// && (!Is_even_MN || Is_even_MN && (std::abs(M_minus_N - window_size_left) < 128 || std::abs(M_minus_N + window_size_right) < 128))
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
......@@ -112,21 +113,38 @@ inline __device__ void apply_mask_bwd_gfx938(union_vec4_fp32 tensor[1][4], int M
}
}
}
// //mask左下角
// if (mask_type == 2 && (!Is_even_MN || Is_even_MN && std::abs(M_minus_N) < 128)) {
// for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
// int M_offset = min_tile_m * 16 + lane_m_idx;
// for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
// for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
// int N_offset = min_tile_n * 16 + lane_n_idx * 4 + vec_idx;
// int N_limit = (M_offset + M_minus_N);
// if((!Is_even_MN && N_offset > N - 1) || N_offset < N_limit){
// tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
// }
// }
// }
// }
// }
//mask左下角
if (mask_type == 2 && (!Is_even_MN || Is_even_MN && std::abs(M_minus_N) < 128)) {
if (mask_type == 2 ) {
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
int M_offset = min_tile_m * 16 + lane_m_idx;
for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
int N_offset = min_tile_n * 16 + lane_n_idx * 4 + vec_idx;
int N_limit = (M_offset + M_minus_N);
if((!Is_even_MN && N_offset > N - 1) || N_offset < N_limit){
if(N_offset < N_limit){
tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
}
}
}
}
}
//local mask
if (mask_type == 3) {// && (!Is_even_MN || Is_even_MN && (std::abs(M_minus_N - window_size_left) < 128 || std::abs(M_minus_N + window_size_right) < 128))
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
......@@ -327,7 +345,7 @@ inline __device__ void scale_apply_exp2_bwd(DataType0 tensor[(BLOCK_M/32)*(WARP_
auto vec2_scale = vec2_fp32{scale, scale};
auto vec2_max_scaled = vec2_fp32{-max_scaled, -max_scaled};
auto tensor_tmp =
hcu_pk_fma_f32(
__builtin_hcu_pk_fma_f32(
vec2_tensor,
vec2_scale,
vec2_max_scaled);
......
......@@ -75,6 +75,11 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ softmax_lse_ptr;
void * __restrict__ softmax_lseaccum_ptr;
// Attention sink values, one scalar per original query head.
// s_aux_type: 0 none, 1 fp32, 2 fp16, 3 bf16.
void * __restrict__ s_aux_ptr;
int s_aux_type;
// For FP8 scaling
float * __restrict__ q_descale_ptr;
float * __restrict__ k_descale_ptr;
......@@ -366,6 +371,8 @@ struct Flash_fwd_mla_reduce_params {
template<typename T, int Headdim, int HeaddimV> void run_mha_fwd_(Flash_fwd_params &params, hipStream_t stream);
template<typename T, int Headdim, int HeaddimV> void run_fp8_mha_fwd_(Flash_fwd_params &params, hipStream_t stream);
template<typename T, int Headdim, int HeaddimV> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, hipStream_t stream);
template<typename T, int Headdim, int HeaddimV> void run_int8_fwd_splitkv_dispatch(Flash_fwd_params &params, hipStream_t stream);
......@@ -386,6 +393,8 @@ template<typename T, int Headdim, int HeaddimV> void run_mha_fwd_prefix_prefill_
template<typename T, int Headdim, int HeaddimV> void run_int8_mha_fwd_prefix_prefill_(Flash_fwd_params &params, hipStream_t stream);
template<typename T, int Headdim, int HeaddimV> void run_fp8_mha_fwd_prefix_prefill_(Flash_fwd_params &params, hipStream_t stream);
template<typename T, int Headdim, int HeaddimV> void run_mla_fwd_prefix_prefill_dispatch_(Flash_fwd_mla_params &params, hipStream_t stream);
template<typename T, int Headdim, int HeaddimV> void run_mla_fwd_dispatch(Flash_fwd_mla_params &params, hipStream_t stream);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment