#include #include #include #include #include "attention_dtypes.h" #include "attention_utils.cuh" #ifdef USE_ROCM #include #include "../quantization/fp8/amd/quant_utils.cuh" typedef __hip_bfloat16 __nv_bfloat16; #else #include "../quantization/fp8/nvidia/quant_utils.cuh" #endif #ifndef USE_ROCM #define WARP_SIZE 32 #else #define WARP_SIZE warpSize #endif #include "static_switch_tc.h" #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) inline std::string get_device_name() { hipDeviceProp_t props{}; int device; auto status = hipGetDevice(&device); if(status != hipSuccess) { return std::string(); } status = hipGetDeviceProperties(&props, device); if(status != hipSuccess) { return std::string(); } const std::string raw_name(props.gcnArchName); return raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str. } static inline int get_env_(const char *env_var) { if (char *value = std::getenv(env_var)) { return atoi(value); } return 0; } static const int PA_REUSE_KV_TIMES = get_env_("PA_REUSE_KV_TIMES"); static const int PA_BLOCK_SIZE = get_env_("PA_BLOCK_SIZE"); static const int PA_PRINT_PARAM = get_env_("PA_PRINT_PARAM"); namespace vllm { // Utility function for attention softmax. template inline __device__ float block_sum(float* red_smem, float sum) { // Decompose the thread index into warp / lane. int warp = threadIdx.x / WARP_SIZE; int lane = threadIdx.x % WARP_SIZE; // Compute the sum per warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { sum += VLLM_SHFL_XOR_SYNC(sum, mask); } // Warp leaders store the data to shared memory. if (lane == 0) { red_smem[warp] = sum; } // Make sure the data is in shared memory. __syncthreads(); // The warps compute the final sums. if (lane < NUM_WARPS) { sum = red_smem[lane]; } // Parallel reduction inside the warp. #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { sum += VLLM_SHFL_XOR_SYNC(sum, mask); } // Broadcast to other threads. return VLLM_SHFL_SYNC(sum, 0); } using half4_t = __attribute__( (__vector_size__(4 * sizeof(_Float16)) )) _Float16; using v4bh = __attribute__( (__vector_size__(4 * sizeof(short)) )) short; using float4_t = __attribute__( (__vector_size__(4 * sizeof(float)) )) float; struct half4x2{ half4_t data[2]; }; template inline __device__ void float4_2_half4(half4_t& dst,const float4_t& src) { if constexpr(is_half){ #pragma unroll for(int i=0;i<4;i++){ dst[i]=src[i]; } } else{ __nv_bfloat16* out = reinterpret_cast<__nv_bfloat16 *>(&dst); #pragma unroll for(int i=0;i<4;i++){ out[i]=__float2bfloat16(src[i]); } } } template inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c) { if constexpr (is_half){ asm volatile("v_mmac_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c)); } else{ asm volatile("v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0" : "=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c)); } } template inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c) { if constexpr (use_vmac){v_mmac_f32_16x16x16_f16(reg_a,reg_b,reg_c);} else{ if constexpr (is_half){reg_c=__builtin_amdgcn_mmac_f32_16x16x16f16(reg_a,reg_b,reg_c);} else{ reg_c=__builtin_amdgcn_mmac_f32_16x16x16bf16(*(v4bh*)®_a,*(v4bh*)®_b,reg_c); } } } // TODO(woosuk): Merge the last two dimensions of the grid. // Grid: (num_heads, num_seqs, max_num_partitions). template // Zero means no partitioning. __device__ void paged_attention_kernel_TC( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads, // max_num_partitions] scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, // head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, // head_size, block_size] const int num_heads, const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, const float k_scale, const float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) { const int seq_idx = blockIdx.z; const int partition_idx = blockIdx.y; const int max_num_partitions = gridDim.y; constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]); if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { // No work to do. Terminate the thread block. return; } constexpr bool is_half = std::is_same::value; static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS"); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE; // [start_block_idx, end_block_idx) is the range of blocks to process. const int start_block_idx = partition_idx * num_blocks_per_partition;//0,64,128… const int end_block_idx =MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);//64,128,192… const int num_blocks = end_block_idx - start_block_idx;//64 or 1-63 // [start_token_idx, end_token_idx) is the range of tokens to process. const int start_token_idx = start_block_idx * BLOCK_SIZE;//0,1024,2048… const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);//1024,2048,3072… const int num_tokens = end_token_idx - start_token_idx;//1024 or 1-1023 // divides NUM_THREADS constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;//4 constexpr int x = 16 / sizeof(cache_t);//8 const int thread_idx = threadIdx.x; const int warp_idx = __builtin_amdgcn_readfirstlane(thread_idx / WARP_SIZE); const int lane = thread_idx % WARP_SIZE; const int rowid = lane%16; const int rows = lane/16; const int num_queries_per_kv = num_heads / num_kv_heads; const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES); const int odd_tg_round = (((blockIdx.z * gridDim.y * gridDim.x) + blockIdx.y * gridDim.x) / 128) % 2; const int mid_x = gridDim.x / 2; const int blockIdx_shift = (odd_tg_round | (gridDim.x & 1)) ? blockIdx.x : (blockIdx.x < mid_x ? (blockIdx.x + mid_x) : (blockIdx.x - mid_x)); const int head_idx = (blockIdx_shift / num_blocks_per_kv) * num_queries_per_kv + (blockIdx_shift % num_blocks_per_kv) * REUSE_KV_TIMES; //const int head_idx=(blockIdx.x / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.x % num_blocks_per_kv) * REUSE_KV_TIMES; int q_boundary=REUSE_KV_TIMES; if(num_heads < REUSE_KV_TIMES*gridDim.x && (num_blocks_per_kv-1)*REUSE_KV_TIMES == head_idx%num_queries_per_kv) q_boundary=num_queries_per_kv-(num_blocks_per_kv-1)*REUSE_KV_TIMES; const int kv_head_idx = head_idx / num_queries_per_kv; constexpr int reuse_group=(REUSE_KV_TIMES-1)/4+1; float alibi_slope[reuse_group]={0.f}; if(alibi_slopes != nullptr){ for(int i=0;i(q_ptr+i*HEAD_SIZE+thread_idx*8); } } __syncthreads(); // Memory planning. extern __shared__ char shared_mem[]; // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. scalar_t* logits = reinterpret_cast(shared_mem); // Workspace for reduction. __shared__ float red_smem[2 * NUM_WARPS]; // Iterate over the key blocks. // Each warp fetches a block of keys for each iteration. // Each thread group in a warp fetches a key from the block, and computes // dot product with the query. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; // blocksparse specific vars int bs_block_offset; int q_bs_block_id; const cache_t* k_ptr_base = k_cache+kv_head_idx * kv_head_stride+lane*8; for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { const int64_t physical_block_number = static_cast(block_table[block_idx]); const cache_t* k_ptr=k_ptr_base + physical_block_number * kv_block_stride; float4_t qk_vec={0,0,0,0}; half4x2 k_vec[2]; k_vec[0]=*reinterpret_cast(k_ptr); #pragma unroll for(int i=0;i<3;i++){ if(rowid(k_ptr+(i+1)*512); builtin_amdgcn_mmac(k_vec[i%2].data[0],q_vec.data[0],qk_vec); builtin_amdgcn_mmac(k_vec[i%2].data[1],q_vec.data[1],qk_vec); } //tail { if(rowid(k_vec[1].data[0],q_vec.data[0],qk_vec); v_mmac_f32_16x16x16_f16(k_vec[1].data[1],q_vec.data[1],qk_vec); } #pragma unroll for(int i=0;i=q_boundary)qk_vec[i]=0; else qk_vec[i]*=scale; const int token_idx = block_idx * BLOCK_SIZE+rowid; if(alibi_slope[i] != 0){ float alibi=alibi_slope[i]* (token_idx - seq_len + 1); qk_vec[i] += alibi; } bool mask = (token_idx >= seq_len); // used for tree-style attention if (attn_masks != nullptr) { const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride; mask |= attn_masks_ptr[token_idx] == 0; } if(mask){ from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f); } else{ from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , qk_vec[i]); qk_max[i] = fmaxf(qk_max[i], qk_vec[i]); } } } } // if(blockIdx.x==0)printf("%d,qkmax=%f\n",threadIdx.x,qk_max[0]); // Perform reduction across the threads in the same warp to get the // max qk value for each "warp" (not across the thread block yet). // The 0-th thread of each thread group already has its max qk value. for(int reuse_kv_idx=0; reuse_kv_idx= 1; mask /= 2) { qk_max_tmp = fmaxf(qk_max_tmp, VLLM_SHFL_XOR_SYNC(qk_max_tmp, mask)); } if (rowid==0 && reuse_kv_idx%4==rows) { red_smem[warp_idx] = qk_max_tmp; } __syncthreads(); // TODO(woosuk): Refactor this part. // Get the max qk value for the sequence. qk_max_tmp = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { qk_max_tmp = fmaxf(qk_max_tmp, VLLM_SHFL_XOR_SYNC(qk_max_tmp, mask)); } // Broadcast the max qk value to all threads. qk_max_tmp = VLLM_SHFL_SYNC(qk_max_tmp, 0); for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { float val = __expf(to_float(logits[(reuse_kv_idx * partition_size) + i]) - qk_max_tmp); from_float(logits[(reuse_kv_idx * partition_size) + i] , val); exp_sum += val; } exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); // Compute softmax. const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { from_float(logits[(reuse_kv_idx * partition_size) + i] ,to_float(logits[(reuse_kv_idx * partition_size) + i])*inv_sum); } __syncthreads(); // If partitioning is enabled, store the max logit and exp_sum. if (USE_PARTITIONING && thread_idx == 0) { float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + head_idx_ * max_num_partitions + partition_idx; *max_logits_ptr = qk_max_tmp; float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + head_idx_ * max_num_partitions + partition_idx; *exp_sums_ptr = exp_sum; } } constexpr int NUM_ROWS_PER_THREAD =DIVIDE_ROUND_UP(HEAD_SIZE, WARP_SIZE);//2 if constexpr(REUSE_KV_TIMES<=2){ float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD]; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { #pragma unroll for(int k=0;k(block_table[block_idx]); const int token_idx = block_idx * BLOCK_SIZE +rows*4; half4_t logits_vec={0,0,0,0}; if(rowid<4*q_boundary){ logits_vec=*reinterpret_cast(logits + rowid/4 * partition_size+token_idx - start_token_idx); } const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + kv_head_idx * kv_head_stride + rows*4+rowid*16; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { #pragma unroll for(int k=0;k<4;k++){ int offset=i*1024+k*256; half4_t v_vec=*reinterpret_cast(v_ptr + offset); if (block_idx == num_seq_blocks - 1) { scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); #pragma unroll for (int j = 0; j < 4; j++) { v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; } } float4_t out_vec={0,0,0,0}; builtin_amdgcn_mmac(v_vec,logits_vec,out_vec); if(rows==k){ for(int resuseid=0;resuseid64){ floatV_t* out_smem = reinterpret_cast(shared_mem); #pragma unroll for (int i = NUM_WARPS; i > 1; i /= 2) { int mid = i / 2; // Upper warps write to shared memory. if (warp_idx >= mid && warp_idx < i) { out_smem[(warp_idx - mid) * 64+lane]=*(floatV_t*)(accs[reuse_kv_idx]); } __syncthreads(); // Lower warps update the output. if (warp_idx < mid) { floatV_t tmp=out_smem[thread_idx]; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { accs[reuse_kv_idx][i] += tmp[i]; } } __syncthreads(); } } // Write the final output. if (warp_idx == 0) { scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + (head_idx+reuse_kv_idx) * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane + i * WARP_SIZE; from_float(*(out_ptr + row_idx), accs[reuse_kv_idx][i]); } } } } else{ constexpr int GROUPS=reuse_group*4; // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. float accs[GROUPS][NUM_ROWS_PER_THREAD]; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { #pragma unroll for(int k=0;k(block_table[block_idx]); const int token_idx = block_idx * BLOCK_SIZE +rows*4; half4_t logits_vec={0,0,0,0}; if(rowid(logits + rowid * partition_size+token_idx - start_token_idx); } const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + kv_head_idx * kv_head_stride + rows*4+rowid*16; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { #pragma unroll for(int k=0;k<4;k++){ int offset=i*1024+k*256; half4_t v_vec=*reinterpret_cast(v_ptr + offset); if (block_idx == num_seq_blocks - 1) { scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); #pragma unroll for (int j = 0; j < 4; j++) { v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; } } float4_t out_vec={0,0,0,0}; builtin_amdgcn_mmac(v_vec,logits_vec,out_vec); for(int g=0;g64){ __syncthreads(); using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float; // Perform reduction across warps. for(int reuse_kv_idx=0; reuse_kv_idx(shared_mem); #pragma unroll for (int i = NUM_WARPS; i > 1; i /= 2) { int mid = i / 2; // Upper warps write to shared memory. if (warp_idx >= mid && warp_idx < i) { out_smem[(warp_idx - mid) * 64+lane]=*(floatV_t*)(accs[reuse_kv_idx]); } __syncthreads(); // Lower warps update the output. if (warp_idx < mid) { floatV_t tmp=out_smem[thread_idx]; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { accs[reuse_kv_idx][i] += tmp[i]; } } __syncthreads(); } } } if (warp_idx == 0) { for(int g=0;g __global__ void paged_attention_v1_kernel_TC( scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, // head_size, block_size] const int num_heads, const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, const float k_scale, const float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) { #ifdef __gfx928__ paged_attention_kernel_TC( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_heads,num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step, attn_masks, attn_masks_stride); #endif } // Grid: (num_heads, num_seqs, max_num_partitions). template __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads, // max_num_partitions] scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, // head_size, block_size] const int num_heads, // [num_heads] const int num_kv_heads, // [num_kv_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, const float k_scale, const float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) { #ifdef __gfx928__ paged_attention_kernel_TC( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step, attn_masks, attn_masks_stride); #endif } // Grid: (num_heads, num_seqs). template __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_tc( scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const float* __restrict__ exp_sums, // [num_seqs, num_heads, // max_num_partitions] const float* __restrict__ max_logits, // [num_seqs, num_heads, // max_num_partitions] const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions, head_size] const int* __restrict__ seq_lens, // [num_seqs] const int max_num_partitions) { const int num_heads = gridDim.x; const int head_idx = blockIdx.x; const int seq_idx = blockIdx.y; const int seq_len = seq_lens[seq_idx]; const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); if (num_partitions == 1) { // No need to reduce. Only copy tmp_out to out. scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE; for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { out_ptr[i] = tmp_out_ptr[i]; } // Terminate the thread block. return; } constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int warp_idx = threadIdx.x / WARP_SIZE; const int lane = threadIdx.x % WARP_SIZE; // Size: 2 * num_partitions. extern __shared__ char shared_mem[]; // Workspace for reduction. __shared__ float red_smem[2 * NUM_WARPS]; // Load max logits to shared memory. float* shared_max_logits = reinterpret_cast(shared_mem); const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; float max_logit = -FLT_MAX; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { const float l = max_logits_ptr[i]; shared_max_logits[i] = l; max_logit = fmaxf(max_logit, l); } __syncthreads(); // Get the global max logit. // Reduce within the warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); } if (lane == 0) { red_smem[warp_idx] = max_logit; } __syncthreads(); // Reduce across warps. max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); } // Broadcast the max value to all threads. max_logit = VLLM_SHFL_SYNC(max_logit, 0); // Load rescaled exp sums to shared memory. float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; float global_exp_sum = 0.0f; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { float l = shared_max_logits[i]; float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); global_exp_sum += rescaled_exp_sum; shared_exp_sums[i] = rescaled_exp_sum; } __syncthreads(); global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); // Aggregate tmp_out to out. const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE; scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; #pragma unroll for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { float acc = 0.0f; for (int j = 0; j < num_partitions; ++j) { acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; } from_float(out_ptr[i], acc); } } } // namespace vllm #define LAUNCH_PAGED_ATTENTION_V1_TC(HEAD_SIZE) \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ ((void*)vllm::paged_attention_v1_kernel_TC), \ shared_mem_size); \ vllm::paged_attention_v1_kernel_TC \ <<>>( \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads,num_kv_heads, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_head_sliding_step, attn_masks_ptr, \ attn_masks_stride); void get_numberthread_and_reuse_kv_v1(int& num_thread,int& reusekv,int batchsize,int seq,int qheads,int kvheads){ //mha reusekv=1; if(qheads==kvheads){ //llama 7B ,其他模型未可知 if(seq<=16||batchsize>=32)num_thread=64; else if(batchsize<=2)num_thread=256; else if(batchsize<8)num_thread=128; else num_thread=64; return; } // mqa if(qheads>kvheads*4){ if(seq<64){ if(batchsize<=64){reusekv=1;num_thread=64;} else if(batchsize<128){reusekv=2;num_thread=64;} else {reusekv=4;num_thread=64;} } else if(seq<=400){ if(batchsize<16){reusekv=1;num_thread=256;} else if(batchsize<64){reusekv=2;num_thread=256;} else if(batchsize<=128){ reusekv=4; if(qheads%7==0)num_thread=64;//qwen7b else num_thread=256;//llama70b } else {reusekv=8;num_thread=64;} } else if(seq<=1000){ if(batchsize<16){reusekv=1;num_thread=256;} else if(qheads%7==0&&batchsize<=128){//qwen7b if(batchsize<64){reusekv=4;num_thread=256;} else{reusekv=4;num_thread=64;} } else if(batchsize<=64){reusekv=4;num_thread=256;} else {reusekv=8;num_thread=128;} } else if(seq<3900) {reusekv=8;num_thread=256;} else if(seq<7800) {reusekv=4;num_thread=256;} else {reusekv=2;num_thread=256;} return; } if(qheads/kvheads >4 && seq<3900)reusekv=8; else if(qheads/kvheads >2 && seq<7800)reusekv=4; else if(qheads/kvheads >=2 && seq<15600)reusekv=2; if(seq<=64){ num_thread=64; if(batchsize<=64)reusekv=1; } else num_thread=256; } // TODO(woosuk): Tune NUM_THREADS. template void paged_attention_v1_launcher_opt_tc( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, const c10::optional& alibi_slopes, float k_scale, float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step, const c10::optional& attn_masks, const int64_t attn_masks_stride=0) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); int max_num_blocks_per_seq = block_tables.size(1); int q_stride = query.stride(0); int kv_block_stride = key_cache.stride(0); int kv_head_stride = key_cache.stride(1); int num_threads = 128; // printf("paged_attention_v1\n"); if (num_heads != num_kv_heads) { num_threads = 256; } [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. const float* alibi_slopes_ptr = alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* seq_lens_ptr = seq_lens.data_ptr(); // NOTE: attn_masks is optional. const int* attn_masks_ptr = attn_masks ? attn_masks.value().data_ptr() : nullptr; int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2 && KV_DTYPE==vllm::Fp8KVCacheDataType::kAuto){ constexpr int HEAD_SIZE=128; constexpr static int use_vmac = false; int reusekv, num_thread; get_numberthread_and_reuse_kv_v1(num_thread,reusekv,num_seqs,padded_max_seq_len,num_heads,num_kv_heads); if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES; if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE; REUSEKV_SWITCH(reusekv,[&] { NUM_THREADS_SWITCH(num_thread , [&] { //constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int logits_size = REUSE_KV_TIMES * padded_max_seq_len * 2; int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); if (NUM_WARPS==64)outputs_size=0; int shared_mem_size = ::max(logits_size, outputs_size); dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1,num_seqs); dim3 block(NUM_THREADS); if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d\n", reusekv,num_thread,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs); LAUNCH_PAGED_ATTENTION_V1_TC(HEAD_SIZE); }); }); } } #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ paged_attention_v1_launcher_opt_tc( \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step, \ attn_masks, attn_masks_stride); #define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ switch (is_block_sparse) { \ case true: \ CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ break; \ case false: \ CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ break; \ } // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ switch (block_size) { \ case 8: \ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ break; \ case 16: \ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ break; \ case 32: \ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ } void paged_attention_v1_opt( torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] int64_t num_kv_heads, // [num_heads] double scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step, const c10::optional& attn_masks, // [num_seqs, max_seq_len] const int64_t attn_masks_stride); void paged_attention_v1_opt_tc( torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] int64_t num_kv_heads, // [num_heads] double scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step, const c10::optional& attn_masks, // [num_seqs, max_seq_len] const int64_t attn_masks_stride) { const bool is_block_sparse = (blocksparse_vert_stride > 1); if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse|| block_size!=16||query.size(2)!=128||get_device_name()!="gfx928"){ paged_attention_v1_opt(out,query,key_cache,value_cache,num_kv_heads, scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, blocksparse_block_size,blocksparse_head_sliding_step, attn_masks, attn_masks_stride); } else{ DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V1_LAUNCHER_BLOCK_SIZE) } } #define LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE) \ hipLaunchKernelGGL( \ (vllm::paged_attention_v2_kernel_TC< \ T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \ IS_BLOCK_SPARSE, REUSE_KV_TIMES,use_vmac, PARTITION_SIZE>), \ dim3(grid), dim3(block), shared_mem_size, stream, exp_sums_ptr, \ max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, \ num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \ max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_head_sliding_step, \ attn_masks_ptr, attn_masks_stride); \ hipLaunchKernelGGL( \ (vllm::paged_attention_v2_reduce_kernel_opt_tc), \ dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, out_ptr, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ max_num_partitions); void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int batchsize,int max_num_partitions,int qheads,int kvheads){ reusekv=1; int blocks=batchsize*qheads*max_num_partitions; if(qheads==kvheads){ if(blocks<=80||blocks>8000){num_thread=256;} else if(blocks<=160){num_thread=128;} else num_thread=64; return; } if(qheads/kvheads>8&&blocks>4000){ reusekv=16; if(blocks>40000)num_thread=64; else num_thread=128; } else if(qheads/kvheads==5||qheads/kvheads==7){ if(blocks<=160){reusekv=1;num_thread=256;} else if(blocks<640/5*qheads/kvheads){reusekv=4;num_thread=256;} else if(blocks<1920){reusekv=8;num_thread=128;} else {reusekv=8;num_thread=64;} } else if(qheads>kvheads*4){ if(blocks<=128){reusekv=1;num_thread=256;} else if(blocks<1536){reusekv=4;num_thread=256;} else if(blocks<6144){reusekv=8;num_thread=128;} else {reusekv=8;num_thread=64;} } else { if(blocks<=128){reusekv=1;num_thread=256;} else if(blocks<3000){reusekv=4;num_thread=256;} else {reusekv=4;num_thread=64;} } } template void paged_attention_v2_launcher_opt_tc( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, const c10::optional& alibi_slopes, float k_scale, float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step, const c10::optional& attn_masks, const int attn_masks_stride) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); int max_num_blocks_per_seq = block_tables.size(1); int q_stride = query.stride(0); int kv_block_stride = key_cache.stride(0); int kv_head_stride = key_cache.stride(1); // printf("paged_attention_v2\n"); int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. const float* alibi_slopes_ptr = alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* seq_lens_ptr = seq_lens.data_ptr(); // NOTE: attn_masks is optional. const int* attn_masks_ptr = attn_masks ? attn_masks.value().data_ptr() : nullptr; const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 reduce_grid(num_heads, num_seqs); int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2 && KV_DTYPE==vllm::Fp8KVCacheDataType::kAuto){ //if(head_size==128&&get_device_name()=="gfx928"){ constexpr int HEAD_SIZE=128; constexpr static int use_vmac = false; int reusekv, num_thread; get_numberthread_and_reuse_kv_v2(num_thread,reusekv,num_seqs,max_num_partitions,num_heads,num_kv_heads); if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES; if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE; REUSEKV_SWITCH(reusekv,[&] { NUM_THREADS_SWITCH(num_thread , [&] { constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 2; int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); dim3 grid; grid.x = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads; grid.y = max_num_partitions; grid.z = num_seqs; dim3 block(NUM_THREADS); int shared_mem_size = ::max(logits_size, outputs_size); if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d\n", reusekv,num_thread,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs); LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE); }); }); } //} } #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ paged_attention_v2_launcher_opt_tc( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_head_sliding_step, attn_masks, attn_masks_stride); #define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ switch (is_block_sparse) { \ case true: \ CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ break; \ case false: \ CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ break; \ } // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ switch (block_size) { \ case 8: \ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ break; \ case 16: \ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ break; \ case 32: \ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ } void paged_attention_v2_opt( torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] int64_t num_kv_heads, // [num_heads] double scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step, const c10::optional& attn_masks, // [num_seqs, max_seq_len] const int64_t attn_masks_stride); void paged_attention_v2_opt_tc( torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] int64_t num_kv_heads, // [num_heads] double scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step, const c10::optional& attn_masks, // [num_seqs, max_seq_len] const int64_t attn_masks_stride) { const bool is_block_sparse = (blocksparse_vert_stride > 1); if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse|| block_size!=16||query.size(2)!=128||get_device_name()!="gfx928"){ paged_attention_v2_opt(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads, scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, blocksparse_block_size,blocksparse_head_sliding_step, attn_masks, attn_masks_stride); } else{ DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE) } } #undef WARP_SIZE #undef MAX #undef MIN #undef DIVIDE_ROUND_UP