#include #include #include #include #define WARP_SIZE 64 #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) // Input validation macros (consistent with flash_api.cpp and flash_api_sparse.cpp) #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") static constexpr int LDS_size = 65536; static constexpr int max_tmp_offset=4000000; static constexpr int signal_tmp_offset=8000000; static constexpr int streamk_max_block=160*8; static constexpr int out_tmp_offset=signal_tmp_offset+streamk_max_block*2; // static constexpr int PARTITION_SIZE=512; #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) template static __device__ inline void from_float(scalar_t &out ,float f){ if constexpr(std::is_same::value||std::is_same::value){ out=f; } else{ uint32_t u = *(uint32_t*)(&f); // u += 0x7fff + ((u >> 16) & 1); u += 0x8000; out = u>>16; } } template static __device__ inline float to_float(scalar_t in){ if constexpr(std::is_same::value||std::is_same::value){ return in; } else{ union{ uint32_t int32; float fp32; } u = {uint32_t(in) << 16}; return u.fp32; } } inline __device__ float uint82float(const uint8_t& input) { #if (defined(__gfx938__) ||defined(__gfx92a__)) return __builtin_hcu_cvt_f32_fp8(input,false,0,0); #else const uint32_t w = (uint32_t)input << 24; const uint32_t sign = w & UINT32_C(0x80000000); const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); uint32_t renorm_shift = __clz(nonsign); renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0; uint32_t result = sign | ((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)); return c10::detail::fp32_from_bits(result); #endif } template __forceinline__ __device__ scalar_t uint82half(const uint8_t& input) { union uf16{ uint16_t as_bits; _Float16 as_value; } ; union uf32 { uint32_t as_bits; float as_value; }; if constexpr(!is_e4m3){ uf16 u16; u16.as_bits = (uint16_t)input << 8; if constexpr(std::is_same::value){ return u16.as_value; } else{ uf32 u32; u32.as_value = (float)u16.as_value; return u32.as_bits>>16; } } else{ uf32 u32; u32.as_value = uint82float(input); if constexpr(std::is_same::value){ return (_Float16)(u32.as_value); } else{ return (uint16_t)(u32.as_bits >> 16); } } } #define BOOL_SWITCH(COND, CONST_NAME, ...) \ [&] { \ if (COND) { \ constexpr static bool CONST_NAME = true; \ return __VA_ARGS__(); \ } else { \ constexpr static bool CONST_NAME = false; \ return __VA_ARGS__(); \ } \ }() #define Input_Type_SWITCH(SRC_DTYPE, ...) \ [&] { \ if (SRC_DTYPE == at::ScalarType::Half) { \ using scalar_t=_Float16; \ return __VA_ARGS__(); \ }else { \ using scalar_t=uint16_t; \ return __VA_ARGS__(); \ } \ }() #define Cache_Type_SWITCH(scalar_t,dtype, ...) \ [&] { \ if(dtype==torch::kFloat8_e5m2){ \ using cache_t=uint8_t; \ constexpr bool is_e4m3=false; \ return __VA_ARGS__(); \ }else if(dtype==torch::kFloat8_e4m3fn){ \ using cache_t=uint8_t; \ constexpr bool is_e4m3=true; \ return __VA_ARGS__(); \ }else { \ using cache_t=scalar_t; \ constexpr bool is_e4m3=false; \ return __VA_ARGS__(); \ } \ }() #define REUSEKV_SWITCH(reusekv,...) \ [&] { \ if (reusekv==64){ \ constexpr static int REUSE_KV_TIMES = 64; \ return __VA_ARGS__(); \ }else if (reusekv==48){ \ constexpr static int REUSE_KV_TIMES = 48; \ return __VA_ARGS__(); \ }else if (reusekv==32){ \ constexpr static int REUSE_KV_TIMES = 32; \ return __VA_ARGS__(); \ }else if (reusekv==24){ \ constexpr static int REUSE_KV_TIMES = 24; \ return __VA_ARGS__(); \ }else if (reusekv==16){ \ constexpr static int REUSE_KV_TIMES = 16; \ return __VA_ARGS__(); \ }else if (reusekv==8){ \ constexpr static int REUSE_KV_TIMES = 8; \ return __VA_ARGS__(); \ }else { \ constexpr static int REUSE_KV_TIMES = 4; \ return __VA_ARGS__(); \ } \ }() #define HEADSIZE_SWITCH(headsize,...) \ [&] { \ if (headsize==64){ \ constexpr static int HEAD_SIZE = 64; \ return __VA_ARGS__(); \ }else if(headsize==128){ \ constexpr static int HEAD_SIZE = 128; \ return __VA_ARGS__(); \ }else if(headsize==192){ \ constexpr static int HEAD_SIZE = 192; \ return __VA_ARGS__(); \ }else { \ constexpr static int HEAD_SIZE = 256; \ return __VA_ARGS__(); \ } \ }() static std::string get_device_name() { hipDeviceProp_t props{}; int device; auto status = hipGetDevice(&device); if(status != hipSuccess) { return std::string(); } status = hipGetDeviceProperties(&props, device); if(status != hipSuccess) { return std::string(); } const std::string raw_name(props.gcnArchName); return raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str. } static const std::string device_name=get_device_name(); static inline int get_env_(const char *env_var) { if (char *value = std::getenv(env_var)) { return atoi(value); } return 0; } static const int PA_USE_STREAMK = get_env_("PA_USE_STREAMK"); static const int PA_MAX_BLOCKS = get_env_("PA_MAX_BLOCKS"); static const int PA_PRINT_PARAM = get_env_("PA_PRINT_PARAM"); static const int PA_PARTITION_SIZE = get_env_("PA_PARTITION_SIZE"); static const int PA_GFX938 = get_env_("PA_GFX938"); using uint8x4_t = __attribute__( (__vector_size__(4 * sizeof(uint8_t)) )) uint8_t; using half4_t = __attribute__( (__vector_size__(4 * sizeof(_Float16)) )) _Float16; using v4bh = __attribute__( (__vector_size__(4 * sizeof(short)) )) short; using float4_t = __attribute__( (__vector_size__(4 * sizeof(float)) )) float; using float2_t = __attribute__( (__vector_size__(2 * sizeof(float)) )) float; template struct half4vec{ half4_t data[vec]; }; using half4x2 = half4vec<2>; template struct uint8x4vec{ uint8x4_t data[vec]; }; using uint8x4x2 = uint8x4vec<2>; using uint8x4x4 = uint8x4vec<4>; template inline __device__ float block_sum(float* red_smem, float sum) { int warp = __builtin_amdgcn_readfirstlane(threadIdx.x / WARP_SIZE); int lane = threadIdx.x % WARP_SIZE; #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { sum += __shfl_xor(sum, mask); } if (lane == 0) { red_smem[warp] = sum; } __syncthreads(); if (lane < NUM_WARPS) { sum = red_smem[lane]; } #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { sum += __shfl_xor(sum, mask); } return __shfl(sum, 0); } template inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c) { if constexpr (is_half){ reg_c=__builtin_hcu_mmac_f32_16x16x16_f16(reg_a,reg_b,reg_c); }else{ reg_c=__builtin_hcu_mmac_f32_16x16x16_bf16(*(v4bh*)®_a,*(v4bh*)®_b,reg_c); } } template // Zero means no partitioning. __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads,head_size] scalar_t* __restrict__ out_tmp, // [num_seqs, num_heads, max_num_partitions,head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, // head_size, block_size] const int num_heads, const int num_kv_heads, // [num_heads] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride,const int kv_block_stride, const float* k_scale_ptr, const float* v_scale_ptr,int max_num_partitions,int PARTITION_SIZE, const scalar_t* __restrict__ s_aux_ptr,int mtp,bool has_abili) { // ★ Attention Sinks: [num_heads] scalar_t ★ const int seq_idx = blockIdx.y; const int partition_idx = blockIdx.z; constexpr int kv_head_stride=BLOCK_SIZE*HEAD_SIZE; const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); if(num_partitions<=partition_idx)return ; constexpr bool is_half = std::is_same::value; constexpr bool is_fp8 = std::is_same::value; constexpr float scale = (HEAD_SIZE==64?0.125f:(HEAD_SIZE==128? 0.0883883476f:(HEAD_SIZE==192?0.0721687836f:0.0625f)))*1.4426950408889634; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int thread_idx = threadIdx.x; const int warp_idx = __builtin_amdgcn_readfirstlane(thread_idx / WARP_SIZE); const int lane = thread_idx % WARP_SIZE; const int rowid = lane%16; const int rows = lane/16; float k_scale=scale; float v_scale=1.0; if(k_scale_ptr!=nullptr){ k_scale*=(*k_scale_ptr); v_scale=*v_scale_ptr; } const int num_queries_per_kv = num_heads / num_kv_heads; const int kv_head_idx = blockIdx.x; const int head_idx=num_queries_per_kv/mtp * kv_head_idx; constexpr int reuse_group=(REUSE_KV_TIMES-1)/4+1; constexpr int Mloop=(REUSE_KV_TIMES-1)/16+1; extern __shared__ char shared_mem[]; scalar_t* logits = reinterpret_cast(shared_mem); float* s_max = reinterpret_cast(shared_mem + sizeof(scalar_t)*num_queries_per_kv*PARTITION_SIZE); float* s_logit = s_max + num_queries_per_kv * NUM_WARPS; float* max_out = s_logit+NUM_WARPS; float* expsum_out = max_out+num_queries_per_kv; // ★ Attention Sinks: load s_aux to shared memory ★ __shared__ scalar_t smem_s_aux[64]; if (s_aux_ptr != nullptr) { if (thread_idx < num_heads) { smem_s_aux[thread_idx] = s_aux_ptr[thread_idx]; } __syncthreads(); } const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; float alibi_slope[reuse_group]={0.f}; if (has_abili){ for(int i=0;i(shared_mem); { int head_offset = HEAD_SIZE*num_queries_per_kv/mtp; for(int i=thread_idx*8;i(s_q+i)=*reinterpret_cast(q_ptr+qoffset); } } __syncthreads(); for(int m=0;m(s_q+head_idx_*HEAD_SIZE+(i*4+rows)*8); else q_vec[m][i]=q_zero; } } __syncthreads(); const int start_block_idx = partition_idx * PARTITION_SIZE / BLOCK_SIZE; const int end_block_idx =MIN(start_block_idx + PARTITION_SIZE / BLOCK_SIZE, num_seq_blocks); const int num_blocks = end_block_idx - start_block_idx; const int start_token_idx = start_block_idx * BLOCK_SIZE; const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); const int num_tokens = end_token_idx - start_token_idx; //comput q*k { const cache_t* k_ptr_base = k_cache+kv_head_idx * kv_head_stride; for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;block_idx += NUM_WARPS) { const int64_t physical_block_number = static_cast(block_table[block_idx]); #pragma unroll for(int b=0;b(k_ptr+i*32+rowid*HEAD_SIZE+rows*8); scalar_t *p1=(scalar_t*)(k_vec+i); uint8_t *p2=(uint8_t*)&k_vec_u8; for(int ii=0;ii<8;ii++){ p1[ii]=uint82half(p2[ii]); } } else{ k_vec[i]=*reinterpret_cast(k_ptr+i*32+rowid*HEAD_SIZE+rows*8); } } __builtin_amdgcn_sched_barrier(0); #pragma unroll for(int i=0;i(k_vec[i].data[0],q_vec[m][i].data[0],qk_vec[m]); builtin_amdgcn_mmac(k_vec[i].data[1],q_vec[m][i].data[1],qk_vec[m]); } } #pragma unroll for(int i=0;i= seq_len) { int seq_len_pad=DIVIDE_ROUND_UP(seq_len,8)*8; if(token_idx1){ int casual = mtp - reuse_kv_idx * mtp / num_queries_per_kv ; if(token_idx+casual>seq_len)qk_vec[m][ii]=-INFINITY; } from_float(temp,qk_vec[m][ii]); logits[PARTITION_SIZE*reuse_kv_idx+token_idx- start_token_idx]=temp; qk_max[i] = fmaxf(qk_max[i], to_float(temp)); // if(partition_idx==0)printf("tid=%d,tokenid=%d,reuse_kv_idx=%d,m=%d,ii=%d,qk=%f\n",thread_idx,token_idx,reuse_kv_idx,m,i,qk_vec[m][ii]); } } } } } } // compute max #pragma unroll for (int mask = 8; mask >= 1; mask /= 2) { #pragma unroll for(int r=0;r(logits+lineid/NUM_WARPS*NUM_WARPS*2*PARTITION_SIZE+thread_idx*8); for(int ii=0;ii<8;ii++){ logit32[ii]=__builtin_amdgcn_exp2f(to_float(logit16[ii])-qk_max_tmp); exp_sum+=logit32[ii]; } // printf("tid=%d,logit32=%.4f,%.4f,%.4f,%.4f, %.4f,%.4f,%.4f,%.4f\n",thread_idx,logit32[0],logit32[1],logit32[2],logit32[3],logit32[4],logit32[5],logit32[6],logit32[7]); } for (int mask = 16; mask >= 1; mask /= 2) { exp_sum += __shfl_xor(exp_sum, mask); } exp_sum += sink_contrib; // printf("tid=%d,exp_sum=%f\n",thread_idx,exp_sum); const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); if(half_lane(logits+lineid/NUM_WARPS*NUM_WARPS*2*PARTITION_SIZE+thread_idx*8)=logit16; if(num_partitions>1&&half_lane==0){ max_out[real_line] = qk_max_tmp; expsum_out[real_line] = exp_sum; } } } } } else if(PARTITION_SIZE==512){ for(int lineid = warp_idx;lineid(logits+lineid/NUM_WARPS*NUM_WARPS*PARTITION_SIZE+thread_idx*8); for(int ii=0;ii<8;ii++){ logit32[ii]=__builtin_amdgcn_exp2f(to_float(logit16[ii])-qk_max_tmp); exp_sum+=logit32[ii]; } // printf("tid=%d,logit32=%.4f,%.4f,%.4f,%.4f, %.4f,%.4f,%.4f,%.4f\n",thread_idx,logit32[0],logit32[1],logit32[2],logit32[3],logit32[4],logit32[5],logit32[6],logit32[7]); } for (int mask = 32; mask >= 1; mask /= 2) { exp_sum += __shfl_xor(exp_sum, mask); } exp_sum += sink_contrib; // printf("tid=%d,exp_sum=%f\n",thread_idx,exp_sum); const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); if(lane(logits+lineid/NUM_WARPS*NUM_WARPS*PARTITION_SIZE+thread_idx*8)=logit16; if(num_partitions>1&&lane==0){ max_out[lineid] = qk_max_tmp; expsum_out[lineid] = exp_sum; } } } } } __syncthreads(); constexpr int NUM_ROWS_PER_THREAD =DIVIDE_ROUND_UP(HEAD_SIZE, 16*NUM_WARPS);//2 constexpr int GROUPS=reuse_group*4; // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. float4_t accs[Mloop][NUM_ROWS_PER_THREAD]; for(int m=0;m; using uint8x4_vec = uint8x4vec; for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx ++) { const int64_t physical_block_number = static_cast(block_table[block_idx]); const int token_idx = block_idx * BLOCK_SIZE +rows*(BLOCK_SIZE/4); half4_vec logits_vec[Mloop]; for(int m=0;m(logits + real_row * PARTITION_SIZE+token_idx - start_token_idx + k*4); } } } const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + kv_head_idx * kv_head_stride; if(partition_idx(v_ptr + offset); scalar_t *p1=(scalar_t*)&v_vec; uint8_t *p2=(uint8_t*)&vecu8; for(int ii=0;ii(p2[ii]); } }else{ v_vec=*reinterpret_cast(v_ptr + offset); } for(int ii=0;ii(v_vec.data[ii],logits_vec[m].data[ii],accs[m][i]); } } } } else{ #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { int offset=i*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD+warp_idx*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD/NUM_WARPS+rows*vecsize*4+rowid*BLOCK_SIZE; half4_vec v_vec; if constexpr(is_fp8){ uint8x4_vec vecu8 = *reinterpret_cast(v_ptr + offset); scalar_t *p1=(scalar_t*)&v_vec; uint8_t *p2=(uint8_t*)&vecu8; for(int ii=0;ii(p2[ii]); } }else{ v_vec=*reinterpret_cast(v_ptr + offset); } //这里的if判断会影响一定的性能,因此只有最后一个patition才判断 scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); #pragma unroll for (int j = 0; j < 4*vecsize; j++) { v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : 0; } for(int ii=0;ii(v_vec.data[ii],logits_vec[m].data[ii],accs[m][i]); } } } } } { scalar_t* out_ptr_base; int out_offset; if(num_partitions>1){ out_offset=max_num_partitions*HEAD_SIZE; out_ptr_base=out_tmp+out_tmp_offset + seq_idx * num_heads * out_offset + head_idx*out_offset+partition_idx * HEAD_SIZE; } else{ out_offset=HEAD_SIZE; out_ptr_base=out + seq_idx * num_heads * HEAD_SIZE + head_idx*HEAD_SIZE; } int head_offset = num_queries_per_kv/mtp; for(int g=0;g1&&thread_idx < num_queries_per_kv){ int out_head = thread_idx/head_offset*num_kv_heads*head_offset + thread_idx%head_offset; int offset = seq_idx * num_heads * max_num_partitions + (head_idx+out_head) * max_num_partitions + partition_idx; float * exp_sums=reinterpret_cast(out_tmp); float * max_logits=reinterpret_cast(out_tmp+max_tmp_offset); *(exp_sums+offset)=expsum_out[thread_idx]; *(max_logits+offset)=max_out[thread_idx]; } } } template __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_combine( scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* out_tmp, // [num_seqs, num_heads, const int* __restrict__ seq_lens, // [num_seqs] const int max_num_partitions, int num_heads, int PARTITION_SIZE) { extern __shared__ char shared_mem[]; const int head_idx = blockIdx.x; const int seq_idx = blockIdx.y; const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]); const int lane = threadIdx.x; const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); if(num_partitions==1)return; float* shared_exp_sums=reinterpret_cast(shared_mem); float* shared_max_logits=shared_exp_sums+num_partitions; float max_logit = -FLT_MAX; float global_exp_sum = 0.0f; int offset = seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; const float * exp_sums=reinterpret_cast(out_tmp); const float * max_logits=reinterpret_cast(out_tmp+max_tmp_offset); const float* max_logits_ptr = max_logits + offset; const float* exp_sums_ptr = exp_sums + offset; const scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; const scalar_t* tmp_out_ptr = out_tmp + out_tmp_offset + offset* HEAD_SIZE; for(int i=lane;i= 1; mask /= 2) { max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); } for(int i=lane;i= 1; mask /= 2) { global_exp_sum += __shfl_xor(global_exp_sum, mask); } const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); constexpr int vec_size_o=HEAD_SIZE/64; constexpr int vec_size = vec_size_o==3?4:vec_size_o; using half_vec= __attribute__( (__vector_size__(vec_size * sizeof(scalar_t)) )) scalar_t; using float_vec= __attribute__( (__vector_size__(vec_size * sizeof(float)) )) float; float_vec acc = {0.0f}; half_vec acc_half; if(lanekv_head*48) return 64; if(qhead>kv_head*32) return 48; if(qhead>kv_head*24) return 32; if(qhead>kv_head*16) return 24; if(qhead>kv_head*8) return 16; if(qhead>kv_head*4)return 8; return 4; } void paged_attention_938( torch::Tensor& out, // [num_seqs,seqlen, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size] torch::Tensor& value_cache,// [num_blocks, num_heads, head_size, block_size] torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] const c10::optional& alibi_slopes, const c10::optional& q_scale, const c10::optional& k_scale, const c10::optional& v_scale, int max_seq_len, const c10::optional &s_aux_, float *tmp_out_ptr, int PARTITION_SIZE); // ★ Attention Sinks ★ extern "C" void paged_attention( torch::Tensor& out, // [num_seqs,seqlen, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size] torch::Tensor& value_cache,// [num_blocks, num_heads, head_size, block_size] double scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, //auto,int8,fp8/fp8_e4m3 const c10::optional& q_scale, const c10::optional& k_scale, const c10::optional& v_scale, int max_seq_len, const c10::optional &s_aux_) // ★ Attention Sinks ★ { int num_seqs = query.size(0); int headsize=query.size(3); int block_size=key_cache.size(2); int mtp = query.size(1); int num_blocks = key_cache.size(0); int max_num_blocks_per_seq = block_tables.size(1); int num_heads = query.size(2)*mtp; int num_kv_heads = key_cache.size(1); int PARTITION_SIZE=512; int reusekv=get_reusekv(num_heads,num_kv_heads); if(reusekv>15)PARTITION_SIZE=256; //if seq<10,the seq is invalid if (max_seq_len<=10||(max_seq_len>=8192&&max_seq_len==max_num_blocks_per_seq*block_size)){ int meanseq = num_blocks*block_size/num_seqs+4096; int maxseq = 100000000/num_seqs/headsize/num_heads*64; if(reusekv<16) maxseq*=2; max_seq_len=MIN(max_num_blocks_per_seq*block_size,MIN(meanseq,maxseq)); } else{ int max_num_partitions=DIVIDE_ROUND_UP(max_seq_len,PARTITION_SIZE); if(max_num_partitions*num_seqs*num_kv_heads<=160)PARTITION_SIZE=256; if(num_seqs*num_kv_heads<=32&&max_seq_len<=32768)PARTITION_SIZE=256; } int real_reuse_times = num_heads/num_kv_heads; if(PA_PARTITION_SIZE!=0)PARTITION_SIZE=PA_PARTITION_SIZE; int max_num_partitions=DIVIDE_ROUND_UP(max_seq_len,PARTITION_SIZE); static float* tmp_out_ptr = nullptr; constexpr int temp_out_size = 110000000; if(tmp_out_ptr == nullptr){ hipMalloc(&tmp_out_ptr, temp_out_size); // 100m hipMemset(tmp_out_ptr,0,temp_out_size); } if((device_name=="gfx938"|| device_name == "gfx92a")&&(key_cache.dtype()==torch::kFloat8_e5m2||key_cache.dtype()==torch::kFloat8_e4m3fn)){ paged_attention_938(out,query,key_cache,value_cache,block_tables,seq_lens,alibi_slopes,q_scale,k_scale,v_scale,max_seq_len,s_aux_,tmp_out_ptr,PARTITION_SIZE); return; } int head_size = query.size(3); int q_stride = query.stride(0); int kv_block_stride = key_cache.stride(0); int kv_head_stride = key_cache.stride(1); const float* alibi_slopes_ptr =alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()):nullptr; int* block_tables_ptr = block_tables.data_ptr(); int* seq_lens_ptr = seq_lens.data_ptr(); auto* out_ptr = out.data_ptr(); const float* k_scale_ptr = k_scale? reinterpret_cast(k_scale.value().data_ptr()):nullptr; const float* v_scale_ptr = v_scale? reinterpret_cast(v_scale.value().data_ptr()):nullptr; // Attention Sinks: validate and set s_aux_ptr const void* s_aux_ptr = nullptr; if (s_aux_.has_value()) { auto s_aux = s_aux_.value(); // ★ s_aux must match Q/K/V dtype (Element type) for mixed precision TORCH_CHECK(s_aux.dtype() == query.dtype(), "s_aux must have the same dtype as query. Got s_aux dtype: ", s_aux.dtype(), ", query dtype: ", query.dtype()); TORCH_CHECK(s_aux.dtype() == torch::kFloat16 || s_aux.dtype() == torch::kBFloat16, "s_aux must have dtype float16 or bfloat16 (to match query). Got: ", s_aux.dtype()); TORCH_CHECK(num_heads <= 64, "Attention Sinks only supports up to 64 heads (shared memory limit), got ", num_heads); CHECK_DEVICE(s_aux); CHECK_SHAPE(s_aux, num_heads); CHECK_CONTIGUOUS(s_aux); s_aux_ptr = s_aux.data_ptr(); } auto* query_ptr = query.data_ptr(); auto* key_cache_ptr = key_cache.data_ptr(); auto* value_cache_ptr = value_cache.data_ptr(); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 reduce_grid(num_heads, num_seqs); dim3 grid; grid.x = num_kv_heads; grid.y = num_seqs; AT_ASSERTM(headsize%64==0 && headsize<=256, "Page Attention head size must be 64, 128, 192 or 256"); AT_ASSERTM(num_heads<=num_kv_heads*64, "Page Attention qheads*mtp/kvheads must be smaller than 48"); HEADSIZE_SWITCH(headsize,[&]{ Input_Type_SWITCH(query.dtype(),[&]{ Cache_Type_SWITCH(scalar_t,key_cache.dtype(),[&] { REUSEKV_SWITCH(reusekv,[&] { BOOL_SWITCH(block_size==64,is_block64,[&]{ // constexpr int BLOCK_SIZE = (is_block64?64:128); constexpr int BLOCK_SIZE=64; // constexpr int HEAD_SIZE=128; // using scalar_t=uint16_t; // using cache_t = scalar_t; constexpr bool is_e4m3=false; // constexpr static int REUSE_KV_TIMES = 4; constexpr static int NUM_THREADS = 256; constexpr static int NUM_WARPS = NUM_THREADS / WARP_SIZE; int other_use = (real_reuse_times*NUM_WARPS+NUM_WARPS+ real_reuse_times*2)*sizeof(float); int shared_mem_size=PARTITION_SIZE*2*real_reuse_times+other_use; grid.z = max_num_partitions; dim3 block(NUM_THREADS); if(PA_PRINT_PARAM&&static_cast(query.get_device())==0)printf("is_fp8=%d,shared_mem_size=%d,HEAD_SIZE=%d,BLOCK_SIZE=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d,PARTITION_SIZE=%d,max_num_partitions=%d\n", (int)(sizeof(cache_t)==1),shared_mem_size,HEAD_SIZE,BLOCK_SIZE,NUM_THREADS,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs,PARTITION_SIZE,max_num_partitions); paged_attention_kernel<<>>( (scalar_t*)out_ptr,(scalar_t*)tmp_out_ptr, (scalar_t*)query_ptr,(cache_t*) key_cache_ptr, (cache_t*)value_cache_ptr, num_heads, num_kv_heads, block_tables_ptr, seq_lens_ptr,max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, k_scale_ptr, v_scale_ptr,max_num_partitions,PARTITION_SIZE,(const scalar_t*)s_aux_ptr,mtp,alibi_slopes_ptr!=nullptr); if(max_num_partitions>1){ paged_attention_combine<<>>( (scalar_t*)out_ptr,(scalar_t*)tmp_out_ptr,seq_lens_ptr,max_num_partitions,num_heads,PARTITION_SIZE); } }); }); }); }); }); }