attention_kernels_opt_tc.cu 55.1 KB
Newer Older
1
2
3
4
5
6
7
8
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>

#include "attention_dtypes.h"
#include "attention_utils.cuh"

zhangshao's avatar
zhangshao committed
9
10
11

#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
12
13
typedef __hip_bfloat16 __nv_bfloat16;

zhangshao's avatar
zhangshao committed
14
#define WARP_SIZE 64
15

zhuwenwen's avatar
zhuwenwen committed
16
#include "static_switch_tc.h"
17
18
19
20
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))

zhangshao's avatar
zhangshao committed
21
std::string get_device_name()
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
{
    hipDeviceProp_t props{};
    int device;
    auto status = hipGetDevice(&device);
    if(status != hipSuccess)
    {
        return std::string();
    }

    status = hipGetDeviceProperties(&props, device);
    if(status != hipSuccess)
    {
        return std::string();
    }
    const std::string raw_name(props.gcnArchName);
    return raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str.
}
zhangshao's avatar
zhangshao committed
39
40
41

static const std::string device_name=get_device_name();

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
static inline int get_env_(const char *env_var) {
  if (char *value = std::getenv(env_var)) {
    return atoi(value);
  }
  return 0;
}

static const int PA_REUSE_KV_TIMES = get_env_("PA_REUSE_KV_TIMES");
static const int PA_BLOCK_SIZE = get_env_("PA_BLOCK_SIZE");
static const int PA_PRINT_PARAM = get_env_("PA_PRINT_PARAM");
namespace vllm {

// Utility function for attention softmax.
template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
  // Decompose the thread index into warp / lane.
  int warp = threadIdx.x / WARP_SIZE;
  int lane = threadIdx.x % WARP_SIZE;

  // Compute the sum per warp.
#pragma unroll
  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
    sum += VLLM_SHFL_XOR_SYNC(sum, mask);
  }

  // Warp leaders store the data to shared memory.
  if (lane == 0) {
    red_smem[warp] = sum;
  }

  // Make sure the data is in shared memory.
  __syncthreads();

  // The warps compute the final sums.
  if (lane < NUM_WARPS) {
    sum = red_smem[lane];
  }

  // Parallel reduction inside the warp.
#pragma unroll
  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
    sum += VLLM_SHFL_XOR_SYNC(sum, mask);
  }

  // Broadcast to other threads.
  return VLLM_SHFL_SYNC(sum, 0);
}

using half4_t = __attribute__( (__vector_size__(4 * sizeof(_Float16)) )) _Float16;
using v4bh = __attribute__( (__vector_size__(4 * sizeof(short)) )) short;
using float4_t = __attribute__( (__vector_size__(4 * sizeof(float)) )) float;
struct half4x2{
  half4_t data[2];
};

template<bool is_half>
inline __device__ void float4_2_half4(half4_t& dst,const float4_t& src)
{
  if constexpr(is_half){
    #pragma unroll
    for(int i=0;i<4;i++){
      dst[i]=src[i];
    }
  }
  else{
    __nv_bfloat16* out = reinterpret_cast<__nv_bfloat16 *>(&dst);
    #pragma unroll
    for(int i=0;i<4;i++){
      out[i]=__float2bfloat16(src[i]);
    }
  }
}

template<bool is_half>
inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c)
{
    
    if constexpr (is_half){
     asm volatile("v_mmac_f32_16x16x16_f16 %0, %1, %2, %0" : 
             "=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c));
    }
    else{
     asm volatile("v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0" : 
       "=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c));
    }
}

template<bool is_half,bool use_vmac>
inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c)
{
    if constexpr (use_vmac){v_mmac_f32_16x16x16_f16<is_half>(reg_a,reg_b,reg_c);}
    else{
      if constexpr (is_half){reg_c=__builtin_amdgcn_mmac_f32_16x16x16f16(reg_a,reg_b,reg_c);}
      else{
        reg_c=__builtin_amdgcn_mmac_f32_16x16x16bf16(*(v4bh*)&reg_a,*(v4bh*)&reg_b,reg_c);
      }
    }
}

// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
          int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
          bool IS_BLOCK_SPARSE,int REUSE_KV_TIMES,bool use_vmac,int PARTITION_SIZE = 0>  // Zero means no partitioning.
__device__ void paged_attention_kernel_TC(
    float* __restrict__ exp_sums,  // [num_seqs, num_heads, max_num_partitions]
    float* __restrict__ max_logits,  // [num_seqs, num_heads,
                                     // max_num_partitions]
    scalar_t* __restrict__ out,  // [num_seqs, num_heads, max_num_partitions,
                                 // head_size]
    const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
    const cache_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
                                          // head_size/x, block_size, x]
    const cache_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
                                          // head_size, block_size]
    const int num_heads,
    const int num_kv_heads,               // [num_heads]
    const float scale,
    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
    const int* __restrict__ seq_lens,      // [num_seqs]
    const int max_num_blocks_per_seq,
    const float* __restrict__ alibi_slopes,  // [num_heads]
    const int q_stride, const int kv_block_stride, const int kv_head_stride,
165
    const float* k_scale, const float* v_scale, const int tp_rank, 
166
    const int blocksparse_local_blocks, const int blocksparse_vert_stride, 
167
    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
168
  const int seq_idx = blockIdx.z;
zhangshao's avatar
zhangshao committed
169
170
  const int partition_idx = blockIdx.x;
  const int max_num_partitions = gridDim.x;
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
  constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
  const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]);
  if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
    // No work to do. Terminate the thread block.
    return;
  }
  constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value;
  static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS");
  const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
  const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
  const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
  // [start_block_idx, end_block_idx) is the range of blocks to process.
  const int start_block_idx = partition_idx * num_blocks_per_partition;//0,64,128…
  const int end_block_idx =MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);//64,128,192…
  const int num_blocks = end_block_idx - start_block_idx;//64 or 1-63

  // [start_token_idx, end_token_idx) is the range of tokens to process.
  const int start_token_idx = start_block_idx * BLOCK_SIZE;//0,1024,2048…
  const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);//1024,2048,3072…
  const int num_tokens = end_token_idx - start_token_idx;//1024 or 1-1023
                                        // divides NUM_THREADS 
  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;//4
  constexpr int x = 16 / sizeof(cache_t);//8
  const int thread_idx = threadIdx.x;
  const int warp_idx = __builtin_amdgcn_readfirstlane(thread_idx / WARP_SIZE);
  const int lane = thread_idx % WARP_SIZE;
  const int rowid = lane%16;
  const int rows = lane/16;
  
  const int num_queries_per_kv = num_heads / num_kv_heads;
  const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES);
zhangshao's avatar
zhangshao committed
202
  const int head_idx=(blockIdx.y / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.y % num_blocks_per_kv) * REUSE_KV_TIMES;
203
204

  int q_boundary=REUSE_KV_TIMES;
zhangshao's avatar
zhangshao committed
205
  if(num_heads < REUSE_KV_TIMES*gridDim.y && (num_blocks_per_kv-1)*REUSE_KV_TIMES == head_idx%num_queries_per_kv)
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    q_boundary=num_queries_per_kv-(num_blocks_per_kv-1)*REUSE_KV_TIMES;
  const int kv_head_idx = head_idx / num_queries_per_kv;
  constexpr int reuse_group=(REUSE_KV_TIMES-1)/4+1;
  float alibi_slope[reuse_group]={0.f};
  if(alibi_slopes != nullptr){
    for(int i=0;i<reuse_group;i++){
      int reuse_kv_idx=rows+i*4;
      if(reuse_kv_idx<q_boundary) alibi_slope[i]=alibi_slopes[head_idx+reuse_kv_idx];
    }
  }
  float qk_max[reuse_group];
  for(int i=0;i<reuse_group;i++){
    qk_max[i]=-FLT_MAX;
  }

  const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
  
  half4x2 q_vec;
  q_vec.data[0]={0,0,0,0};
  q_vec.data[1]={0,0,0,0};
  
  __shared__ half4x2 q_vecs[REUSE_KV_TIMES][16];
zhangshao's avatar
zhangshao committed
228
  //if(thread_idx==0)printf("blockIdx.y==%d,q_boundary=%d,head_idx=%d,kv_head_idx=%d\n",blockIdx.y,q_boundary,head_idx,kv_head_idx);
229
    for(int i=0;i<q_boundary;i++){
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    if(thread_idx<16){
      q_vecs[i][thread_idx]=*reinterpret_cast<const half4x2*>(q_ptr+i*HEAD_SIZE+thread_idx*8);
    }
  }
  __syncthreads();
  // Memory planning.
  extern __shared__ char shared_mem[];
  // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
  scalar_t* logits = reinterpret_cast<scalar_t*>(shared_mem);
  // Workspace for reduction.
  __shared__ float red_smem[2 * NUM_WARPS];
 
  // Iterate over the key blocks.
  // Each warp fetches a block of keys for each iteration.
  // Each thread group in a warp fetches a key from the block, and computes
  // dot product with the query.
  const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;

  // blocksparse specific vars
  int bs_block_offset;
  int q_bs_block_id;
  const cache_t* k_ptr_base = k_cache+kv_head_idx * kv_head_stride+lane*8;

  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
       block_idx += NUM_WARPS) {

    const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
    const cache_t* k_ptr=k_ptr_base + physical_block_number * kv_block_stride;
    float4_t qk_vec={0,0,0,0};

    half4x2 k_vec[2];
    k_vec[0]=*reinterpret_cast<const half4x2*>(k_ptr);
    #pragma unroll
    for(int i=0;i<3;i++){
      if(rowid<q_boundary)q_vec=q_vecs[rowid][i*4+rows];
      k_vec[1-i%2]=*reinterpret_cast<const half4x2*>(k_ptr+(i+1)*512);
      builtin_amdgcn_mmac<is_half,use_vmac>(k_vec[i%2].data[0],q_vec.data[0],qk_vec);
      builtin_amdgcn_mmac<is_half,use_vmac>(k_vec[i%2].data[1],q_vec.data[1],qk_vec);
    }
    //tail
    {
      if(rowid<q_boundary)q_vec=q_vecs[rowid][3*4+rows];
      builtin_amdgcn_mmac<is_half,use_vmac>(k_vec[1].data[0],q_vec.data[0],qk_vec);
      v_mmac_f32_16x16x16_f16<is_half>(k_vec[1].data[1],q_vec.data[1],qk_vec);
    }
    #pragma unroll
    for(int i=0;i<reuse_group;i++){
      int reuse_kv_idx=rows+i*4;
      if(reuse_kv_idx<REUSE_KV_TIMES){
        if(reuse_kv_idx>=q_boundary)qk_vec[i]=0;
        else qk_vec[i]*=scale;
        const int token_idx = block_idx * BLOCK_SIZE+rowid;
        if(alibi_slope[i] != 0){
          float alibi=alibi_slope[i]* (token_idx - seq_len + 1);
zhuwenwen's avatar
zhuwenwen committed
284
          qk_vec[i] += alibi;
285
        }
286

287
        const bool mask = (token_idx >= seq_len);
288
289
290
291
292
293
294
295
296
297
        if(mask){
          from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f);
        }
        else{
          from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , qk_vec[i]);
          qk_max[i] = fmaxf(qk_max[i], qk_vec[i]);
        }
      }
    }
  }
zhangshao's avatar
zhangshao committed
298
  // if(blockIdx.y==0)printf("%d,qkmax=%f\n",threadIdx.x,qk_max[0]);
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
  // Perform reduction across the threads in the same warp to get the
  // max qk value for each "warp" (not across the thread block yet).
  // The 0-th thread of each thread group already has its max qk value.
  for(int reuse_kv_idx=0; reuse_kv_idx<q_boundary; reuse_kv_idx++) {
    const int head_idx_ = head_idx + reuse_kv_idx;
    float qk_max_tmp=qk_max[reuse_kv_idx/4];
    float exp_sum = 0.f;
    #pragma unroll
    for (int mask = 8; mask >= 1; mask /= 2) {
      qk_max_tmp = fmaxf(qk_max_tmp, VLLM_SHFL_XOR_SYNC(qk_max_tmp, mask));
    }
    if (rowid==0 && reuse_kv_idx%4==rows) {
      red_smem[warp_idx] = qk_max_tmp;
    }
    __syncthreads();

    // TODO(woosuk): Refactor this part.
    // Get the max qk value for the sequence.
    qk_max_tmp = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
    #pragma unroll
      for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
        qk_max_tmp = fmaxf(qk_max_tmp, VLLM_SHFL_XOR_SYNC(qk_max_tmp, mask));
      }
      // Broadcast the max qk value to all threads.
      qk_max_tmp = VLLM_SHFL_SYNC(qk_max_tmp, 0);
      for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
        float val = __expf(to_float(logits[(reuse_kv_idx * partition_size) + i]) - qk_max_tmp);
        from_float(logits[(reuse_kv_idx * partition_size) + i] , val);
        exp_sum += val;
      }
      exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
      // Compute softmax.
      const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
      for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
        from_float(logits[(reuse_kv_idx * partition_size) + i] ,to_float(logits[(reuse_kv_idx * partition_size) + i])*inv_sum);
      }
      __syncthreads();

      // If partitioning is enabled, store the max logit and exp_sum.
      if (USE_PARTITIONING && thread_idx == 0) {
        float* max_logits_ptr = max_logits +
                                seq_idx * num_heads * max_num_partitions +
                                head_idx_ * max_num_partitions + partition_idx;
        *max_logits_ptr = qk_max_tmp;
        float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
                              head_idx_ * max_num_partitions + partition_idx;
        *exp_sums_ptr = exp_sum;
      }
  }
  constexpr int NUM_ROWS_PER_THREAD =DIVIDE_ROUND_UP(HEAD_SIZE, WARP_SIZE);//2
  if constexpr(REUSE_KV_TIMES<=2){
    float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD];
    #pragma unroll
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
      #pragma unroll
      for(int k=0;k<REUSE_KV_TIMES;k++)
      {
        accs[k][i] = 0.f;
      }
      
    }
    scalar_t zero_value;
    zero(zero_value);
    for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
        block_idx += NUM_WARPS) {
      const int64_t physical_block_number =
          static_cast<int64_t>(block_table[block_idx]);
      const int token_idx = block_idx * BLOCK_SIZE +rows*4;
      half4_t logits_vec={0,0,0,0};
      if(rowid<4*q_boundary){
          logits_vec=*reinterpret_cast<half4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
      }
      const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
                            kv_head_idx * kv_head_stride + rows*4+rowid*16;
      #pragma unroll
      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
        #pragma unroll
        for(int k=0;k<4;k++){
          int offset=i*1024+k*256;
          half4_t v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
          if (block_idx == num_seq_blocks - 1) {
            scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
            #pragma unroll
            for (int j = 0; j < 4; j++) {
              v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
            }
          }
          float4_t out_vec={0,0,0,0};
          builtin_amdgcn_mmac<is_half,use_vmac>(v_vec,logits_vec,out_vec);
          if(rows==k){
            for(int resuseid=0;resuseid<REUSE_KV_TIMES;resuseid++){
              accs[resuseid][i]+=out_vec[resuseid];
            }
          }
        }
      } 
    } 
    __syncthreads();
zhangshao's avatar
zhangshao committed
397
    using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float;
398
    // Perform reduction across warps.
zhangshao's avatar
zhangshao committed
399
    
400
    for(int reuse_kv_idx=0; reuse_kv_idx<q_boundary; reuse_kv_idx++) {
zhangshao's avatar
zhangshao committed
401
402
403
404
405
406
407
408
      if constexpr (NUM_THREADS>64){
        floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
        #pragma unroll
        for (int i = NUM_WARPS; i > 1; i /= 2) {
          int mid = i / 2;
          // Upper warps write to shared memory.
          if (warp_idx >= mid && warp_idx < i) {
            out_smem[(warp_idx - mid) * 64+lane]=*(floatV_t*)(accs[reuse_kv_idx]);
409
          }
zhangshao's avatar
zhangshao committed
410
411
412
413
414
415
416
417
          __syncthreads();
          // Lower warps update the output.
          if (warp_idx < mid) {
            floatV_t tmp=out_smem[thread_idx];
            #pragma unroll
            for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
              accs[reuse_kv_idx][i] += tmp[i];
            }
418
          }
zhangshao's avatar
zhangshao committed
419
          __syncthreads();
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
        }
      }
      // Write the final output.
      if (warp_idx == 0) {
        scalar_t* out_ptr =
            out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
            (head_idx+reuse_kv_idx) * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
    #pragma unroll
        for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
          const int row_idx = lane + i * WARP_SIZE;
            from_float(*(out_ptr + row_idx), accs[reuse_kv_idx][i]);
        }
      }
    }
  }
zhangshao's avatar
zhangshao committed
435
#if defined __gfx928__
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
  else{
    constexpr int GROUPS=reuse_group*4;
    // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
    float accs[GROUPS][NUM_ROWS_PER_THREAD];
    #pragma unroll
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
      #pragma unroll
      for(int k=0;k<GROUPS;k++)
      {
        accs[k][i] = 0.f;
      }
    }
    scalar_t zero_value;
    zero(zero_value);
    for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
        block_idx += NUM_WARPS) {
      const int64_t physical_block_number =
          static_cast<int64_t>(block_table[block_idx]);
      const int token_idx = block_idx * BLOCK_SIZE +rows*4;
      half4_t logits_vec={0,0,0,0};
      if(rowid<q_boundary){
          logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
      }
      const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
                            kv_head_idx * kv_head_stride + rows*4+rowid*16;
      #pragma unroll
      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
        #pragma unroll
        for(int k=0;k<4;k++){
          int offset=i*1024+k*256;
          half4_t v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
          if (block_idx == num_seq_blocks - 1) {
            scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
            #pragma unroll
            for (int j = 0; j < 4; j++) {
              v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
            }
          }
          float4_t out_vec={0,0,0,0};
          builtin_amdgcn_mmac<is_half,use_vmac>(v_vec,logits_vec,out_vec);
          for(int g=0;g<reuse_group;g++){
            accs[g*4+k][i]+=out_vec[g];
          }
        }
      } 
zhangshao's avatar
zhangshao committed
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
    }
    if constexpr (NUM_THREADS>64){
      __syncthreads();
      using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float;
      // Perform reduction across warps.
      
      for(int reuse_kv_idx=0; reuse_kv_idx<GROUPS; reuse_kv_idx++) {
        
        floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
        #pragma unroll
        for (int i = NUM_WARPS; i > 1; i /= 2) {
          int mid = i / 2;
          // Upper warps write to shared memory.
          if (warp_idx >= mid && warp_idx < i) {
            out_smem[(warp_idx - mid) * 64+lane]=*(floatV_t*)(accs[reuse_kv_idx]);
496
          }
zhangshao's avatar
zhangshao committed
497
498
499
500
501
502
503
504
          __syncthreads();
          // Lower warps update the output.
          if (warp_idx < mid) {
            floatV_t tmp=out_smem[thread_idx];
            #pragma unroll
            for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
              accs[reuse_kv_idx][i] += tmp[i];
            }
505
          }
zhangshao's avatar
zhangshao committed
506
          __syncthreads();
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
        }
      }
    }
    if (warp_idx == 0) {
      for(int g=0;g<reuse_group;g++){
        int reusekvid=g*4+rows;
        if(reusekvid<q_boundary){
          scalar_t* out_ptr =
              out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
              (head_idx+reusekvid) * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
          #pragma unroll
          for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
            for(int k=0;k<4;k++){
              const int row_idx = rowid+16*k + i * WARP_SIZE;
              from_float(*(out_ptr + row_idx), accs[g*4+k][i]);
            }
          }
        }
      }
    }
  }
zhangshao's avatar
zhangshao committed
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
#else
  else{
    constexpr int GROUPS=reuse_group*4;
    // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
    float4_t accs[4][NUM_ROWS_PER_THREAD];
    #pragma unroll
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
      #pragma unroll
      for(int k=0;k<4;k++)
      {
        accs[k][i] = {0.f,0.f,0.f,0.f};
      }
    }
    scalar_t zero_value;
    zero(zero_value);
    for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
        block_idx += NUM_WARPS) {
      const int64_t physical_block_number =
          static_cast<int64_t>(block_table[block_idx]);
      const int token_idx = block_idx * BLOCK_SIZE +rows*4;
      half4_t logits_vec={0,0,0,0};
      if(rowid<q_boundary){
          logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
      }
      const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
                            kv_head_idx * kv_head_stride + rows*4+rowid*16;
      #pragma unroll
      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
        #pragma unroll
        for(int k=0;k<4;k++){
          int offset=i*1024+k*256;
          half4_t v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
          if (block_idx == num_seq_blocks - 1) {
            scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
            #pragma unroll
            for (int j = 0; j < 4; j++) {
              v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
            }
          }
          builtin_amdgcn_mmac<is_half,use_vmac>(v_vec,logits_vec,accs[k][i]);
        }
      } 
    }
    if constexpr (NUM_THREADS>64){
      __syncthreads();
      using floatV_t = __attribute__( (__vector_size__(reuse_group * sizeof(float)) )) float;
      // Perform reduction across warps.
      
      for(int m=0; m<4; m++) {
        floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
        #pragma unroll
        for (int i = NUM_WARPS; i > 1; i /= 2) {
          int mid = i / 2;
          // Upper warps write to shared memory.
          if (warp_idx >= mid && warp_idx < i) {
            for(int k=0;k<NUM_ROWS_PER_THREAD;k++){
              out_smem[((warp_idx - mid) * 64+lane)*NUM_ROWS_PER_THREAD+k]=*(floatV_t*)(&(accs[m][k]));
            }
          }
          __syncthreads();
          // Lower warps update the output.
          if (warp_idx < mid) {
            for(int k=0;k<NUM_ROWS_PER_THREAD;k++){
              floatV_t tmp=out_smem[thread_idx*NUM_ROWS_PER_THREAD+k];
              #pragma unroll
              for (int i = 0; i < reuse_group; i++) {
                accs[m][k][i] += tmp[i];
              }
            }
          }
          __syncthreads();
        }
      }
    }
    if (warp_idx == 0) {
      for(int g=0;g<reuse_group;g++){
        int reusekvid=g*4+rows;
        if(reusekvid<q_boundary){
          scalar_t* out_ptr =
              out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
              (head_idx+reusekvid) * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
          #pragma unroll
          for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
            for(int k=0;k<4;k++){
              const int row_idx = rowid+16*k + i * WARP_SIZE;
              from_float(*(out_ptr + row_idx), accs[k][i][g]);
            }
          }
        }
      }
    }
  }
#endif
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
}


template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
          int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
          bool IS_BLOCK_SPARSE,int REUSE_KV_TIMES,bool use_vmac>
__global__ void paged_attention_v1_kernel_TC(
    scalar_t* __restrict__ out,           // [num_seqs, num_heads, head_size]
    const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
    const cache_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
                                          // head_size/x, block_size, x]
    const cache_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
                                          // head_size, block_size]
    const int num_heads,
    const int num_kv_heads,               // [num_heads]
    const float scale,
    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
    const int* __restrict__ seq_lens,      // [num_seqs]
    const int max_num_blocks_per_seq,
    const float* __restrict__ alibi_slopes,  // [num_heads]
    const int q_stride, const int kv_block_stride, const int kv_head_stride,
642
    const float* k_scale, const float* v_scale, const int tp_rank, 
643
    const int blocksparse_local_blocks, const int blocksparse_vert_stride, 
644
    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
zhangshao's avatar
zhangshao committed
645
    #if defined(__gfx936__) || defined(__gfx928__)
646
647
648
649
650
651
652
    paged_attention_kernel_TC<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
                          KV_DTYPE, IS_BLOCK_SPARSE,REUSE_KV_TIMES,use_vmac>(
        /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
        v_cache, num_heads,num_kv_heads, scale, block_tables, seq_lens,
        max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
        kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
        blocksparse_vert_stride, blocksparse_block_size,
653
        blocksparse_head_sliding_step);
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
    #endif
  } 

// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
          int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
          bool IS_BLOCK_SPARSE, int REUSE_KV_TIMES,bool use_vmac, int PARTITION_SIZE,
          bool odd_nheads = false>
__global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
    float* __restrict__ exp_sums,  // [num_seqs, num_heads, max_num_partitions]
    float* __restrict__ max_logits,       // [num_seqs, num_heads,
                                          // max_num_partitions]
    scalar_t* __restrict__ tmp_out,       // [num_seqs, num_heads,
                                          // max_num_partitions, head_size]
    const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
    const cache_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
                                          // head_size/x, block_size, x]
    const cache_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
                                          // head_size, block_size]
    const int num_heads,                  // [num_heads]
    const int num_kv_heads,               // [num_kv_heads]
    const float scale,
    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
    const int* __restrict__ seq_lens,      // [num_seqs]
    const int max_num_blocks_per_seq,
    const float* __restrict__ alibi_slopes,  // [num_heads]
    const int q_stride, const int kv_block_stride, const int kv_head_stride,
681
    const float* k_scale, const float* v_scale, const int tp_rank, 
682
    const int blocksparse_local_blocks, const int blocksparse_vert_stride, 
683
    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
zhangshao's avatar
zhangshao committed
684
  #if defined(__gfx936__) || defined(__gfx928__)
685
686
687
688
689
690
691
  paged_attention_kernel_TC<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
                         KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES,use_vmac,
                         PARTITION_SIZE>(
      exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads,
      num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq,
      alibi_slopes, q_stride, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, 
      blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, 
692
      blocksparse_head_sliding_step);
693
694
695
696
697
  #endif
}

// Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS, int PARTITION_SIZE>
zhuwenwen's avatar
zhuwenwen committed
698
__global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_tc(
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
    scalar_t* __restrict__ out,            // [num_seqs, num_heads, head_size]
    const float* __restrict__ exp_sums,    // [num_seqs, num_heads,
                                           // max_num_partitions]
    const float* __restrict__ max_logits,  // [num_seqs, num_heads,
                                           // max_num_partitions]
    const scalar_t* __restrict__ tmp_out,  // [num_seqs, num_heads,
                                           // max_num_partitions, head_size]
    const int* __restrict__ seq_lens,      // [num_seqs]
    const int max_num_partitions) {
  const int num_heads = gridDim.x;
  const int head_idx = blockIdx.x;
  const int seq_idx = blockIdx.y;
  const int seq_len = seq_lens[seq_idx];
  const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
  if (num_partitions == 1) {
    // No need to reduce. Only copy tmp_out to out.
    scalar_t* out_ptr =
        out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
    const scalar_t* tmp_out_ptr =
        tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
        head_idx * max_num_partitions * HEAD_SIZE;
    for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
      out_ptr[i] = tmp_out_ptr[i];
    }
    // Terminate the thread block.
    return;
  }

  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  const int warp_idx = threadIdx.x / WARP_SIZE;
  const int lane = threadIdx.x % WARP_SIZE;

  // Size: 2 * num_partitions.
  extern __shared__ char shared_mem[];
  // Workspace for reduction.
  __shared__ float red_smem[2 * NUM_WARPS];

  // Load max logits to shared memory.
  float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
  const float* max_logits_ptr = max_logits +
                                seq_idx * num_heads * max_num_partitions +
                                head_idx * max_num_partitions;
  float max_logit = -FLT_MAX;
  for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
    const float l = max_logits_ptr[i];
    shared_max_logits[i] = l;
    max_logit = fmaxf(max_logit, l);
  }
  __syncthreads();

  // Get the global max logit.
  // Reduce within the warp.
 #pragma unroll
  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
    max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
  }
  if (lane == 0) {
    red_smem[warp_idx] = max_logit;
  }
  __syncthreads();
  // Reduce across warps.
  max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
  #pragma unroll
  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
    max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
  }
  // Broadcast the max value to all threads.
  max_logit = VLLM_SHFL_SYNC(max_logit, 0);

  // Load rescaled exp sums to shared memory.
  float* shared_exp_sums =
      reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
  const float* exp_sums_ptr = exp_sums +
                              seq_idx * num_heads * max_num_partitions +
                              head_idx * max_num_partitions;
  float global_exp_sum = 0.0f;
  for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
    float l = shared_max_logits[i];
    float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
    global_exp_sum += rescaled_exp_sum;
    shared_exp_sums[i] = rescaled_exp_sum;
  }
  __syncthreads();
  global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
  const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);

  // Aggregate tmp_out to out.
  const scalar_t* tmp_out_ptr =
      tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
      head_idx * max_num_partitions * HEAD_SIZE;
  scalar_t* out_ptr =
      out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  #pragma unroll
  for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
    float acc = 0.0f;
    for (int j = 0; j < num_partitions; ++j) {
      acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
             inv_global_exp_sum;
    }
    from_float(out_ptr[i], acc);
  }
}

}  // namespace vllm


#define LAUNCH_PAGED_ATTENTION_V1_TC(HEAD_SIZE)                                \
  VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                     \
      ((void*)vllm::paged_attention_v1_kernel_TC<T, CACHE_T, HEAD_SIZE,        \
                                              BLOCK_SIZE, NUM_THREADS,      \
                                              KV_DTYPE, IS_BLOCK_SPARSE,REUSE_KV_TIMES,use_vmac>),  \
      shared_mem_size);                                                     \
  vllm::paged_attention_v1_kernel_TC<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE,        \
                                  NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE,REUSE_KV_TIMES,use_vmac>   \
      <<<grid, block, shared_mem_size, stream>>>(                           \
          out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads,num_kv_heads, \
          scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq,    \
          alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride,      \
817
          k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks,              \
818
          blocksparse_vert_stride, blocksparse_block_size,                  \
819
          blocksparse_head_sliding_step);
820
821
822
823

void get_numberthread_and_reuse_kv_v1(int& num_thread,int& reusekv,int batchsize,int seq,int qheads,int kvheads){
  //mha
  reusekv=1;
zhangshao's avatar
zhangshao committed
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
  num_thread=256;
  if(device_name=="gfx936"){//bw
    if(qheads==kvheads){
      if(seq<16){num_thread=64;return;}
      if(batchsize>=32&&seq>=1000)return;
      if(batchsize*qheads>=512)num_thread=64;
      return;
    }
    if(seq<=16){
      num_thread=64;
      if(qheads*batchsize>1000)reusekv=4;
      return;
    }
    if(seq<=64){
      if(qheads*batchsize>1000)reusekv=4;
      return;
    }
    if(seq<=200){
      if(qheads*batchsize>400)reusekv=4;
      return;
    }
    if(seq<=500){
     if(qheads*batchsize>200)reusekv=4;
      return;
    }
    if(((qheads-1)/16+1)*batchsize>=64&&qheads/kvheads>4&&seq<7800)reusekv=8;
    else if(qheads*batchsize>100)reusekv=4;
    return;
  }
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
  if(qheads==kvheads){
    //llama 7B ,其他模型未可知
    if(seq<=16||batchsize>=32)num_thread=64;
    else if(batchsize<=2)num_thread=256;
    else if(batchsize<8)num_thread=128;
    else num_thread=64;
    return;
  }
  // mqa
  if(qheads>kvheads*4){
    if(seq<64){
      if(batchsize<=64){reusekv=1;num_thread=64;}
      else if(batchsize<128){reusekv=2;num_thread=64;}
      else {reusekv=4;num_thread=64;}
    }
    else if(seq<=400){
      if(batchsize<16){reusekv=1;num_thread=256;}
      else if(batchsize<64){reusekv=2;num_thread=256;}
      else if(batchsize<=128){
          reusekv=4;
          if(qheads%7==0)num_thread=64;//qwen7b
          else num_thread=256;//llama70b
        }
      else {reusekv=8;num_thread=64;}
    }
    else if(seq<=1000){
      if(batchsize<16){reusekv=1;num_thread=256;}
      else if(qheads%7==0&&batchsize<=128){//qwen7b
        if(batchsize<64){reusekv=4;num_thread=256;}
        else{reusekv=4;num_thread=64;}
      }
      else if(batchsize<=64){reusekv=4;num_thread=256;}
      else {reusekv=8;num_thread=128;}
    }
    else if(seq<3900) {reusekv=8;num_thread=256;}
    else if(seq<7800) {reusekv=4;num_thread=256;}
    else {reusekv=2;num_thread=256;}
    return;
  }

  if(qheads/kvheads >4 && seq<3900)reusekv=8;
  else if(qheads/kvheads >2 && seq<7800)reusekv=4;
  else if(qheads/kvheads >=2 && seq<15600)reusekv=2;

  if(seq<=64){
    num_thread=64;
    if(batchsize<=64)reusekv=1;
  }
  else num_thread=256;
}

// TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE,
          vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE>
zhuwenwen's avatar
zhuwenwen committed
907
void paged_attention_v1_launcher_opt_tc(
908
909
910
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int num_kv_heads, float scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
911
912
    const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
    torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
913
    const int blocksparse_vert_stride, const int blocksparse_block_size,
914
    const int blocksparse_head_sliding_step) {
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
  int num_seqs = query.size(0);
  int num_heads = query.size(1);
  int head_size = query.size(2);
  int max_num_blocks_per_seq = block_tables.size(1);
  int q_stride = query.stride(0);
  int kv_block_stride = key_cache.stride(0);
  int kv_head_stride = key_cache.stride(1);
  int num_threads = 128;
  // printf("paged_attention_v1\n");
  if (num_heads != num_kv_heads) {
    num_threads = 256;
  }
  [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  assert(head_size % thread_group_size == 0);

  // NOTE: alibi_slopes is optional.
  const float* alibi_slopes_ptr =
      alibi_slopes
          ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
          : nullptr;
935
936
  const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
  const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
937
938
939
940
941
942
  T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
  CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
  int* block_tables_ptr = block_tables.data_ptr<int>();
  int* seq_lens_ptr = seq_lens.data_ptr<int>();
943

944
945
946
947
948
949
950
951
952
953
954
955
956
957
  int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2 && KV_DTYPE==vllm::Fp8KVCacheDataType::kAuto){
      constexpr int HEAD_SIZE=128;
      constexpr static int use_vmac = false;
      int reusekv, num_thread;
      get_numberthread_and_reuse_kv_v1(num_thread,reusekv,num_seqs,padded_max_seq_len,num_heads,num_kv_heads);
      if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES;
      if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE;
      REUSEKV_SWITCH(reusekv,[&] {
        NUM_THREADS_SWITCH(num_thread , [&] {
          //constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES; 
          constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
zhangshao's avatar
zhangshao committed
958
959
960
          int logits_size = REUSE_KV_TIMES  * padded_max_seq_len * 2;
          int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
          if (NUM_WARPS==64)outputs_size=0;
961
          int shared_mem_size = ::max(logits_size, outputs_size);
zhangshao's avatar
zhangshao committed
962
          dim3 grid(1,(num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads,num_seqs);
963
          dim3 block(NUM_THREADS);
zhangshao's avatar
zhangshao committed
964
965
          if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d\n",
                                    reusekv,num_thread,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs);
966
967
968
969
970
971
972
          LAUNCH_PAGED_ATTENTION_V1_TC(HEAD_SIZE);
        });
      });
    }
}

#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE)  \
zhuwenwen's avatar
zhuwenwen committed
973
  paged_attention_v1_launcher_opt_tc<T, CACHE_T, BLOCK_SIZE, KV_DTYPE,              \
974
975
976
977
                              IS_BLOCK_SPARSE>(                              \
      out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
      seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank,        \
      blocksparse_local_blocks, blocksparse_vert_stride,                     \
978
      blocksparse_block_size, blocksparse_head_sliding_step);
979
980

#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
zhuwenwen's avatar
zhuwenwen committed
981
982
983
984
  if (is_block_sparse) {                                                   \
    CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true);       \
  } else {                                                                 \
    CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false);      \
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
  }

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE)         \
  switch (block_size) {                                           \
    case 8:                                                       \
      CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE);         \
      break;                                                      \
    case 16:                                                      \
      CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE);        \
      break;                                                      \
    case 32:                                                      \
      CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE);        \
      break;                                                      \
    default:                                                      \
      TORCH_CHECK(false, "Unsupported block size: ", block_size); \
      break;                                                      \
  }

1005
void paged_attention_v1(
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
    torch::Tensor& out,    // [num_seqs, num_heads, head_size]
    torch::Tensor& query,  // [num_seqs, num_heads, head_size]
    torch::Tensor&
        key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x]
    torch::Tensor&
        value_cache,       // [num_blocks, num_heads, head_size, block_size]
    int64_t num_kv_heads,  // [num_heads]
    double scale,
    torch::Tensor& block_tables,  // [num_seqs, max_num_blocks_per_seq]
    torch::Tensor& seq_lens,      // [num_seqs]
    int64_t block_size, int64_t max_seq_len,
    const c10::optional<torch::Tensor>& alibi_slopes,
1018
    const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank,
1019
1020
    const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
1021
    const int64_t blocksparse_head_sliding_step);
1022

zhuwenwen's avatar
zhuwenwen committed
1023
void paged_attention_v1_opt_tc(
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
    torch::Tensor& out,    // [num_seqs, num_heads, head_size]
    torch::Tensor& query,  // [num_seqs, num_heads, head_size]
    torch::Tensor&
        key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x]
    torch::Tensor&
        value_cache,       // [num_blocks, num_heads, head_size, block_size]
    int64_t num_kv_heads,  // [num_heads]
    double scale,
    torch::Tensor& block_tables,  // [num_seqs, max_num_blocks_per_seq]
    torch::Tensor& seq_lens,      // [num_seqs]
    int64_t block_size, int64_t max_seq_len,
    const c10::optional<torch::Tensor>& alibi_slopes,
1036
    const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, 
1037
1038
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
1039
    const int64_t blocksparse_head_sliding_step) {
1040
1041
  const bool is_block_sparse = (blocksparse_vert_stride > 1);
  if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
zhangshao's avatar
zhangshao committed
1042
      block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
1043
    paged_attention_v1(out,query,key_cache,value_cache,num_kv_heads,
1044
1045
                       scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
                       k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
1046
                       blocksparse_block_size,blocksparse_head_sliding_step);
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
  }
  else{
    DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
                              CALL_V1_LAUNCHER_BLOCK_SIZE)
  }
}

#define LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE)                                   \
  hipLaunchKernelGGL(                                                          \
      (vllm::paged_attention_v2_kernel_TC<                                        \
          T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE,            \
          IS_BLOCK_SPARSE, REUSE_KV_TIMES,use_vmac, PARTITION_SIZE>),       \
      dim3(grid), dim3(block), shared_mem_size, stream, exp_sums_ptr,          \
      max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr,  \
      num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr,          \
      max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride,     \
1063
      kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks,             \
1064
      blocksparse_vert_stride, blocksparse_block_size,                         \
1065
      blocksparse_head_sliding_step);                                          \
1066
  hipLaunchKernelGGL(                                                          \
zhuwenwen's avatar
zhuwenwen committed
1067
      (vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, NUM_THREADS,       \
1068
                                              PARTITION_SIZE>),                \
zhangshao's avatar
zhangshao committed
1069
      dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, out_ptr, \
1070
1071
1072
      exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr,                 \
      max_num_partitions);

zhangshao's avatar
zhangshao committed
1073
void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int batchsize,int max_num_partitions,int qheads,int kvheads,int num_blocks){
1074
  reusekv=1;
zhangshao's avatar
zhangshao committed
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
  num_thread=256;
  if(device_name=="gfx936"){//bw
    if(max_num_partitions==16&&num_blocks==1024){//ali test
      if(batchsize==1&&qheads==16&&kvheads==16){num_thread=128;return;}
      if(batchsize==1&&qheads==32&&kvheads==32){num_thread=64;return;}
      if(batchsize==1){
        if(qheads==52){reusekv=8;return;}
        if(qheads==13){reusekv=2;return;}
        reusekv=4;return;
      }
      if(batchsize==64){
        if(qheads==13||qheads==32){num_thread=128;reusekv=8;return;}
        if(qheads==52){reusekv=16;return;}
        reusekv=8;return;
      }
    }
    if(qheads==kvheads)return;
    int bp=max_num_partitions*batchsize;
    if(qheads/kvheads>4){
      if(qheads==16&&bp>96||qheads<16&&bp>=192||qheads>16&&bp>24){reusekv=8;return;}
    }
    if(qheads/4*bp>=32)reusekv=4;
    return;
  }
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
  int blocks=batchsize*qheads*max_num_partitions;
  if(qheads==kvheads){
    if(blocks<=80||blocks>8000){num_thread=256;}
    else if(blocks<=160){num_thread=128;}
    else num_thread=64;
    return;
  }
  if(qheads/kvheads>8&&blocks>4000){
    reusekv=16;
    if(blocks>40000)num_thread=64;
    else num_thread=128;
  }
  else if(qheads/kvheads==5||qheads/kvheads==7){
    if(blocks<=160){reusekv=1;num_thread=256;}
    else if(blocks<640/5*qheads/kvheads){reusekv=4;num_thread=256;}
    else if(blocks<1920){reusekv=8;num_thread=128;}
    else {reusekv=8;num_thread=64;}
  }
  else if(qheads>kvheads*4){
    if(blocks<=128){reusekv=1;num_thread=256;}
    else if(blocks<1536){reusekv=4;num_thread=256;}
    else if(blocks<6144){reusekv=8;num_thread=128;}
    else {reusekv=8;num_thread=64;}
  }
  else {
    if(blocks<=128){reusekv=1;num_thread=256;}
    else if(blocks<3000){reusekv=4;num_thread=256;}
    else {reusekv=4;num_thread=64;}
  }
}

template <typename T, typename CACHE_T, int BLOCK_SIZE,
          vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE, int PARTITION_SIZE = 512>
zhuwenwen's avatar
zhuwenwen committed
1132
void paged_attention_v2_launcher_opt_tc(
1133
1134
1135
1136
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int num_kv_heads, float scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
1137
1138
    const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
    torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
1139
    const int blocksparse_vert_stride, const int blocksparse_block_size,
1140
    const int blocksparse_head_sliding_step) {
1141
1142
1143
1144
1145
1146
1147
  int num_seqs = query.size(0);
  int num_heads = query.size(1);
  int head_size = query.size(2);
  int max_num_blocks_per_seq = block_tables.size(1);
  int q_stride = query.stride(0);
  int kv_block_stride = key_cache.stride(0);
  int kv_head_stride = key_cache.stride(1);
zhangshao's avatar
zhangshao committed
1148
  int num_blocks=key_cache.size(0);
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
  // printf("paged_attention_v2\n");
  int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  assert(head_size % thread_group_size == 0);

  // NOTE: alibi_slopes is optional.
  const float* alibi_slopes_ptr =
      alibi_slopes
          ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
          : nullptr;

  T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
1160
1161
  const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
  const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
1162
1163
1164
1165
1166
1167
1168
1169
  float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
  float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
  T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
  CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
  int* block_tables_ptr = block_tables.data_ptr<int>();
  int* seq_lens_ptr = seq_lens.data_ptr<int>();
1170

1171
1172
1173
1174
1175
1176
1177
  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  dim3 reduce_grid(num_heads, num_seqs);
  int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
  int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);

  if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2 && KV_DTYPE==vllm::Fp8KVCacheDataType::kAuto){
zhangshao's avatar
zhangshao committed
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
    constexpr int HEAD_SIZE=128;
    constexpr static int use_vmac = false;
    int reusekv, num_thread;
    get_numberthread_and_reuse_kv_v2(num_thread,reusekv,num_seqs,max_num_partitions,num_heads,num_kv_heads,num_blocks);
    if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES;
    if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE;
    REUSEKV_SWITCH(reusekv,[&] {
      NUM_THREADS_SWITCH(num_thread , [&] {
        constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
        int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 2;
        int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
        dim3 grid;
        grid.y = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads;
        grid.x = max_num_partitions;
        grid.z = num_seqs;
        dim3 block(NUM_THREADS);
        int shared_mem_size = ::max(logits_size, outputs_size);
        if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d\n",
                                  reusekv,num_thread,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs);
        LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE);
1198
      });
zhangshao's avatar
zhangshao committed
1199
1200
    });
  }
1201
1202
1203
}

#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE)   \
zhuwenwen's avatar
zhuwenwen committed
1204
  paged_attention_v2_launcher_opt_tc<T, CACHE_T, BLOCK_SIZE, KV_DTYPE,               \
1205
1206
1207
1208
1209
                              IS_BLOCK_SPARSE>(                               \
      out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache,      \
      num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
      k_scale, v_scale, tp_rank, blocksparse_local_blocks,                    \
      blocksparse_vert_stride, blocksparse_block_size,                        \
1210
      blocksparse_head_sliding_step);
1211
1212

#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
zhuwenwen's avatar
zhuwenwen committed
1213
1214
1215
1216
  if (is_block_sparse) {                                                   \
    CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true);       \
  } else {                                                                 \
    CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false);      \
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
  }

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE)         \
  switch (block_size) {                                           \
    case 8:                                                       \
      CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE);         \
      break;                                                      \
    case 16:                                                      \
      CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE);        \
      break;                                                      \
    case 32:                                                      \
      CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE);        \
      break;                                                      \
    default:                                                      \
      TORCH_CHECK(false, "Unsupported block size: ", block_size); \
      break;                                                      \
  }

1237
void paged_attention_v2(
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
    torch::Tensor& out,         // [num_seqs, num_heads, head_size]
    torch::Tensor& exp_sums,    // [num_seqs, num_heads, max_num_partitions]
    torch::Tensor& max_logits,  // [num_seqs, num_heads, max_num_partitions]
    torch::Tensor&
        tmp_out,  // [num_seqs, num_heads, max_num_partitions, head_size]
    torch::Tensor& query,  // [num_seqs, num_heads, head_size]
    torch::Tensor&
        key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x]
    torch::Tensor&
        value_cache,       // [num_blocks, num_heads, head_size, block_size]
    int64_t num_kv_heads,  // [num_heads]
    double scale,
    torch::Tensor& block_tables,  // [num_seqs, max_num_blocks_per_seq]
    torch::Tensor& seq_lens,      // [num_seqs]
    int64_t block_size, int64_t max_seq_len,
    const c10::optional<torch::Tensor>& alibi_slopes,
1254
    const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, 
1255
1256
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
1257
    const int64_t blocksparse_head_sliding_step);
1258

zhuwenwen's avatar
zhuwenwen committed
1259
void paged_attention_v2_opt_tc(
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
    torch::Tensor& out,         // [num_seqs, num_heads, head_size]
    torch::Tensor& exp_sums,    // [num_seqs, num_heads, max_num_partitions]
    torch::Tensor& max_logits,  // [num_seqs, num_heads, max_num_partitions]
    torch::Tensor&
        tmp_out,  // [num_seqs, num_heads, max_num_partitions, head_size]
    torch::Tensor& query,  // [num_seqs, num_heads, head_size]
    torch::Tensor&
        key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x]
    torch::Tensor&
        value_cache,       // [num_blocks, num_heads, head_size, block_size]
    int64_t num_kv_heads,  // [num_heads]
    double scale,
    torch::Tensor& block_tables,  // [num_seqs, max_num_blocks_per_seq]
    torch::Tensor& seq_lens,      // [num_seqs]
    int64_t block_size, int64_t max_seq_len,
    const c10::optional<torch::Tensor>& alibi_slopes,
1276
    const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank,
1277
1278
    const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
1279
    const int64_t blocksparse_head_sliding_step) {
1280
1281
  const bool is_block_sparse = (blocksparse_vert_stride > 1);
  if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
zhangshao's avatar
zhangshao committed
1282
      block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
1283
    paged_attention_v2(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads,
1284
1285
                       scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
                       k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
1286
                       blocksparse_block_size,blocksparse_head_sliding_step);
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
  }
  else{
    DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
                              CALL_V2_LAUNCHER_BLOCK_SIZE)
  }
}

#undef WARP_SIZE
#undef MAX
#undef MIN
1297
#undef DIVIDE_ROUND_UP