Commit 518a5f4d authored by hly's avatar hly
Browse files

import aicc-master-dev

parent c2a1b310
...@@ -5,4 +5,3 @@ ...@@ -5,4 +5,3 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template void run_mha_fwd_unified_dispatch<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream); template void run_mha_fwd_unified_dispatch<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);
...@@ -5,4 +5,3 @@ ...@@ -5,4 +5,3 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template void run_mha_fwd_unified_dispatch<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream); template void run_mha_fwd_unified_dispatch<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);
...@@ -5,4 +5,3 @@ ...@@ -5,4 +5,3 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template void run_mha_fwd_unified_dispatch<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream); template void run_mha_fwd_unified_dispatch<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);
...@@ -5,4 +5,3 @@ ...@@ -5,4 +5,3 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template void run_mha_fwd_unified_dispatch<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream); template void run_mha_fwd_unified_dispatch<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);
...@@ -48,7 +48,7 @@ static __device__ inline float to_float(scalar_t in){ ...@@ -48,7 +48,7 @@ static __device__ inline float to_float(scalar_t in){
inline __device__ float uint82float(const uint8_t& input) { inline __device__ float uint82float(const uint8_t& input) {
#if (defined(__gfx938__) ) #if (defined(__gfx938__) ||defined(__gfx92a__))
return __builtin_hcu_cvt_f32_fp8(input,false,0,0); return __builtin_hcu_cvt_f32_fp8(input,false,0,0);
#else #else
const uint32_t w = (uint32_t)input << 24; const uint32_t w = (uint32_t)input << 24;
...@@ -137,11 +137,11 @@ __forceinline__ __device__ scalar_t uint82half(const uint8_t& input) { ...@@ -137,11 +137,11 @@ __forceinline__ __device__ scalar_t uint82half(const uint8_t& input) {
#define REUSEKV_SWITCH(reusekv,...) \ #define REUSEKV_SWITCH(reusekv,...) \
[&] { \ [&] { \
if (reusekv==48){ \ if (reusekv==64){ \
constexpr static int REUSE_KV_TIMES = 48; \ constexpr static int REUSE_KV_TIMES = 64; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
}else if (reusekv==36){ \ }else if (reusekv==48){ \
constexpr static int REUSE_KV_TIMES = 36; \ constexpr static int REUSE_KV_TIMES = 48; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
}else if (reusekv==32){ \ }else if (reusekv==32){ \
constexpr static int REUSE_KV_TIMES = 32; \ constexpr static int REUSE_KV_TIMES = 32; \
...@@ -257,18 +257,11 @@ inline __device__ float block_sum(float* red_smem, float sum) { ...@@ -257,18 +257,11 @@ inline __device__ float block_sum(float* red_smem, float sum) {
template<bool is_half> template<bool is_half>
inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c) inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c)
{ {
#if (defined(__gfx938__) ) if constexpr (is_half){
if constexpr (is_half){reg_c=__builtin_hcu_mmac_f32_16x16x16_f16_lit_lts(reg_a,reg_b,reg_c,false,false);} reg_c=__builtin_hcu_mmac_f32_16x16x16_f16(reg_a,reg_b,reg_c);
else{ }else{
reg_c=__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(*(v4bh*)&reg_a,*(v4bh*)&reg_b,reg_c,false,false); reg_c=__builtin_hcu_mmac_f32_16x16x16_bf16(*(v4bh*)&reg_a,*(v4bh*)&reg_b,reg_c);
}
#else
if constexpr (is_half){reg_c=__builtin_amdgcn_mmac_f32_16x16x16f16(reg_a,reg_b,reg_c);}
else{
reg_c=__builtin_amdgcn_mmac_f32_16x16x16bf16(*(v4bh*)&reg_a,*(v4bh*)&reg_b,reg_c);
} }
#endif
} }
...@@ -390,23 +383,28 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -390,23 +383,28 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
for(int m=0;m<Mloop;m++){ for(int m=0;m<Mloop;m++){
qk_vec[m]={0,0,0,0}; qk_vec[m]={0,0,0,0};
} }
half4x2 k_vec[HEAD_SIZE/32];
__builtin_amdgcn_sched_barrier(0);
#pragma unroll #pragma unroll
for(int i=0;i<HEAD_SIZE/32;i++){ for(int i=0;i<HEAD_SIZE/32;i++){
half4x2 k_vec;
if constexpr(is_fp8){ if constexpr(is_fp8){
uint8x4x2 k_vec_u8=*reinterpret_cast<const uint8x4x2*>(k_ptr+i*32+rowid*HEAD_SIZE+rows*8); uint8x4x2 k_vec_u8=*reinterpret_cast<const uint8x4x2*>(k_ptr+i*32+rowid*HEAD_SIZE+rows*8);
scalar_t *p1=(scalar_t*)&k_vec; scalar_t *p1=(scalar_t*)(k_vec+i);
uint8_t *p2=(uint8_t*)&k_vec_u8; uint8_t *p2=(uint8_t*)&k_vec_u8;
for(int ii=0;ii<8;ii++){ for(int ii=0;ii<8;ii++){
p1[ii]=uint82half<scalar_t,is_e4m3>(p2[ii]); p1[ii]=uint82half<scalar_t,is_e4m3>(p2[ii]);
} }
} }
else{ else{
k_vec=*reinterpret_cast<const half4x2*>(k_ptr+i*32+rowid*HEAD_SIZE+rows*8); k_vec[i]=*reinterpret_cast<const half4x2*>(k_ptr+i*32+rowid*HEAD_SIZE+rows*8);
}
} }
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int i=0;i<HEAD_SIZE/32;i++){
for(int m=0;m<Mloop;m++){ for(int m=0;m<Mloop;m++){
builtin_amdgcn_mmac<is_half>(k_vec.data[0],q_vec[m][i].data[0],qk_vec[m]); builtin_amdgcn_mmac<is_half>(k_vec[i].data[0],q_vec[m][i].data[0],qk_vec[m]);
builtin_amdgcn_mmac<is_half>(k_vec.data[1],q_vec[m][i].data[1],qk_vec[m]); builtin_amdgcn_mmac<is_half>(k_vec[i].data[1],q_vec[m][i].data[1],qk_vec[m]);
} }
} }
#pragma unroll #pragma unroll
...@@ -597,7 +595,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -597,7 +595,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
} }
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride; kv_head_idx * kv_head_stride;
if(partition_idx<num_partitions-1){ if(partition_idx<num_partitions-1||block_idx < num_seq_blocks-1){
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
int offset=i*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD+warp_idx*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD/NUM_WARPS+rows*vecsize*4+rowid*BLOCK_SIZE; int offset=i*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD+warp_idx*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD/NUM_WARPS+rows*vecsize*4+rowid*BLOCK_SIZE;
...@@ -635,13 +633,11 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -635,13 +633,11 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
v_vec=*reinterpret_cast<const half4_vec*>(v_ptr + offset); v_vec=*reinterpret_cast<const half4_vec*>(v_ptr + offset);
} }
//这里的if判断会影响一定的性能,因此只有最后一个patition才判断 //这里的if判断会影响一定的性能,因此只有最后一个patition才判断
if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll #pragma unroll
for (int j = 0; j < 4*vecsize; j++) { for (int j = 0; j < 4*vecsize; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : 0; v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : 0;
} }
}
for(int ii=0;ii<vecsize;ii++){ for(int ii=0;ii<vecsize;ii++){
for(int m=0;m<Mloop;m++){ for(int m=0;m<Mloop;m++){
builtin_amdgcn_mmac<is_half>(v_vec.data[ii],logits_vec[m].data[ii],accs[m][i]); builtin_amdgcn_mmac<is_half>(v_vec.data[ii],logits_vec[m].data[ii],accs[m][i]);
...@@ -756,8 +752,8 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_combine( ...@@ -756,8 +752,8 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_combine(
} }
static int get_reusekv(int qhead,int kv_head){ static int get_reusekv(int qhead,int kv_head){
if(qhead>kv_head*36) return 48; if(qhead>kv_head*48) return 64;
if(qhead>kv_head*32) return 36;//glm4.7 mtp 3 if(qhead>kv_head*32) return 48;
if(qhead>kv_head*24) return 32; if(qhead>kv_head*24) return 32;
if(qhead>kv_head*16) return 24; if(qhead>kv_head*16) return 24;
if(qhead>kv_head*8) return 16; if(qhead>kv_head*8) return 16;
...@@ -831,7 +827,7 @@ void paged_attention( ...@@ -831,7 +827,7 @@ void paged_attention(
hipMalloc(&tmp_out_ptr, temp_out_size); // 100m hipMalloc(&tmp_out_ptr, temp_out_size); // 100m
hipMemset(tmp_out_ptr,0,temp_out_size); hipMemset(tmp_out_ptr,0,temp_out_size);
} }
if(device_name=="gfx938"&&(key_cache.dtype()==torch::kFloat8_e5m2||key_cache.dtype()==torch::kFloat8_e4m3fn)){ if((device_name=="gfx938"|| device_name == "gfx92a")&&(key_cache.dtype()==torch::kFloat8_e5m2||key_cache.dtype()==torch::kFloat8_e4m3fn)){
paged_attention_938(out,query,key_cache,value_cache,block_tables,seq_lens,alibi_slopes,q_scale,k_scale,v_scale,max_seq_len,s_aux_,tmp_out_ptr,PARTITION_SIZE); paged_attention_938(out,query,key_cache,value_cache,block_tables,seq_lens,alibi_slopes,q_scale,k_scale,v_scale,max_seq_len,s_aux_,tmp_out_ptr,PARTITION_SIZE);
return; return;
} }
...@@ -876,16 +872,16 @@ void paged_attention( ...@@ -876,16 +872,16 @@ void paged_attention(
grid.x = num_kv_heads; grid.x = num_kv_heads;
grid.y = num_seqs; grid.y = num_seqs;
AT_ASSERTM(headsize%64==0 && headsize<=256, "Page Attention head size must be 64, 128, 192 or 256"); AT_ASSERTM(headsize%64==0 && headsize<=256, "Page Attention head size must be 64, 128, 192 or 256");
AT_ASSERTM(num_heads<=num_kv_heads*48, "Page Attention qheads*mtp/kvheads must be smaller than 48"); AT_ASSERTM(num_heads<=num_kv_heads*64, "Page Attention qheads*mtp/kvheads must be smaller than 48");
HEADSIZE_SWITCH(headsize,[&]{ HEADSIZE_SWITCH(headsize,[&]{
Input_Type_SWITCH(query.dtype(),[&]{ Input_Type_SWITCH(query.dtype(),[&]{
Cache_Type_SWITCH(scalar_t,key_cache.dtype(),[&] { Cache_Type_SWITCH(scalar_t,key_cache.dtype(),[&] {
REUSEKV_SWITCH(reusekv,[&] { REUSEKV_SWITCH(reusekv,[&] {
BOOL_SWITCH(block_size==64,is_block64,[&]{ BOOL_SWITCH(block_size==64,is_block64,[&]{
constexpr int BLOCK_SIZE = (is_block64?64:128); // constexpr int BLOCK_SIZE = (is_block64?64:128);
// constexpr int BLOCK_SIZE=128; constexpr int BLOCK_SIZE=64;
// constexpr int HEAD_SIZE=128; // constexpr int HEAD_SIZE=128;
// using scalar_t=_Float16; // using scalar_t=uint16_t;
// using cache_t = scalar_t; // using cache_t = scalar_t;
constexpr bool is_e4m3=false; constexpr bool is_e4m3=false;
// constexpr static int REUSE_KV_TIMES = 4; // constexpr static int REUSE_KV_TIMES = 4;
......
...@@ -58,7 +58,7 @@ static __device__ inline float to_float(scalar_t in){ ...@@ -58,7 +58,7 @@ static __device__ inline float to_float(scalar_t in){
inline __device__ float uint82float(const uint8_t& input) { inline __device__ float uint82float(const uint8_t& input) {
#if (defined(__gfx938__) ) #if (defined(__gfx938__)||defined(__gfx92a__) )
return __builtin_hcu_cvt_f32_fp8(input,false,0,0); return __builtin_hcu_cvt_f32_fp8(input,false,0,0);
#else #else
const uint32_t w = (uint32_t)input << 24; const uint32_t w = (uint32_t)input << 24;
...@@ -106,7 +106,7 @@ __forceinline__ __device__ scalar_t uint82half(const uint8_t& input) { ...@@ -106,7 +106,7 @@ __forceinline__ __device__ scalar_t uint82half(const uint8_t& input) {
template <bool is_e4m3> template <bool is_e4m3>
static __device__ int to_f8_from_f32(float v1,float v2,float v3,float v4) { static __device__ int to_f8_from_f32(float v1,float v2,float v3,float v4) {
int val=0; int val=0;
#if (defined(__gfx938__) ) #if (defined(__gfx938__) || defined(__gfx92a__))
if constexpr(is_e4m3){ if constexpr(is_e4m3){
val = __builtin_hcu_cvt_pk_fp8_f32(v1,v2,val,false); val = __builtin_hcu_cvt_pk_fp8_f32(v1,v2,val,false);
val = __builtin_hcu_cvt_pk_fp8_f32(v3,v4,val,true); val = __builtin_hcu_cvt_pk_fp8_f32(v3,v4,val,true);
...@@ -122,7 +122,7 @@ static __device__ int to_f8_from_f32(float v1,float v2,float v3,float v4) { ...@@ -122,7 +122,7 @@ static __device__ int to_f8_from_f32(float v1,float v2,float v3,float v4) {
template <bool is_e4m3> template <bool is_e4m3>
static __device__ float4_t to_fp32_from_fp8(int val) { static __device__ float4_t to_fp32_from_fp8(int val) {
float4_t ret; float4_t ret;
#if (defined(__gfx938__) ) #if (defined(__gfx938__) || defined(__gfx92a__))
if constexpr(is_e4m3){ if constexpr(is_e4m3){
ret[0] = __builtin_hcu_cvt_f32_fp8(val,false,0,0); ret[0] = __builtin_hcu_cvt_f32_fp8(val,false,0,0);
ret[1] = __builtin_hcu_cvt_f32_fp8(val,false,0,1); ret[1] = __builtin_hcu_cvt_f32_fp8(val,false,0,1);
...@@ -184,11 +184,11 @@ static __device__ float4_t to_fp32_from_fp8(int val) { ...@@ -184,11 +184,11 @@ static __device__ float4_t to_fp32_from_fp8(int val) {
#define REUSEKV_SWITCH(reusekv,...) \ #define REUSEKV_SWITCH(reusekv,...) \
[&] { \ [&] { \
if (reusekv==48){ \ if (reusekv==64){ \
constexpr static int REUSE_KV_TIMES = 48; \ constexpr static int REUSE_KV_TIMES = 64; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
}else if (reusekv==36){ \ }else if (reusekv==48){ \
constexpr static int REUSE_KV_TIMES = 36; \ constexpr static int REUSE_KV_TIMES = 48; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
}else if (reusekv==32){ \ }else if (reusekv==32){ \
constexpr static int REUSE_KV_TIMES = 32; \ constexpr static int REUSE_KV_TIMES = 32; \
...@@ -303,13 +303,11 @@ inline __device__ float block_sum(float* red_smem, float sum) { ...@@ -303,13 +303,11 @@ inline __device__ float block_sum(float* red_smem, float sum) {
template<bool is_e4m3> template<bool is_e4m3>
inline __device__ void builtin_amdgcn_mmac(const intx2& reg_a, const intx2& reg_b, float4_t& reg_c) inline __device__ void builtin_amdgcn_mmac(const intx2& reg_a, const intx2& reg_b, float4_t& reg_c)
{ {
#if (defined(__gfx938__) )
if constexpr(is_e4m3){ if constexpr(is_e4m3){
reg_c=__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(reg_a,reg_b,reg_c,false,false); reg_c=__builtin_hcu_mmac_f32_16x16x32_fp8_fp8(reg_a,reg_b,reg_c);
}else{ }else{
reg_c=__builtin_hcu_mmac_f32_16x16x32_bf8_bf8_lit_lts(reg_a,reg_b,reg_c,false,false); reg_c=__builtin_hcu_mmac_f32_16x16x32_bf8_bf8(reg_a,reg_b,reg_c);
} }
#endif
} }
template <typename scalar_t,typename q_type,bool is_e4m3 ,int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t,typename q_type,bool is_e4m3 ,int HEAD_SIZE, int BLOCK_SIZE,
...@@ -332,7 +330,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -332,7 +330,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
const float* q_scale_ptr, const float* k_scale_ptr, const float* v_scale_ptr, const float* q_scale_ptr, const float* k_scale_ptr, const float* v_scale_ptr,
int max_num_partitions,int PARTITION_SIZE, int max_num_partitions,int PARTITION_SIZE,
const scalar_t* __restrict__ s_aux_ptr,int mtp,bool has_abili) { // ★ Attention Sinks: [num_heads] scalar_t ★ const scalar_t* __restrict__ s_aux_ptr,int mtp,bool has_abili) { // ★ Attention Sinks: [num_heads] scalar_t ★
#if (defined(__gfx938__) ) #if (defined(__gfx938__)||defined(__gfx92a__) )
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z; const int partition_idx = blockIdx.z;
constexpr int kv_head_stride=BLOCK_SIZE*HEAD_SIZE; constexpr int kv_head_stride=BLOCK_SIZE*HEAD_SIZE;
...@@ -453,10 +451,16 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -453,10 +451,16 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
for(int m=0;m<Mloop;m++){ for(int m=0;m<Mloop;m++){
qk_vec[m]={0,0,0,0}; qk_vec[m]={0,0,0,0};
} }
intx4 k_vec[HEAD_SIZE/64];
__builtin_amdgcn_sched_barrier(0);
#pragma unroll #pragma unroll
for(int i=0;i<HEAD_SIZE/64;i++){ for(int i=0;i<HEAD_SIZE/64;i++){
intx4 k_vec=*reinterpret_cast<const intx4*>(k_ptr+i*64+rowid*HEAD_SIZE+rows*16); k_vec[i]=*reinterpret_cast<const intx4*>(k_ptr+i*64+rowid*HEAD_SIZE+rows*16);
intx2 *k_vec_2 = (intx2*)&k_vec; }
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for(int i=0;i<HEAD_SIZE/64;i++){
intx2 *k_vec_2 = (intx2*)(k_vec+i);
for(int m=0;m<Mloop;m++){ for(int m=0;m<Mloop;m++){
intx2 *q_vec_2 = (intx2*)(&q_vec[m][i]); intx2 *q_vec_2 = (intx2*)(&q_vec[m][i]);
builtin_amdgcn_mmac<is_e4m3>(k_vec_2[0],q_vec_2[0],qk_vec[m]); builtin_amdgcn_mmac<is_e4m3>(k_vec_2[0],q_vec_2[0],qk_vec[m]);
...@@ -655,7 +659,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -655,7 +659,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
} }
const uint8_t* v_ptr = v_cache + physical_block_number * kv_block_stride + const uint8_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride; kv_head_idx * kv_head_stride;
if(partition_idx<num_partitions-1){ if(partition_idx<num_partitions-1||block_idx < num_seq_blocks-1){
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
int offset=i*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD+warp_idx*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD/NUM_WARPS+rows*16+rowid*BLOCK_SIZE; int offset=i*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD+warp_idx*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD/NUM_WARPS+rows*16+rowid*BLOCK_SIZE;
...@@ -673,13 +677,11 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -673,13 +677,11 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
int offset=i*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD+warp_idx*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD/NUM_WARPS+rows*16+rowid*BLOCK_SIZE; int offset=i*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD+warp_idx*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD/NUM_WARPS+rows*16+rowid*BLOCK_SIZE;
int_vec v_vec = *reinterpret_cast<const int_vec*>(v_ptr + offset); int_vec v_vec = *reinterpret_cast<const int_vec*>(v_ptr + offset);
//这里的if判断会影响一定的性能,因此只有最后一个patition才判断 //这里的if判断会影响一定的性能,因此只有最后一个patition才判断
if (block_idx == num_seq_blocks - 1) {
uint8_t* v_vec_ptr = reinterpret_cast<uint8_t*>(&v_vec); uint8_t* v_vec_ptr = reinterpret_cast<uint8_t*>(&v_vec);
#pragma unroll #pragma unroll
for (int j = 0; j < 16; j++) { for (int j = 0; j < 16; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : 0; v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : 0;
} }
}
for(int ii=0;ii<vecsize;ii++){ for(int ii=0;ii<vecsize;ii++){
for(int m=0;m<Mloop;m++){ for(int m=0;m<Mloop;m++){
builtin_amdgcn_mmac<is_e4m3>(v_vec.data[ii],logits_vec[m][ii],accs[m][i]); builtin_amdgcn_mmac<is_e4m3>(v_vec.data[ii],logits_vec[m][ii],accs[m][i]);
...@@ -795,8 +797,8 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_combine( ...@@ -795,8 +797,8 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_combine(
} }
static int get_reusekv(int qhead,int kv_head){ static int get_reusekv(int qhead,int kv_head){
if(qhead>kv_head*36) return 48; if(qhead>kv_head*48) return 64;
if(qhead>kv_head*32) return 36;//glm4.7 mtp 3 if(qhead>kv_head*32) return 48;
if(qhead>kv_head*24) return 32; if(qhead>kv_head*24) return 32;
if(qhead>kv_head*16) return 24; if(qhead>kv_head*16) return 24;
if(qhead>kv_head*8) return 16; if(qhead>kv_head*8) return 16;
...@@ -870,17 +872,18 @@ void paged_attention_938( ...@@ -870,17 +872,18 @@ void paged_attention_938(
int reusekv=get_reusekv(num_heads,num_kv_heads); int reusekv=get_reusekv(num_heads,num_kv_heads);
int headsize=query.size(3); int headsize=query.size(3);
AT_ASSERTM(headsize%64==0 && headsize<=256, "Page Attention head size must be 64, 128, 192 or 256"); AT_ASSERTM(headsize%64==0 && headsize<=256, "Page Attention head size must be 64, 128, 192 or 256");
AT_ASSERTM(num_heads<=num_kv_heads*48, "Page Attention qheads*mtp/kvheads must be smaller than 48"); AT_ASSERTM(num_heads<=num_kv_heads*64, "Page Attention qheads*mtp/kvheads must be smaller than 48");
HEADSIZE_SWITCH(headsize,[&]{ HEADSIZE_SWITCH(headsize,[&]{
Output_Type_SWITCH(out.dtype(),[&]{ Output_Type_SWITCH(out.dtype(),[&]{
Input_Type_SWITCH(scalar_t,query.dtype(),key_cache.dtype(),[&] { Input_Type_SWITCH(scalar_t,query.dtype(),key_cache.dtype(),[&] {
REUSEKV_SWITCH(reusekv,[&] { REUSEKV_SWITCH(reusekv,[&] {
BOOL_SWITCH(block_size==64,is_block64,[&]{ BOOL_SWITCH(block_size==64,is_block64,[&]{
constexpr int BLOCK_SIZE = (is_block64?64:128); // constexpr int BLOCK_SIZE = (is_block64?64:128);
// constexpr int HEAD_SIZE=128; constexpr int BLOCK_SIZE=64;
// constexpr int HEAD_SIZE=256;
// using scalar_t=uint16_t; // using scalar_t=uint16_t;
// constexpr bool is_e4m3=true; // constexpr bool is_e4m3=true;
// constexpr static int REUSE_KV_TIMES = 4; // constexpr static int REUSE_KV_TIMES = 64;
// constexpr bool has_abili=false; // constexpr bool has_abili=false;
// constexpr bool use_mtp=false; // constexpr bool use_mtp=false;
constexpr static int NUM_THREADS = 256; constexpr static int NUM_THREADS = 256;
......
...@@ -26,6 +26,7 @@ __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &t ...@@ -26,6 +26,7 @@ __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &t
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) { for (int mi = 0; mi < size<0>(tensor); mi++) {
#if defined(__gfx928__)
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll #pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) { for (int ni = 1; ni < size<1>(tensor); ni++) {
...@@ -36,6 +37,29 @@ __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &t ...@@ -36,6 +37,29 @@ __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &t
// printf("thread_reduce_ mi:%d ni:%d %7.4f %7.4f %7.4f\n", mi, ni, ori, tensor(mi, ni), summary(mi)); // printf("thread_reduce_ mi:%d ni:%d %7.4f %7.4f %7.4f\n", mi, ni, ori, tensor(mi, ni), summary(mi));
// } // }
} }
#else
if constexpr (std::is_same_v<Operator, SumOp<float>>) {
using __float2 = __attribute__((ext_vector_type(2))) float;
__float2 sum_v = {zero_init ? 0.0f : summary(mi), 0.0f};
for (int ni = 0; ni < size<1>(tensor); ni += 2) {
__float2 vx2 = {tensor(mi, ni), tensor(mi, ni + 1)};
sum_v = __builtin_hcu_pk_add_f32(sum_v, vx2);
}
summary(mi) = sum_v.x + sum_v.y;
}
else {
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
// float ori = summary(mi);
summary(mi) = op(summary(mi), tensor(mi, ni));
// wangaq debug
// if (thread0()) {
// printf("thread_reduce_ mi:%d ni:%d %7.4f %7.4f %7.4f\n", mi, ni, ori, tensor(mi, ni), summary(mi));
// }
}
}
#endif
} }
} }
...@@ -131,6 +155,7 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tenso ...@@ -131,6 +155,7 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tenso
// We don't want (-inf - (-inf)) since that would give NaN. // We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64. // If we don't have float around M_LOG2E the multiplication is done in fp64.
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
#if defined(__gfx928__)
#pragma unroll #pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) { for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) - // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
...@@ -141,6 +166,17 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tenso ...@@ -141,6 +166,17 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tenso
// This macro is set in PyTorch and not FlashAttention // This macro is set in PyTorch and not FlashAttention
tensor(mi, ni) = custom_exp2f(tensor(mi, ni) * scale - max_scaled); tensor(mi, ni) = custom_exp2f(tensor(mi, ni) * scale - max_scaled);
} }
#else
using __float2 = __attribute__((ext_vector_type(2))) float;
__float2 scalex2 = {scale, scale};
__float2 max_scaledx2 = {-max_scaled, -max_scaled};
for (int ni = 0; ni < size<1>(tensor); ni += 2) {
__float2 vx2 = {tensor(mi, ni), tensor(mi, ni + 1)};
__float2 res = __builtin_hcu_pk_fma_f32(vx2, scalex2, max_scaledx2);
tensor(mi, ni) = custom_exp2f(res.x);
tensor(mi, ni + 1) = custom_exp2f(res.y);
}
#endif
} }
} }
...@@ -229,8 +265,19 @@ struct Softmax { ...@@ -229,8 +265,19 @@ struct Softmax {
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
float scores_scale = custom_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); float scores_scale = custom_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
row_sum(mi) *= scores_scale; row_sum(mi) *= scores_scale;
#if defined(__gfx928__)
#pragma unroll #pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
#else
using __float2 = __attribute__((ext_vector_type(2))) float;
__float2 scores_scalex2 = {scores_scale, scores_scale};
for (int ni = 0; ni < size<1>(acc_o_rowcol); ni += 2) {
__float2 vx2 = {acc_o_rowcol(mi, ni), acc_o_rowcol(mi, ni + 1)};
__float2 res = __builtin_hcu_pk_mul_f32(vx2, scores_scalex2);
acc_o_rowcol(mi, ni) = res.x;
acc_o_rowcol(mi, ni + 1) = res.y;
}
#endif
} }
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum. // We don't do the reduce across threads here since we don't need to use the row_sum.
...@@ -584,8 +631,19 @@ struct Softmax { ...@@ -584,8 +631,19 @@ struct Softmax {
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#if defined(__gfx928__)
#pragma unroll #pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
#else
using __float2 = __attribute__((ext_vector_type(2))) float;
__float2 scores_scalex2 = {scale, scale};
for (int ni = 0; ni < size<1>(acc_o_rowcol); ni += 2) {
__float2 vx2 = {acc_o_rowcol(mi, ni), acc_o_rowcol(mi, ni + 1)};
__float2 res = __builtin_hcu_pk_mul_f32(vx2, scores_scalex2);
acc_o_rowcol(mi, ni) = res.x;
acc_o_rowcol(mi, ni + 1) = res.y;
}
#endif
} }
return lse; return lse;
}; };
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -24,12 +24,15 @@ ...@@ -24,12 +24,15 @@
#include "static_switch.h" #include "static_switch.h"
#include "dot_do_o.h" #include "dot_do_o.h"
#include "dot_do_o_gfx938.h" #include "dot_do_o_gfx938.h"
#include "dot_do_o_gfx946.h"
#include "prefetch.h" #include "prefetch.h"
#include "flash_singleton.h" #include "flash_singleton.h"
#include "flash_attention_dv_dk_bwd.h" #include "flash_attention_dv_dk_bwd.h"
#include "flash_attention_dv_dk_bwd_gfx938.h" #include "flash_attention_dv_dk_bwd_gfx938.h"
#include "flash_attention_dv_dk_bwd_gfx946.h"
#include "flash_attention_dq_bwd.h" #include "flash_attention_dq_bwd.h"
#include "flash_attention_dq_bwd_gfx938.h" #include "flash_attention_dq_bwd_gfx938.h"
#include "flash_attention_dq_bwd_gfx946.h"
using std::make_shared; using std::make_shared;
using std::shared_ptr; using std::shared_ptr;
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment