Commit 518a5f4d authored by hly's avatar hly
Browse files

import aicc-master-dev

parent c2a1b310
...@@ -5,4 +5,3 @@ ...@@ -5,4 +5,3 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template void run_mha_fwd_unified_dispatch<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream); template void run_mha_fwd_unified_dispatch<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);
...@@ -5,4 +5,3 @@ ...@@ -5,4 +5,3 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template void run_mha_fwd_unified_dispatch<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream); template void run_mha_fwd_unified_dispatch<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);
...@@ -5,4 +5,3 @@ ...@@ -5,4 +5,3 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template void run_mha_fwd_unified_dispatch<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream); template void run_mha_fwd_unified_dispatch<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);
...@@ -5,4 +5,3 @@ ...@@ -5,4 +5,3 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template void run_mha_fwd_unified_dispatch<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream); template void run_mha_fwd_unified_dispatch<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);
...@@ -48,7 +48,7 @@ static __device__ inline float to_float(scalar_t in){ ...@@ -48,7 +48,7 @@ static __device__ inline float to_float(scalar_t in){
inline __device__ float uint82float(const uint8_t& input) { inline __device__ float uint82float(const uint8_t& input) {
#if (defined(__gfx938__) ) #if (defined(__gfx938__) ||defined(__gfx92a__))
return __builtin_hcu_cvt_f32_fp8(input,false,0,0); return __builtin_hcu_cvt_f32_fp8(input,false,0,0);
#else #else
const uint32_t w = (uint32_t)input << 24; const uint32_t w = (uint32_t)input << 24;
...@@ -137,11 +137,11 @@ __forceinline__ __device__ scalar_t uint82half(const uint8_t& input) { ...@@ -137,11 +137,11 @@ __forceinline__ __device__ scalar_t uint82half(const uint8_t& input) {
#define REUSEKV_SWITCH(reusekv,...) \ #define REUSEKV_SWITCH(reusekv,...) \
[&] { \ [&] { \
if (reusekv==48){ \ if (reusekv==64){ \
constexpr static int REUSE_KV_TIMES = 48; \ constexpr static int REUSE_KV_TIMES = 64; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
}else if (reusekv==36){ \ }else if (reusekv==48){ \
constexpr static int REUSE_KV_TIMES = 36; \ constexpr static int REUSE_KV_TIMES = 48; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
}else if (reusekv==32){ \ }else if (reusekv==32){ \
constexpr static int REUSE_KV_TIMES = 32; \ constexpr static int REUSE_KV_TIMES = 32; \
...@@ -257,18 +257,11 @@ inline __device__ float block_sum(float* red_smem, float sum) { ...@@ -257,18 +257,11 @@ inline __device__ float block_sum(float* red_smem, float sum) {
template<bool is_half> template<bool is_half>
inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c) inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c)
{ {
#if (defined(__gfx938__) ) if constexpr (is_half){
if constexpr (is_half){reg_c=__builtin_hcu_mmac_f32_16x16x16_f16_lit_lts(reg_a,reg_b,reg_c,false,false);} reg_c=__builtin_hcu_mmac_f32_16x16x16_f16(reg_a,reg_b,reg_c);
else{ }else{
reg_c=__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(*(v4bh*)&reg_a,*(v4bh*)&reg_b,reg_c,false,false); reg_c=__builtin_hcu_mmac_f32_16x16x16_bf16(*(v4bh*)&reg_a,*(v4bh*)&reg_b,reg_c);
} }
#else
if constexpr (is_half){reg_c=__builtin_amdgcn_mmac_f32_16x16x16f16(reg_a,reg_b,reg_c);}
else{
reg_c=__builtin_amdgcn_mmac_f32_16x16x16bf16(*(v4bh*)&reg_a,*(v4bh*)&reg_b,reg_c);
}
#endif
} }
...@@ -390,23 +383,28 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -390,23 +383,28 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
for(int m=0;m<Mloop;m++){ for(int m=0;m<Mloop;m++){
qk_vec[m]={0,0,0,0}; qk_vec[m]={0,0,0,0};
} }
half4x2 k_vec[HEAD_SIZE/32];
__builtin_amdgcn_sched_barrier(0);
#pragma unroll #pragma unroll
for(int i=0;i<HEAD_SIZE/32;i++){ for(int i=0;i<HEAD_SIZE/32;i++){
half4x2 k_vec;
if constexpr(is_fp8){ if constexpr(is_fp8){
uint8x4x2 k_vec_u8=*reinterpret_cast<const uint8x4x2*>(k_ptr+i*32+rowid*HEAD_SIZE+rows*8); uint8x4x2 k_vec_u8=*reinterpret_cast<const uint8x4x2*>(k_ptr+i*32+rowid*HEAD_SIZE+rows*8);
scalar_t *p1=(scalar_t*)&k_vec; scalar_t *p1=(scalar_t*)(k_vec+i);
uint8_t *p2=(uint8_t*)&k_vec_u8; uint8_t *p2=(uint8_t*)&k_vec_u8;
for(int ii=0;ii<8;ii++){ for(int ii=0;ii<8;ii++){
p1[ii]=uint82half<scalar_t,is_e4m3>(p2[ii]); p1[ii]=uint82half<scalar_t,is_e4m3>(p2[ii]);
} }
} }
else{ else{
k_vec=*reinterpret_cast<const half4x2*>(k_ptr+i*32+rowid*HEAD_SIZE+rows*8); k_vec[i]=*reinterpret_cast<const half4x2*>(k_ptr+i*32+rowid*HEAD_SIZE+rows*8);
} }
}
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int i=0;i<HEAD_SIZE/32;i++){
for(int m=0;m<Mloop;m++){ for(int m=0;m<Mloop;m++){
builtin_amdgcn_mmac<is_half>(k_vec.data[0],q_vec[m][i].data[0],qk_vec[m]); builtin_amdgcn_mmac<is_half>(k_vec[i].data[0],q_vec[m][i].data[0],qk_vec[m]);
builtin_amdgcn_mmac<is_half>(k_vec.data[1],q_vec[m][i].data[1],qk_vec[m]); builtin_amdgcn_mmac<is_half>(k_vec[i].data[1],q_vec[m][i].data[1],qk_vec[m]);
} }
} }
#pragma unroll #pragma unroll
...@@ -597,7 +595,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -597,7 +595,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
} }
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride; kv_head_idx * kv_head_stride;
if(partition_idx<num_partitions-1){ if(partition_idx<num_partitions-1||block_idx < num_seq_blocks-1){
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
int offset=i*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD+warp_idx*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD/NUM_WARPS+rows*vecsize*4+rowid*BLOCK_SIZE; int offset=i*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD+warp_idx*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD/NUM_WARPS+rows*vecsize*4+rowid*BLOCK_SIZE;
...@@ -635,12 +633,10 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -635,12 +633,10 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
v_vec=*reinterpret_cast<const half4_vec*>(v_ptr + offset); v_vec=*reinterpret_cast<const half4_vec*>(v_ptr + offset);
} }
//这里的if判断会影响一定的性能,因此只有最后一个patition才判断 //这里的if判断会影响一定的性能,因此只有最后一个patition才判断
if (block_idx == num_seq_blocks - 1) { scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); #pragma unroll
#pragma unroll for (int j = 0; j < 4*vecsize; j++) {
for (int j = 0; j < 4*vecsize; j++) { v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : 0;
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : 0;
}
} }
for(int ii=0;ii<vecsize;ii++){ for(int ii=0;ii<vecsize;ii++){
for(int m=0;m<Mloop;m++){ for(int m=0;m<Mloop;m++){
...@@ -756,8 +752,8 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_combine( ...@@ -756,8 +752,8 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_combine(
} }
static int get_reusekv(int qhead,int kv_head){ static int get_reusekv(int qhead,int kv_head){
if(qhead>kv_head*36) return 48; if(qhead>kv_head*48) return 64;
if(qhead>kv_head*32) return 36;//glm4.7 mtp 3 if(qhead>kv_head*32) return 48;
if(qhead>kv_head*24) return 32; if(qhead>kv_head*24) return 32;
if(qhead>kv_head*16) return 24; if(qhead>kv_head*16) return 24;
if(qhead>kv_head*8) return 16; if(qhead>kv_head*8) return 16;
...@@ -831,7 +827,7 @@ void paged_attention( ...@@ -831,7 +827,7 @@ void paged_attention(
hipMalloc(&tmp_out_ptr, temp_out_size); // 100m hipMalloc(&tmp_out_ptr, temp_out_size); // 100m
hipMemset(tmp_out_ptr,0,temp_out_size); hipMemset(tmp_out_ptr,0,temp_out_size);
} }
if(device_name=="gfx938"&&(key_cache.dtype()==torch::kFloat8_e5m2||key_cache.dtype()==torch::kFloat8_e4m3fn)){ if((device_name=="gfx938"|| device_name == "gfx92a")&&(key_cache.dtype()==torch::kFloat8_e5m2||key_cache.dtype()==torch::kFloat8_e4m3fn)){
paged_attention_938(out,query,key_cache,value_cache,block_tables,seq_lens,alibi_slopes,q_scale,k_scale,v_scale,max_seq_len,s_aux_,tmp_out_ptr,PARTITION_SIZE); paged_attention_938(out,query,key_cache,value_cache,block_tables,seq_lens,alibi_slopes,q_scale,k_scale,v_scale,max_seq_len,s_aux_,tmp_out_ptr,PARTITION_SIZE);
return; return;
} }
...@@ -876,16 +872,16 @@ void paged_attention( ...@@ -876,16 +872,16 @@ void paged_attention(
grid.x = num_kv_heads; grid.x = num_kv_heads;
grid.y = num_seqs; grid.y = num_seqs;
AT_ASSERTM(headsize%64==0 && headsize<=256, "Page Attention head size must be 64, 128, 192 or 256"); AT_ASSERTM(headsize%64==0 && headsize<=256, "Page Attention head size must be 64, 128, 192 or 256");
AT_ASSERTM(num_heads<=num_kv_heads*48, "Page Attention qheads*mtp/kvheads must be smaller than 48"); AT_ASSERTM(num_heads<=num_kv_heads*64, "Page Attention qheads*mtp/kvheads must be smaller than 48");
HEADSIZE_SWITCH(headsize,[&]{ HEADSIZE_SWITCH(headsize,[&]{
Input_Type_SWITCH(query.dtype(),[&]{ Input_Type_SWITCH(query.dtype(),[&]{
Cache_Type_SWITCH(scalar_t,key_cache.dtype(),[&] { Cache_Type_SWITCH(scalar_t,key_cache.dtype(),[&] {
REUSEKV_SWITCH(reusekv,[&] { REUSEKV_SWITCH(reusekv,[&] {
BOOL_SWITCH(block_size==64,is_block64,[&]{ BOOL_SWITCH(block_size==64,is_block64,[&]{
constexpr int BLOCK_SIZE = (is_block64?64:128); // constexpr int BLOCK_SIZE = (is_block64?64:128);
// constexpr int BLOCK_SIZE=128; constexpr int BLOCK_SIZE=64;
// constexpr int HEAD_SIZE=128; // constexpr int HEAD_SIZE=128;
// using scalar_t=_Float16; // using scalar_t=uint16_t;
// using cache_t = scalar_t; // using cache_t = scalar_t;
constexpr bool is_e4m3=false; constexpr bool is_e4m3=false;
// constexpr static int REUSE_KV_TIMES = 4; // constexpr static int REUSE_KV_TIMES = 4;
......
...@@ -58,7 +58,7 @@ static __device__ inline float to_float(scalar_t in){ ...@@ -58,7 +58,7 @@ static __device__ inline float to_float(scalar_t in){
inline __device__ float uint82float(const uint8_t& input) { inline __device__ float uint82float(const uint8_t& input) {
#if (defined(__gfx938__) ) #if (defined(__gfx938__)||defined(__gfx92a__) )
return __builtin_hcu_cvt_f32_fp8(input,false,0,0); return __builtin_hcu_cvt_f32_fp8(input,false,0,0);
#else #else
const uint32_t w = (uint32_t)input << 24; const uint32_t w = (uint32_t)input << 24;
...@@ -106,7 +106,7 @@ __forceinline__ __device__ scalar_t uint82half(const uint8_t& input) { ...@@ -106,7 +106,7 @@ __forceinline__ __device__ scalar_t uint82half(const uint8_t& input) {
template <bool is_e4m3> template <bool is_e4m3>
static __device__ int to_f8_from_f32(float v1,float v2,float v3,float v4) { static __device__ int to_f8_from_f32(float v1,float v2,float v3,float v4) {
int val=0; int val=0;
#if (defined(__gfx938__) ) #if (defined(__gfx938__) || defined(__gfx92a__))
if constexpr(is_e4m3){ if constexpr(is_e4m3){
val = __builtin_hcu_cvt_pk_fp8_f32(v1,v2,val,false); val = __builtin_hcu_cvt_pk_fp8_f32(v1,v2,val,false);
val = __builtin_hcu_cvt_pk_fp8_f32(v3,v4,val,true); val = __builtin_hcu_cvt_pk_fp8_f32(v3,v4,val,true);
...@@ -122,7 +122,7 @@ static __device__ int to_f8_from_f32(float v1,float v2,float v3,float v4) { ...@@ -122,7 +122,7 @@ static __device__ int to_f8_from_f32(float v1,float v2,float v3,float v4) {
template <bool is_e4m3> template <bool is_e4m3>
static __device__ float4_t to_fp32_from_fp8(int val) { static __device__ float4_t to_fp32_from_fp8(int val) {
float4_t ret; float4_t ret;
#if (defined(__gfx938__) ) #if (defined(__gfx938__) || defined(__gfx92a__))
if constexpr(is_e4m3){ if constexpr(is_e4m3){
ret[0] = __builtin_hcu_cvt_f32_fp8(val,false,0,0); ret[0] = __builtin_hcu_cvt_f32_fp8(val,false,0,0);
ret[1] = __builtin_hcu_cvt_f32_fp8(val,false,0,1); ret[1] = __builtin_hcu_cvt_f32_fp8(val,false,0,1);
...@@ -184,11 +184,11 @@ static __device__ float4_t to_fp32_from_fp8(int val) { ...@@ -184,11 +184,11 @@ static __device__ float4_t to_fp32_from_fp8(int val) {
#define REUSEKV_SWITCH(reusekv,...) \ #define REUSEKV_SWITCH(reusekv,...) \
[&] { \ [&] { \
if (reusekv==48){ \ if (reusekv==64){ \
constexpr static int REUSE_KV_TIMES = 48; \ constexpr static int REUSE_KV_TIMES = 64; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
}else if (reusekv==36){ \ }else if (reusekv==48){ \
constexpr static int REUSE_KV_TIMES = 36; \ constexpr static int REUSE_KV_TIMES = 48; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
}else if (reusekv==32){ \ }else if (reusekv==32){ \
constexpr static int REUSE_KV_TIMES = 32; \ constexpr static int REUSE_KV_TIMES = 32; \
...@@ -303,13 +303,11 @@ inline __device__ float block_sum(float* red_smem, float sum) { ...@@ -303,13 +303,11 @@ inline __device__ float block_sum(float* red_smem, float sum) {
template<bool is_e4m3> template<bool is_e4m3>
inline __device__ void builtin_amdgcn_mmac(const intx2& reg_a, const intx2& reg_b, float4_t& reg_c) inline __device__ void builtin_amdgcn_mmac(const intx2& reg_a, const intx2& reg_b, float4_t& reg_c)
{ {
#if (defined(__gfx938__) )
if constexpr(is_e4m3){ if constexpr(is_e4m3){
reg_c=__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(reg_a,reg_b,reg_c,false,false); reg_c=__builtin_hcu_mmac_f32_16x16x32_fp8_fp8(reg_a,reg_b,reg_c);
}else{ }else{
reg_c=__builtin_hcu_mmac_f32_16x16x32_bf8_bf8_lit_lts(reg_a,reg_b,reg_c,false,false); reg_c=__builtin_hcu_mmac_f32_16x16x32_bf8_bf8(reg_a,reg_b,reg_c);
} }
#endif
} }
template <typename scalar_t,typename q_type,bool is_e4m3 ,int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t,typename q_type,bool is_e4m3 ,int HEAD_SIZE, int BLOCK_SIZE,
...@@ -332,7 +330,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -332,7 +330,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
const float* q_scale_ptr, const float* k_scale_ptr, const float* v_scale_ptr, const float* q_scale_ptr, const float* k_scale_ptr, const float* v_scale_ptr,
int max_num_partitions,int PARTITION_SIZE, int max_num_partitions,int PARTITION_SIZE,
const scalar_t* __restrict__ s_aux_ptr,int mtp,bool has_abili) { // ★ Attention Sinks: [num_heads] scalar_t ★ const scalar_t* __restrict__ s_aux_ptr,int mtp,bool has_abili) { // ★ Attention Sinks: [num_heads] scalar_t ★
#if (defined(__gfx938__) ) #if (defined(__gfx938__)||defined(__gfx92a__) )
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z; const int partition_idx = blockIdx.z;
constexpr int kv_head_stride=BLOCK_SIZE*HEAD_SIZE; constexpr int kv_head_stride=BLOCK_SIZE*HEAD_SIZE;
...@@ -453,10 +451,16 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -453,10 +451,16 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
for(int m=0;m<Mloop;m++){ for(int m=0;m<Mloop;m++){
qk_vec[m]={0,0,0,0}; qk_vec[m]={0,0,0,0};
} }
intx4 k_vec[HEAD_SIZE/64];
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int i=0;i<HEAD_SIZE/64;i++){
k_vec[i]=*reinterpret_cast<const intx4*>(k_ptr+i*64+rowid*HEAD_SIZE+rows*16);
}
__builtin_amdgcn_sched_barrier(0);
#pragma unroll #pragma unroll
for(int i=0;i<HEAD_SIZE/64;i++){ for(int i=0;i<HEAD_SIZE/64;i++){
intx4 k_vec=*reinterpret_cast<const intx4*>(k_ptr+i*64+rowid*HEAD_SIZE+rows*16); intx2 *k_vec_2 = (intx2*)(k_vec+i);
intx2 *k_vec_2 = (intx2*)&k_vec;
for(int m=0;m<Mloop;m++){ for(int m=0;m<Mloop;m++){
intx2 *q_vec_2 = (intx2*)(&q_vec[m][i]); intx2 *q_vec_2 = (intx2*)(&q_vec[m][i]);
builtin_amdgcn_mmac<is_e4m3>(k_vec_2[0],q_vec_2[0],qk_vec[m]); builtin_amdgcn_mmac<is_e4m3>(k_vec_2[0],q_vec_2[0],qk_vec[m]);
...@@ -655,7 +659,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -655,7 +659,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
} }
const uint8_t* v_ptr = v_cache + physical_block_number * kv_block_stride + const uint8_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride; kv_head_idx * kv_head_stride;
if(partition_idx<num_partitions-1){ if(partition_idx<num_partitions-1||block_idx < num_seq_blocks-1){
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
int offset=i*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD+warp_idx*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD/NUM_WARPS+rows*16+rowid*BLOCK_SIZE; int offset=i*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD+warp_idx*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD/NUM_WARPS+rows*16+rowid*BLOCK_SIZE;
...@@ -673,12 +677,10 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -673,12 +677,10 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
int offset=i*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD+warp_idx*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD/NUM_WARPS+rows*16+rowid*BLOCK_SIZE; int offset=i*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD+warp_idx*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD/NUM_WARPS+rows*16+rowid*BLOCK_SIZE;
int_vec v_vec = *reinterpret_cast<const int_vec*>(v_ptr + offset); int_vec v_vec = *reinterpret_cast<const int_vec*>(v_ptr + offset);
//这里的if判断会影响一定的性能,因此只有最后一个patition才判断 //这里的if判断会影响一定的性能,因此只有最后一个patition才判断
if (block_idx == num_seq_blocks - 1) { uint8_t* v_vec_ptr = reinterpret_cast<uint8_t*>(&v_vec);
uint8_t* v_vec_ptr = reinterpret_cast<uint8_t*>(&v_vec); #pragma unroll
#pragma unroll for (int j = 0; j < 16; j++) {
for (int j = 0; j < 16; j++) { v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : 0;
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : 0;
}
} }
for(int ii=0;ii<vecsize;ii++){ for(int ii=0;ii<vecsize;ii++){
for(int m=0;m<Mloop;m++){ for(int m=0;m<Mloop;m++){
...@@ -795,8 +797,8 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_combine( ...@@ -795,8 +797,8 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_combine(
} }
static int get_reusekv(int qhead,int kv_head){ static int get_reusekv(int qhead,int kv_head){
if(qhead>kv_head*36) return 48; if(qhead>kv_head*48) return 64;
if(qhead>kv_head*32) return 36;//glm4.7 mtp 3 if(qhead>kv_head*32) return 48;
if(qhead>kv_head*24) return 32; if(qhead>kv_head*24) return 32;
if(qhead>kv_head*16) return 24; if(qhead>kv_head*16) return 24;
if(qhead>kv_head*8) return 16; if(qhead>kv_head*8) return 16;
...@@ -870,17 +872,18 @@ void paged_attention_938( ...@@ -870,17 +872,18 @@ void paged_attention_938(
int reusekv=get_reusekv(num_heads,num_kv_heads); int reusekv=get_reusekv(num_heads,num_kv_heads);
int headsize=query.size(3); int headsize=query.size(3);
AT_ASSERTM(headsize%64==0 && headsize<=256, "Page Attention head size must be 64, 128, 192 or 256"); AT_ASSERTM(headsize%64==0 && headsize<=256, "Page Attention head size must be 64, 128, 192 or 256");
AT_ASSERTM(num_heads<=num_kv_heads*48, "Page Attention qheads*mtp/kvheads must be smaller than 48"); AT_ASSERTM(num_heads<=num_kv_heads*64, "Page Attention qheads*mtp/kvheads must be smaller than 48");
HEADSIZE_SWITCH(headsize,[&]{ HEADSIZE_SWITCH(headsize,[&]{
Output_Type_SWITCH(out.dtype(),[&]{ Output_Type_SWITCH(out.dtype(),[&]{
Input_Type_SWITCH(scalar_t,query.dtype(),key_cache.dtype(),[&] { Input_Type_SWITCH(scalar_t,query.dtype(),key_cache.dtype(),[&] {
REUSEKV_SWITCH(reusekv,[&] { REUSEKV_SWITCH(reusekv,[&] {
BOOL_SWITCH(block_size==64,is_block64,[&]{ BOOL_SWITCH(block_size==64,is_block64,[&]{
constexpr int BLOCK_SIZE = (is_block64?64:128); // constexpr int BLOCK_SIZE = (is_block64?64:128);
// constexpr int HEAD_SIZE=128; constexpr int BLOCK_SIZE=64;
// constexpr int HEAD_SIZE=256;
// using scalar_t=uint16_t; // using scalar_t=uint16_t;
// constexpr bool is_e4m3=true; // constexpr bool is_e4m3=true;
// constexpr static int REUSE_KV_TIMES = 4; // constexpr static int REUSE_KV_TIMES = 64;
// constexpr bool has_abili=false; // constexpr bool has_abili=false;
// constexpr bool use_mtp=false; // constexpr bool use_mtp=false;
constexpr static int NUM_THREADS = 256; constexpr static int NUM_THREADS = 256;
......
...@@ -26,6 +26,7 @@ __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &t ...@@ -26,6 +26,7 @@ __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &t
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) { for (int mi = 0; mi < size<0>(tensor); mi++) {
#if defined(__gfx928__)
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll #pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) { for (int ni = 1; ni < size<1>(tensor); ni++) {
...@@ -36,6 +37,29 @@ __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &t ...@@ -36,6 +37,29 @@ __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &t
// printf("thread_reduce_ mi:%d ni:%d %7.4f %7.4f %7.4f\n", mi, ni, ori, tensor(mi, ni), summary(mi)); // printf("thread_reduce_ mi:%d ni:%d %7.4f %7.4f %7.4f\n", mi, ni, ori, tensor(mi, ni), summary(mi));
// } // }
} }
#else
if constexpr (std::is_same_v<Operator, SumOp<float>>) {
using __float2 = __attribute__((ext_vector_type(2))) float;
__float2 sum_v = {zero_init ? 0.0f : summary(mi), 0.0f};
for (int ni = 0; ni < size<1>(tensor); ni += 2) {
__float2 vx2 = {tensor(mi, ni), tensor(mi, ni + 1)};
sum_v = __builtin_hcu_pk_add_f32(sum_v, vx2);
}
summary(mi) = sum_v.x + sum_v.y;
}
else {
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
// float ori = summary(mi);
summary(mi) = op(summary(mi), tensor(mi, ni));
// wangaq debug
// if (thread0()) {
// printf("thread_reduce_ mi:%d ni:%d %7.4f %7.4f %7.4f\n", mi, ni, ori, tensor(mi, ni), summary(mi));
// }
}
}
#endif
} }
} }
...@@ -131,6 +155,7 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tenso ...@@ -131,6 +155,7 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tenso
// We don't want (-inf - (-inf)) since that would give NaN. // We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64. // If we don't have float around M_LOG2E the multiplication is done in fp64.
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
#if defined(__gfx928__)
#pragma unroll #pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) { for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) - // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
...@@ -141,6 +166,17 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tenso ...@@ -141,6 +166,17 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tenso
// This macro is set in PyTorch and not FlashAttention // This macro is set in PyTorch and not FlashAttention
tensor(mi, ni) = custom_exp2f(tensor(mi, ni) * scale - max_scaled); tensor(mi, ni) = custom_exp2f(tensor(mi, ni) * scale - max_scaled);
} }
#else
using __float2 = __attribute__((ext_vector_type(2))) float;
__float2 scalex2 = {scale, scale};
__float2 max_scaledx2 = {-max_scaled, -max_scaled};
for (int ni = 0; ni < size<1>(tensor); ni += 2) {
__float2 vx2 = {tensor(mi, ni), tensor(mi, ni + 1)};
__float2 res = __builtin_hcu_pk_fma_f32(vx2, scalex2, max_scaledx2);
tensor(mi, ni) = custom_exp2f(res.x);
tensor(mi, ni + 1) = custom_exp2f(res.y);
}
#endif
} }
} }
...@@ -229,8 +265,19 @@ struct Softmax { ...@@ -229,8 +265,19 @@ struct Softmax {
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
float scores_scale = custom_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); float scores_scale = custom_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
row_sum(mi) *= scores_scale; row_sum(mi) *= scores_scale;
#if defined(__gfx928__)
#pragma unroll #pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
#else
using __float2 = __attribute__((ext_vector_type(2))) float;
__float2 scores_scalex2 = {scores_scale, scores_scale};
for (int ni = 0; ni < size<1>(acc_o_rowcol); ni += 2) {
__float2 vx2 = {acc_o_rowcol(mi, ni), acc_o_rowcol(mi, ni + 1)};
__float2 res = __builtin_hcu_pk_mul_f32(vx2, scores_scalex2);
acc_o_rowcol(mi, ni) = res.x;
acc_o_rowcol(mi, ni + 1) = res.y;
}
#endif
} }
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum. // We don't do the reduce across threads here since we don't need to use the row_sum.
...@@ -584,8 +631,19 @@ struct Softmax { ...@@ -584,8 +631,19 @@ struct Softmax {
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#if defined(__gfx928__)
#pragma unroll #pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
#else
using __float2 = __attribute__((ext_vector_type(2))) float;
__float2 scores_scalex2 = {scale, scale};
for (int ni = 0; ni < size<1>(acc_o_rowcol); ni += 2) {
__float2 vx2 = {acc_o_rowcol(mi, ni), acc_o_rowcol(mi, ni + 1)};
__float2 res = __builtin_hcu_pk_mul_f32(vx2, scores_scalex2);
acc_o_rowcol(mi, ni) = res.x;
acc_o_rowcol(mi, ni + 1) = res.y;
}
#endif
} }
return lse; return lse;
}; };
......
This diff is collapsed.
This diff is collapsed.
...@@ -20,7 +20,20 @@ void run_mha_fwd(Flash_fwd_params &params, hipStream_t stream, bool force_split_ ...@@ -20,7 +20,20 @@ void run_mha_fwd(Flash_fwd_params &params, hipStream_t stream, bool force_split_
} }
if (params.seqused_k != nullptr) { if (params.seqused_k != nullptr) {
// Prefix prefill attention // Prefix prefill attention
if (!params.is_int8){ if (params.is_e4m3) {
// FP8 prefix prefill
FP16_SWITCH(!params.is_bf16, [&] {
if (params.d == 128 and params.d_value == 128) {
run_fp8_mha_fwd_prefix_prefill_<elem_type, 128, 128>(params, stream);
} else if (params.d == 192 and params.d_value == 128) {
run_fp8_mha_fwd_prefix_prefill_<elem_type, 192, 128>(params, stream);
} else if (params.d == 256 and params.d_value == 256) {
run_fp8_mha_fwd_prefix_prefill_<elem_type, 256, 256>(params, stream);
} else {
assert(false && "FP8 prefix prefill only supports head_dim=128/128, 192/128, or 256/256");
}
});
} else if (!params.is_int8){
FP16_SWITCH(!params.is_bf16, [&] { FP16_SWITCH(!params.is_bf16, [&] {
if (params.d == 128 and params.d_value == 128) { if (params.d == 128 and params.d_value == 128) {
run_mha_fwd_prefix_prefill_<elem_type, 128, 128>(params, stream); run_mha_fwd_prefix_prefill_<elem_type, 128, 128>(params, stream);
...@@ -65,15 +78,23 @@ void run_mha_fwd(Flash_fwd_params &params, hipStream_t stream, bool force_split_ ...@@ -65,15 +78,23 @@ void run_mha_fwd(Flash_fwd_params &params, hipStream_t stream, bool force_split_
else { else {
// Decoder-only attention // Decoder-only attention
FP16_SWITCH(!params.is_bf16, [&] { FP16_SWITCH(!params.is_bf16, [&] {
#if defined(HEADDIM_128_ONLY) if (params.is_e4m3) {
run_mha_fwd_<elem_type, 128, 128>(params, stream); if (params.d == 128 and params.d_value == 128) {
#elif defined(HEADDIM_192_128_ONLY) run_fp8_mha_fwd_<elem_type, 128, 128>(params, stream);
run_mha_fwd_<elem_type, 192, 128>(params, stream); } else {
#else assert(false && "FP8 forward only supports head_dim=128");
ALL_HEADDIM_SWITCH(params.d, params.d_value, [&] { }
run_mha_fwd_<elem_type, kHeadDimQ, kHeadDimV>(params, stream); } else {
}); #if defined(HEADDIM_128_ONLY)
#endif run_mha_fwd_<elem_type, 128, 128>(params, stream);
#elif defined(HEADDIM_192_128_ONLY)
run_mha_fwd_<elem_type, 192, 128>(params, stream);
#else
ALL_HEADDIM_SWITCH(params.d, params.d_value, [&] {
run_mha_fwd_<elem_type, kHeadDimQ, kHeadDimV>(params, stream);
});
#endif
}
}); });
} }
#endif #endif
...@@ -182,4 +203,4 @@ void run_fwd_prefix_prefill_mla(Flash_fwd_mla_params &params, hipStream_t stream ...@@ -182,4 +203,4 @@ void run_fwd_prefix_prefill_mla(Flash_fwd_mla_params &params, hipStream_t stream
run_mla_fwd_prefix_prefill_dispatch_<elem_type, 576, 512>(params, stream); run_mla_fwd_prefix_prefill_dispatch_<elem_type, 576, 512>(params, stream);
}); });
#endif #endif
} }
\ No newline at end of file
#pragma once
#include <block_info.h>
#include "utils.h"
#include "prefetch.h"
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template<bool Clear_dQaccum=true, bool Is_even_MN, class Element, class ElementAccum, int kBlockM, int kBlockN, int WARP_M, int WARP_N, int K, int STAGES, bool USE_BSHD_LAYOUT, typename Params>
inline __device__ void compute_dot_do_o_gfx946(const Params &params) {
Element *do_ptr = static_cast<Element*>(params.do_ptr);
Element *o_ptr = static_cast<Element*>(params.o_ptr);
ElementAccum* dsoftmax_sum = static_cast<ElementAccum*>(params.dsoftmax_sum);
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.z;
// The block index for the head.
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;
//wave size should be defined in launch file. Here use 64 threads
int lane_id = threadIdx.x & 63; //lane id, 0-63
int warp_id_vec = threadIdx.x / 64; //warp id in a block
int warp_id = 0;
__shared__ Element dO_lds[kBlockM * kBlockN];
__shared__ Element O_lds[kBlockM * kBlockN];
float dP_sum_cur[(kBlockM/16)] = {0.0f};
const int WARP_NUM = (kBlockM)/(WARP_M);
const flash::BlockInfo</*Varlen=*/!Is_even_MN, false, USE_BSHD_LAYOUT> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
int seqlen_do_stride = params.do_row_stride;
int seqlen_o_stride = params.o_row_stride;
const int row_offset_do = binfo.q_offset1(params.do_batch_stride, params.do_row_stride, bidb) + binfo.q_offset2(params.do_head_stride,bidh) + m_block * kBlockM * seqlen_do_stride;
const int row_offset_o = binfo.q_offset1(params.o_batch_stride, params.o_row_stride, bidb) + binfo.q_offset2(params.o_head_stride,bidh) + m_block * kBlockM * seqlen_o_stride;
const int row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
auto gdO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(do_ptr) + row_offset_do, seqlen_do_stride);
auto gO = prepare_for_matrix_load_gfx938<Element>(reinterpret_cast<Element *>(o_ptr) + row_offset_o, seqlen_o_stride);
ElementAccum *dP_sum = reinterpret_cast<ElementAccum *>(dsoftmax_sum) + row_offset_dpsum;
asm volatile("v_readfirstlane_b32 %0,%1"
: "=s"(warp_id)
: "v"(warp_id_vec)
:);
union_vec4_f16x2<Element> dO_reg[((WARP_M*kBlockN)/(32*32))*2];
union_vec4_f16x2<Element> O_reg[((WARP_M*kBlockN)/(32*32))*2];
for(int k_loop=0; k_loop<K/kBlockN; k_loop++) {
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
int do_block_buffer_load_global_offset = k_loop * kBlockN;
//read 32 * 128
prefetch_to_lds_gfx938<true, kBlockM, kBlockN, Element, ElementAccum, Is_even_MN, 1>(gdO, do_block_buffer_load_global_offset, dO_lds, binfo.actual_seqlen_q - m_block * kBlockM, warp_id);
prefetch_to_lds_gfx938<true, kBlockM, kBlockN, Element, ElementAccum, Is_even_MN, 1>(gO, do_block_buffer_load_global_offset, O_lds, binfo.actual_seqlen_q - m_block * kBlockM, warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
for(int i = 0; i < kBlockN / 32; ++i) {
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(dO_lds + i * 32 * 32), dO_reg[i * 2 + 0].f16, dO_reg[i * 2 + 1].f16, true);
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(O_lds + i * 32 * 32), O_reg[i * 2 + 0].f16, O_reg[i * 2 + 1].f16, true);
if constexpr (std::is_same_v<Element, half_t>) {
dO_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(dO_lds + i * 32 * 32, 0, 2, 1, 0);
dO_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(dO_lds + i * 32 * 32, 1024, 2, 1, 0);
O_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(O_lds + i * 32 * 32, 0, 2, 1, 0);
O_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(O_lds + i * 32 * 32, 1024, 2, 1, 0);
} else {
dO_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(dO_lds + i * 32 * 32, 0, 2, 1, 0);
dO_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(dO_lds + i * 32 * 32, 1024, 2, 1, 0);
O_reg[i*2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(O_lds + i * 32 * 32, 0, 2, 1, 0);
O_reg[i*2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(O_lds + i * 32 * 32, 1024, 2, 1, 0);
}
}
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
#pragma unroll
for (int head_dim_idx = 0; head_dim_idx < (kBlockN/32); ++head_dim_idx) {
#pragma unroll
for(int vec_id = 0; vec_id<4; vec_id++) {
#pragma unroll
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
if (Is_even_MN || (m_block * kBlockM + min_tile_m*16 + (threadIdx.x & 15)) < binfo.actual_seqlen_q) {
dP_sum_cur[min_tile_m] += UpCast<Element,float,false>(dO_reg[head_dim_idx*2 + min_tile_m].f16[vec_id * 2 + min_tile_n]) * UpCast<Element,float,false>(O_reg[head_dim_idx*2 + min_tile_m].f16[vec_id * 2 + min_tile_n]);
}
}
}
}
}
}
#pragma unroll
for (int mi = 0; mi < (WARP_M/32); ++mi) {
#pragma unroll
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) {
flash::SumOp<float> sum_op;
dP_sum_cur[mi*2 + min_tile_m] = flash::Allreduce<64>::run(dP_sum_cur[mi*2 + min_tile_m], sum_op) * params.p_dropout;
if ((threadIdx.x >> 4) == 0) {
dP_sum[mi*32 + min_tile_m * 16 + (threadIdx.x & 15)] = dP_sum_cur[mi*2 + min_tile_m];
}
}
}
}
\ No newline at end of file
...@@ -24,12 +24,15 @@ ...@@ -24,12 +24,15 @@
#include "static_switch.h" #include "static_switch.h"
#include "dot_do_o.h" #include "dot_do_o.h"
#include "dot_do_o_gfx938.h" #include "dot_do_o_gfx938.h"
#include "dot_do_o_gfx946.h"
#include "prefetch.h" #include "prefetch.h"
#include "flash_singleton.h" #include "flash_singleton.h"
#include "flash_attention_dv_dk_bwd.h" #include "flash_attention_dv_dk_bwd.h"
#include "flash_attention_dv_dk_bwd_gfx938.h" #include "flash_attention_dv_dk_bwd_gfx938.h"
#include "flash_attention_dv_dk_bwd_gfx946.h"
#include "flash_attention_dq_bwd.h" #include "flash_attention_dq_bwd.h"
#include "flash_attention_dq_bwd_gfx938.h" #include "flash_attention_dq_bwd_gfx938.h"
#include "flash_attention_dq_bwd_gfx946.h"
using std::make_shared; using std::make_shared;
using std::shared_ptr; using std::shared_ptr;
......
This diff is collapsed.
...@@ -234,7 +234,7 @@ __forceinline__ __device__ void compute_dk_dv_1colblock_gfx938(Params &params, i ...@@ -234,7 +234,7 @@ __forceinline__ __device__ void compute_dk_dv_1colblock_gfx938(Params &params, i
auto d1 = (!Is_dropout || p[1] >= 0 ? dp[1] - d[1] : d[1]); auto d1 = (!Is_dropout || p[1] >= 0 ? dp[1] - d[1] : d[1]);
// return vec2_fp32{p[0]*d0,p[1]*d1}; // return vec2_fp32{p[0]*d0,p[1]*d1};
// return __builtin_hcu_pk_mul_f32(p, vec2_fp32{d0, d1}); // return __builtin_hcu_pk_mul_f32(p, vec2_fp32{d0, d1});
return hcu_pk_mul_f32(p, vec2_fp32{d0, d1}); return __builtin_hcu_pk_mul_f32(p, vec2_fp32{d0, d1});
}; };
#else #else
auto pointwise_mult = [](float p, float dp, float d) { auto pointwise_mult = [](float p, float dp, float d) {
...@@ -295,9 +295,12 @@ __forceinline__ __device__ void compute_dk_dv_1colblock_gfx938(Params &params, i ...@@ -295,9 +295,12 @@ __forceinline__ __device__ void compute_dk_dv_1colblock_gfx938(Params &params, i
//提前读取V到vgpr //提前读取V到vgpr
prefetch_to_vgpr_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gV, V_lds, v_reg, (binfo.actual_seqlen_k - n_block * kBlockN_), warp_id); prefetch_to_vgpr_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gV, V_lds, v_reg, (binfo.actual_seqlen_k - n_block * kBlockN_), warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
//提前读取K到vgpr //提前读取K到vgpr
prefetch_to_vgpr_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gK, K_lds, k_reg, (binfo.actual_seqlen_k - n_block * kBlockN_), warp_id); prefetch_to_vgpr_gfx938<true, kBlockN_, K, Element, ElementAccum, Is_even_MN>(gK, K_lds, k_reg, (binfo.actual_seqlen_k - n_block * kBlockN_), warp_id);
__builtin_amdgcn_s_waitcnt(0);
__syncthreads();
//提前读取Q到lds //提前读取Q到lds
if constexpr (Is_preload_Q){ if constexpr (Is_preload_Q){
...@@ -307,8 +310,8 @@ __forceinline__ __device__ void compute_dk_dv_1colblock_gfx938(Params &params, i ...@@ -307,8 +310,8 @@ __forceinline__ __device__ void compute_dk_dv_1colblock_gfx938(Params &params, i
if constexpr (Is_preload_dO){ if constexpr (Is_preload_dO){
prefetch_to_lds_gfx938<true, kBlockM_, K_v, Element, ElementAccum, Is_even_MN>(gdO, 0, dO_lds, (binfo.actual_seqlen_q - (m_block_max - 1) * kBlockM_), warp_id); prefetch_to_lds_gfx938<true, kBlockM_, K_v, Element, ElementAccum, Is_even_MN>(gdO, 0, dO_lds, (binfo.actual_seqlen_q - (m_block_max - 1) * kBlockM_), warp_id);
} }
// __builtin_amdgcn_s_waitcnt(0); __builtin_amdgcn_s_waitcnt(0);
// __syncthreads(); __syncthreads();
union_vec4_fp32 acc_dv[(K_v/kBlockK_) * ((WARP_N_/32)*(kBlockK_/32))][4]={0}; union_vec4_fp32 acc_dv[(K_v/kBlockK_) * ((WARP_N_/32)*(kBlockK_/32))][4]={0};
......
...@@ -410,13 +410,11 @@ __forceinline__ __device__ void gemm_tt_kq_gfx938( ...@@ -410,13 +410,11 @@ __forceinline__ __device__ void gemm_tt_kq_gfx938(
int A_lds_stage_offset = stage_id * BLOCK_M * BLOCK_K; int A_lds_stage_offset = stage_id * BLOCK_M * BLOCK_K;
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(A_lds + A_lds_stage_offset), A_reg_tmp[0].f16, A_reg_tmp[1].f16, true); // DS_READ_MATRIX_32X32_B16(ds_offset_cast(A_lds + A_lds_stage_offset), A_reg_tmp[0].f16, A_reg_tmp[1].f16, true);
if constexpr (std::is_same_v<Element, half_t>) { if constexpr (std::is_same_v<Element, half_t>) {
auto *const f16_lds = hcu_ds_read_matrix_f16_lds_base(A_lds + A_lds_stage_offset); A_reg_tmp[0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(A_lds + A_lds_stage_offset, 0, 2, 1, 0);
A_reg_tmp[0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(f16_lds, 0, 2, 1, 0); A_reg_tmp[1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(A_lds + A_lds_stage_offset, 1024, 2, 1, 0);
A_reg_tmp[1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_f16(f16_lds, 1024, 2, 1, 0);
} else { } else {
auto *const bf16_lds = hcu_ds_read_matrix_bf16_lds_base(A_lds + A_lds_stage_offset); A_reg_tmp[0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(A_lds + A_lds_stage_offset, 0, 2, 1, 0);
A_reg_tmp[0].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(bf16_lds, 0, 2, 1, 0); A_reg_tmp[1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(A_lds + A_lds_stage_offset, 1024, 2, 1, 0);
A_reg_tmp[1].f16x8 = __builtin_hcu_ds_read_matrix_trans_format_bf16(bf16_lds, 1024, 2, 1, 0);
} }
} }
int B_lds_stage_offset = stage_id * WARP_N * BLOCK_K; int B_lds_stage_offset = stage_id * WARP_N * BLOCK_K;
......
This diff is collapsed.
...@@ -57,6 +57,7 @@ inline __device__ void apply_mask_bwd(union_vec4_fp32 tensor[1][4], int M, int N ...@@ -57,6 +57,7 @@ inline __device__ void apply_mask_bwd(union_vec4_fp32 tensor[1][4], int M, int N
} }
} }
} }
//local mask //local mask
if (mask_type == 3) {// && (!Is_even_MN || Is_even_MN && (std::abs(M_minus_N - window_size_left) < 128 || std::abs(M_minus_N + window_size_right) < 128)) if (mask_type == 3) {// && (!Is_even_MN || Is_even_MN && (std::abs(M_minus_N - window_size_left) < 128 || std::abs(M_minus_N + window_size_right) < 128))
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) { for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
...@@ -112,21 +113,38 @@ inline __device__ void apply_mask_bwd_gfx938(union_vec4_fp32 tensor[1][4], int M ...@@ -112,21 +113,38 @@ inline __device__ void apply_mask_bwd_gfx938(union_vec4_fp32 tensor[1][4], int M
} }
} }
} }
// //mask左下角
// if (mask_type == 2 && (!Is_even_MN || Is_even_MN && std::abs(M_minus_N) < 128)) {
// for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
// int M_offset = min_tile_m * 16 + lane_m_idx;
// for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
// for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
// int N_offset = min_tile_n * 16 + lane_n_idx * 4 + vec_idx;
// int N_limit = (M_offset + M_minus_N);
// if((!Is_even_MN && N_offset > N - 1) || N_offset < N_limit){
// tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
// }
// }
// }
// }
// }
//mask左下角 //mask左下角
if (mask_type == 2 && (!Is_even_MN || Is_even_MN && std::abs(M_minus_N) < 128)) { if (mask_type == 2 ) {
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) { for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
int M_offset = min_tile_m * 16 + lane_m_idx; int M_offset = min_tile_m * 16 + lane_m_idx;
for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) { for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
for(int vec_idx = 0; vec_idx < 4; vec_idx ++) { for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
int N_offset = min_tile_n * 16 + lane_n_idx * 4 + vec_idx; int N_offset = min_tile_n * 16 + lane_n_idx * 4 + vec_idx;
int N_limit = (M_offset + M_minus_N); int N_limit = (M_offset + M_minus_N);
if((!Is_even_MN && N_offset > N - 1) || N_offset < N_limit){ if(N_offset < N_limit){
tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY; tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
} }
} }
} }
} }
} }
//local mask //local mask
if (mask_type == 3) {// && (!Is_even_MN || Is_even_MN && (std::abs(M_minus_N - window_size_left) < 128 || std::abs(M_minus_N + window_size_right) < 128)) if (mask_type == 3) {// && (!Is_even_MN || Is_even_MN && (std::abs(M_minus_N - window_size_left) < 128 || std::abs(M_minus_N + window_size_right) < 128))
for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) { for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
...@@ -327,7 +345,7 @@ inline __device__ void scale_apply_exp2_bwd(DataType0 tensor[(BLOCK_M/32)*(WARP_ ...@@ -327,7 +345,7 @@ inline __device__ void scale_apply_exp2_bwd(DataType0 tensor[(BLOCK_M/32)*(WARP_
auto vec2_scale = vec2_fp32{scale, scale}; auto vec2_scale = vec2_fp32{scale, scale};
auto vec2_max_scaled = vec2_fp32{-max_scaled, -max_scaled}; auto vec2_max_scaled = vec2_fp32{-max_scaled, -max_scaled};
auto tensor_tmp = auto tensor_tmp =
hcu_pk_fma_f32( __builtin_hcu_pk_fma_f32(
vec2_tensor, vec2_tensor,
vec2_scale, vec2_scale,
vec2_max_scaled); vec2_max_scaled);
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment