attention_kernels_opt_tc.cu 45.1 KB
Newer Older
1
2
3
4
5
6
7
8
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>

#include "attention_dtypes.h"
#include "attention_utils.cuh"

zhangshao's avatar
zhangshao committed
9
10
11

#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
12
13
typedef __hip_bfloat16 __nv_bfloat16;

zhangshao's avatar
zhangshao committed
14
#define WARP_SIZE 64
15

zhuwenwen's avatar
zhuwenwen committed
16
#include "static_switch_tc.h"
17
18
19
20
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))

zhangshao's avatar
zhangshao committed
21
std::string get_device_name()
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
{
    hipDeviceProp_t props{};
    int device;
    auto status = hipGetDevice(&device);
    if(status != hipSuccess)
    {
        return std::string();
    }

    status = hipGetDeviceProperties(&props, device);
    if(status != hipSuccess)
    {
        return std::string();
    }
    const std::string raw_name(props.gcnArchName);
    return raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str.
}
zhangshao's avatar
zhangshao committed
39
40
41

static const std::string device_name=get_device_name();

42
43
44
45
46
47
static inline int get_env_(const char *env_var) {
  if (char *value = std::getenv(env_var)) {
    return atoi(value);
  }
  return 0;
}
48
static const int PA_USE_V1 = get_env_("PA_USE_V1");
49
static const int PA_REUSE_KV_TIMES = get_env_("PA_REUSE_KV_TIMES");
50
static const int PA_PARTITION_SIZE = get_env_("PA_PARTITION_SIZE");
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
static const int PA_BLOCK_SIZE = get_env_("PA_BLOCK_SIZE");
static const int PA_PRINT_PARAM = get_env_("PA_PRINT_PARAM");
namespace vllm {

// Utility function for attention softmax.
template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
  // Decompose the thread index into warp / lane.
  int warp = threadIdx.x / WARP_SIZE;
  int lane = threadIdx.x % WARP_SIZE;

  // Compute the sum per warp.
#pragma unroll
  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
    sum += VLLM_SHFL_XOR_SYNC(sum, mask);
  }

  // Warp leaders store the data to shared memory.
  if (lane == 0) {
    red_smem[warp] = sum;
  }

  // Make sure the data is in shared memory.
  __syncthreads();

  // The warps compute the final sums.
  if (lane < NUM_WARPS) {
    sum = red_smem[lane];
  }

  // Parallel reduction inside the warp.
#pragma unroll
  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
    sum += VLLM_SHFL_XOR_SYNC(sum, mask);
  }

  // Broadcast to other threads.
  return VLLM_SHFL_SYNC(sum, 0);
}

using half4_t = __attribute__( (__vector_size__(4 * sizeof(_Float16)) )) _Float16;
using v4bh = __attribute__( (__vector_size__(4 * sizeof(short)) )) short;
using float4_t = __attribute__( (__vector_size__(4 * sizeof(float)) )) float;
94
using float2_t = __attribute__( (__vector_size__(2 * sizeof(float)) )) float;
95
96
97
98
struct half4x2{
  half4_t data[2];
};

99
100
101
102
103
template<typename scalar_t> 
struct vec2data{
  scalar_t data[2];
};

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
template<bool is_half>
inline __device__ void float4_2_half4(half4_t& dst,const float4_t& src)
{
  if constexpr(is_half){
    #pragma unroll
    for(int i=0;i<4;i++){
      dst[i]=src[i];
    }
  }
  else{
    __nv_bfloat16* out = reinterpret_cast<__nv_bfloat16 *>(&dst);
    #pragma unroll
    for(int i=0;i<4;i++){
      out[i]=__float2bfloat16(src[i]);
    }
  }
}

template<bool is_half>
inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c)
{
    
    if constexpr (is_half){
zhangshao's avatar
zhangshao committed
127
     asm volatile("\n s_nop 1 \n v_mmac_f32_16x16x16_f16 %0, %1, %2, %0" : 
128
129
130
             "=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c));
    }
    else{
zhangshao's avatar
zhangshao committed
131
     asm volatile("\n s_nop 1 \n v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0" : 
132
133
134
135
       "=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c));
    }
}

136
template<bool is_half>
137
138
inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c)
{
139
    if constexpr (is_half){reg_c=__builtin_amdgcn_mmac_f32_16x16x16f16(reg_a,reg_b,reg_c);}
140
    else{
141
      reg_c=__builtin_amdgcn_mmac_f32_16x16x16bf16(*(v4bh*)&reg_a,*(v4bh*)&reg_b,reg_c);
142
143
144
145
146
147
148
    }
}

// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
          int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
149
          bool IS_BLOCK_SPARSE,int REUSE_KV_TIMES>  // Zero means no partitioning.
zhangshao's avatar
zhangshao committed
150
__global__ void paged_attention_kernel_TC(
151
    float* __restrict__ exp_sums,  // [num_seqs, num_heads, max_num_partitions]
152
153
154
    float* __restrict__ max_logits,  // [num_seqs, num_heads, max_num_partitions]
    scalar_t* __restrict__ out,  // [num_seqs, num_heads,head_size]
    scalar_t* __restrict__ out_tmp,  // [num_seqs, num_heads, max_num_partitions,head_size]
155
156
157
158
159
160
161
162
163
164
165
166
167
    const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
    const cache_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
                                          // head_size/x, block_size, x]
    const cache_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
                                          // head_size, block_size]
    const int num_heads,
    const int num_kv_heads,               // [num_heads]
    const float scale,
    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
    const int* __restrict__ seq_lens,      // [num_seqs]
    const int max_num_blocks_per_seq,
    const float* __restrict__ alibi_slopes,  // [num_heads]
    const int q_stride, const int kv_block_stride, const int kv_head_stride,
168
    const float* k_scale, const float* v_scale, const int tp_rank, 
169
    const int blocksparse_local_blocks, const int blocksparse_vert_stride, 
170
171
    const int blocksparse_block_size, const int blocksparse_head_sliding_step,int PARTITION_SIZE=0) {
#if defined(__gfx936__) || defined(__gfx928__)
172
  const int seq_idx = blockIdx.z;
zhangshao's avatar
zhangshao committed
173
174
  const int partition_idx = blockIdx.x;
  const int max_num_partitions = gridDim.x;
175
  const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]);
176
177
  const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
  const bool USE_PARTITIONING = PARTITION_SIZE<num_seq_blocks * BLOCK_SIZE && PARTITION_SIZE>0;
178
  if (partition_idx * PARTITION_SIZE >= seq_len) return;
179
180
181
182
  constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value;
  static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS");
  const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
  const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
183
184
185
186
187
188
189
190
  const int start_block_idx = partition_idx * num_blocks_per_partition;
  const int end_block_idx =MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
  const int num_blocks = end_block_idx - start_block_idx;
  const int start_token_idx = start_block_idx * BLOCK_SIZE;
  const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
  const int num_tokens = end_token_idx - start_token_idx;
  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  constexpr int x = 16 / sizeof(cache_t);
191
192
193
194
195
196
197
198
  const int thread_idx = threadIdx.x;
  const int warp_idx = __builtin_amdgcn_readfirstlane(thread_idx / WARP_SIZE);
  const int lane = thread_idx % WARP_SIZE;
  const int rowid = lane%16;
  const int rows = lane/16;
  
  const int num_queries_per_kv = num_heads / num_kv_heads;
  const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES);
zhangshao's avatar
zhangshao committed
199
  const int head_idx=(blockIdx.y / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.y % num_blocks_per_kv) * REUSE_KV_TIMES;
200
201

  int q_boundary=REUSE_KV_TIMES;
zhangshao's avatar
zhangshao committed
202
  if(num_heads < REUSE_KV_TIMES*gridDim.y && (num_blocks_per_kv-1)*REUSE_KV_TIMES == head_idx%num_queries_per_kv)
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
    q_boundary=num_queries_per_kv-(num_blocks_per_kv-1)*REUSE_KV_TIMES;
  const int kv_head_idx = head_idx / num_queries_per_kv;
  constexpr int reuse_group=(REUSE_KV_TIMES-1)/4+1;
  float alibi_slope[reuse_group]={0.f};
  if(alibi_slopes != nullptr){
    for(int i=0;i<reuse_group;i++){
      int reuse_kv_idx=rows+i*4;
      if(reuse_kv_idx<q_boundary) alibi_slope[i]=alibi_slopes[head_idx+reuse_kv_idx];
    }
  }
  float qk_max[reuse_group];
  for(int i=0;i<reuse_group;i++){
    qk_max[i]=-FLT_MAX;
  }

  const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
  
  half4x2 q_vec;
  q_vec.data[0]={0,0,0,0};
  q_vec.data[1]={0,0,0,0};
  
  __shared__ half4x2 q_vecs[REUSE_KV_TIMES][16];
225
    for(int i=0;i<q_boundary;i++){
226
    if(thread_idx<16){
227
      half4x2 temp = *reinterpret_cast<const half4x2*>(q_ptr+i*HEAD_SIZE+thread_idx*8);
zhangshao's avatar
zhangshao committed
228
229
230
231
232
233
      if constexpr(is_half){
        scalar_t *t=reinterpret_cast<scalar_t*>(&temp);
        #pragma unroll
        for(int k=0;k<8;k++){
          from_float(t[k],to_float(t[k])*scale);
        }
234
235
      }
      q_vecs[i][thread_idx]=temp;
236
237
238
239
240
    }
  }
  __syncthreads();
  extern __shared__ char shared_mem[];
  scalar_t* logits = reinterpret_cast<scalar_t*>(shared_mem);
241
242
243
  // __shared__ float red_smem[2 * NUM_WARPS];
  __shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS];
  __shared__ float s_logit[NUM_WARPS];
244
245
  const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
  const cache_t* k_ptr_base = k_cache+kv_head_idx * kv_head_stride+lane*8;
246
  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;block_idx += NUM_WARPS) {
247
248
249
250
251
252
253
254
255
    const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
    const cache_t* k_ptr=k_ptr_base + physical_block_number * kv_block_stride;
    float4_t qk_vec={0,0,0,0};
    half4x2 k_vec[2];
    k_vec[0]=*reinterpret_cast<const half4x2*>(k_ptr);
    #pragma unroll
    for(int i=0;i<3;i++){
      if(rowid<q_boundary)q_vec=q_vecs[rowid][i*4+rows];
      k_vec[1-i%2]=*reinterpret_cast<const half4x2*>(k_ptr+(i+1)*512);
256
257
      builtin_amdgcn_mmac<is_half>(k_vec[i%2].data[0],q_vec.data[0],qk_vec);
      builtin_amdgcn_mmac<is_half>(k_vec[i%2].data[1],q_vec.data[1],qk_vec);
258
259
260
261
    }
    //tail
    {
      if(rowid<q_boundary)q_vec=q_vecs[rowid][3*4+rows];
262
      builtin_amdgcn_mmac<is_half>(k_vec[1].data[0],q_vec.data[0],qk_vec);
263
264
265
266
267
268
269
      v_mmac_f32_16x16x16_f16<is_half>(k_vec[1].data[1],q_vec.data[1],qk_vec);
    }
    #pragma unroll
    for(int i=0;i<reuse_group;i++){
      int reuse_kv_idx=rows+i*4;
      if(reuse_kv_idx<REUSE_KV_TIMES){
        if(reuse_kv_idx>=q_boundary)qk_vec[i]=0;
zhangshao's avatar
zhangshao committed
270
271
272
        else {
          if constexpr(!is_half) qk_vec[i]*=scale;
        }
273
274
275
        const int token_idx = block_idx * BLOCK_SIZE+rowid;
        if(alibi_slope[i] != 0){
          float alibi=alibi_slope[i]* (token_idx - seq_len + 1);
zhuwenwen's avatar
zhuwenwen committed
276
          qk_vec[i] += alibi;
277
        }
278
        const bool mask = (token_idx >= seq_len);
279
280
281
282
283
284
285
286
287
288
        if(mask){
          from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f);
        }
        else{
          from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , qk_vec[i]);
          qk_max[i] = fmaxf(qk_max[i], qk_vec[i]);
        }
      }
    }
  }
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
  // compute max
  #pragma unroll
  for (int mask = 8; mask >= 1; mask /= 2) {
    #pragma unroll
    for(int r=0;r<reuse_group;r++){
      qk_max[r]=fmaxf(qk_max[r],__shfl_xor(qk_max[r],mask));
    }
  }
  #pragma unroll
  for(int r=0;r<reuse_group;r++){
    if(rowid==0&&r*4+rows<q_boundary){
      s_max[r*4+rows][warp_idx] = qk_max[r];
    }
  }
  __syncthreads();
  __shared__ float max_out[REUSE_KV_TIMES];
  __shared__ float expsum_out[REUSE_KV_TIMES];
306
307
  for(int reuse_kv_idx=0; reuse_kv_idx<q_boundary; reuse_kv_idx++) {
    const int head_idx_ = head_idx + reuse_kv_idx;
308
    float qk_max_tmp = lane < NUM_WARPS ? s_max[reuse_kv_idx][lane] : -FLT_MAX;
309
310
    float exp_sum = 0.f;
    #pragma unroll
311
312
    for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
      qk_max_tmp = fmaxf(qk_max_tmp, __shfl_xor(qk_max_tmp, mask));
313
    }
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    qk_max_tmp = __shfl(qk_max_tmp, 0);
    for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
      float val = __expf(to_float(logits[(reuse_kv_idx * partition_size) + i]) - qk_max_tmp);
      from_float(logits[(reuse_kv_idx * partition_size) + i] , val);
      exp_sum += val;
    }
    exp_sum = block_sum<NUM_WARPS>(s_logit, exp_sum);
    // Compute softmax.
    const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
    for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
      from_float(logits[(reuse_kv_idx * partition_size) + i] ,to_float(logits[(reuse_kv_idx * partition_size) + i])*inv_sum);
    }
    if(USE_PARTITIONING&&thread_idx == 0){
      max_out[reuse_kv_idx] = qk_max_tmp;
      expsum_out[reuse_kv_idx]=exp_sum;
329
330
    }
  }
331
  __syncthreads();
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
  constexpr int NUM_ROWS_PER_THREAD =DIVIDE_ROUND_UP(HEAD_SIZE, WARP_SIZE);//2
  if constexpr(REUSE_KV_TIMES<=2){
    float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD];
    #pragma unroll
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
      #pragma unroll
      for(int k=0;k<REUSE_KV_TIMES;k++)
      {
        accs[k][i] = 0.f;
      }
    }
    scalar_t zero_value;
    zero(zero_value);
    for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
        block_idx += NUM_WARPS) {
      const int64_t physical_block_number =
          static_cast<int64_t>(block_table[block_idx]);
      const int token_idx = block_idx * BLOCK_SIZE +rows*4;
      half4_t logits_vec={0,0,0,0};
      if(rowid<4*q_boundary){
          logits_vec=*reinterpret_cast<half4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
      }
      const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
                            kv_head_idx * kv_head_stride + rows*4+rowid*16;
      #pragma unroll
      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
        #pragma unroll
        for(int k=0;k<4;k++){
          int offset=i*1024+k*256;
          half4_t v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
          if (block_idx == num_seq_blocks - 1) {
            scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
            #pragma unroll
            for (int j = 0; j < 4; j++) {
              v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
            }
          }
          float4_t out_vec={0,0,0,0};
370
          builtin_amdgcn_mmac<is_half>(v_vec,logits_vec,out_vec);
371
372
373
374
375
376
377
378
379
          if(rows==k){
            for(int resuseid=0;resuseid<REUSE_KV_TIMES;resuseid++){
              accs[resuseid][i]+=out_vec[resuseid];
            }
          }
        }
      } 
    } 
    __syncthreads();
zhangshao's avatar
zhangshao committed
380
    using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float;
381
382
    // Perform reduction across warps.
    for(int reuse_kv_idx=0; reuse_kv_idx<q_boundary; reuse_kv_idx++) {
zhangshao's avatar
zhangshao committed
383
384
385
386
387
388
389
390
      if constexpr (NUM_THREADS>64){
        floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
        #pragma unroll
        for (int i = NUM_WARPS; i > 1; i /= 2) {
          int mid = i / 2;
          // Upper warps write to shared memory.
          if (warp_idx >= mid && warp_idx < i) {
            out_smem[(warp_idx - mid) * 64+lane]=*(floatV_t*)(accs[reuse_kv_idx]);
391
          }
zhangshao's avatar
zhangshao committed
392
393
394
395
396
397
398
399
          __syncthreads();
          // Lower warps update the output.
          if (warp_idx < mid) {
            floatV_t tmp=out_smem[thread_idx];
            #pragma unroll
            for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
              accs[reuse_kv_idx][i] += tmp[i];
            }
400
          }
zhangshao's avatar
zhangshao committed
401
          __syncthreads();
402
403
404
        }
      }
      if (warp_idx == 0) {
405
406
407
408
409
410
411
412
413
414
        scalar_t* out_ptr;
        int out_offset;
        if(USE_PARTITIONING){
          out_offset=max_num_partitions*HEAD_SIZE;
          out_ptr=out_tmp + seq_idx * num_heads * out_offset + head_idx*out_offset+partition_idx * HEAD_SIZE;
        }
        else{
          out_ptr=out + seq_idx * num_heads  * HEAD_SIZE + head_idx*HEAD_SIZE;
        } 
        #pragma unroll
415
416
417
418
419
420
421
        for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
          const int row_idx = lane + i * WARP_SIZE;
            from_float(*(out_ptr + row_idx), accs[reuse_kv_idx][i]);
        }
      }
    }
  }
zhangshao's avatar
zhangshao committed
422
#if defined __gfx928__
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
  else{
    constexpr int GROUPS=reuse_group*4;
    // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
    float accs[GROUPS][NUM_ROWS_PER_THREAD];
    #pragma unroll
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
      #pragma unroll
      for(int k=0;k<GROUPS;k++)
      {
        accs[k][i] = 0.f;
      }
    }
    scalar_t zero_value;
    zero(zero_value);
    for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
        block_idx += NUM_WARPS) {
      const int64_t physical_block_number =
          static_cast<int64_t>(block_table[block_idx]);
      const int token_idx = block_idx * BLOCK_SIZE +rows*4;
      half4_t logits_vec={0,0,0,0};
      if(rowid<q_boundary){
          logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
      }
      const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
                            kv_head_idx * kv_head_stride + rows*4+rowid*16;
      #pragma unroll
      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
        #pragma unroll
        for(int k=0;k<4;k++){
          int offset=i*1024+k*256;
          half4_t v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
          if (block_idx == num_seq_blocks - 1) {
            scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
            #pragma unroll
            for (int j = 0; j < 4; j++) {
              v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
            }
          }
          float4_t out_vec={0,0,0,0};
462
          builtin_amdgcn_mmac<is_half>(v_vec,logits_vec,out_vec);
463
464
465
466
467
          for(int g=0;g<reuse_group;g++){
            accs[g*4+k][i]+=out_vec[g];
          }
        }
      } 
zhangshao's avatar
zhangshao committed
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
    }
    if constexpr (NUM_THREADS>64){
      __syncthreads();
      using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float;
      // Perform reduction across warps.
      
      for(int reuse_kv_idx=0; reuse_kv_idx<GROUPS; reuse_kv_idx++) {
        
        floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
        #pragma unroll
        for (int i = NUM_WARPS; i > 1; i /= 2) {
          int mid = i / 2;
          // Upper warps write to shared memory.
          if (warp_idx >= mid && warp_idx < i) {
            out_smem[(warp_idx - mid) * 64+lane]=*(floatV_t*)(accs[reuse_kv_idx]);
483
          }
zhangshao's avatar
zhangshao committed
484
485
486
487
488
489
490
491
          __syncthreads();
          // Lower warps update the output.
          if (warp_idx < mid) {
            floatV_t tmp=out_smem[thread_idx];
            #pragma unroll
            for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
              accs[reuse_kv_idx][i] += tmp[i];
            }
492
          }
zhangshao's avatar
zhangshao committed
493
          __syncthreads();
494
495
496
497
        }
      }
    }
    if (warp_idx == 0) {
498
499
500
501
502
503
504
505
506
507
      scalar_t* out_ptr_base;
      int out_offset;
      if(USE_PARTITIONING){
        out_offset=max_num_partitions*HEAD_SIZE;
        out_ptr_base=out_tmp + seq_idx * num_heads * out_offset + head_idx*out_offset+partition_idx * HEAD_SIZE;
      }
      else{
        out_offset=HEAD_SIZE;
        out_ptr_base=out + seq_idx * num_heads  * HEAD_SIZE + head_idx*HEAD_SIZE;
      } 
508
509
510
      for(int g=0;g<reuse_group;g++){
        int reusekvid=g*4+rows;
        if(reusekvid<q_boundary){
511
          scalar_t* out_ptr = out_ptr_base + reusekvid * out_offset;
512
513
514
515
516
517
518
519
520
521
522
          #pragma unroll
          for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
            for(int k=0;k<4;k++){
              const int row_idx = rowid+16*k + i * WARP_SIZE;
              from_float(*(out_ptr + row_idx), accs[g*4+k][i]);
            }
          }
        }
      }
    }
  }
zhangshao's avatar
zhangshao committed
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
#else
  else{
    constexpr int GROUPS=reuse_group*4;
    // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
    float4_t accs[4][NUM_ROWS_PER_THREAD];
    #pragma unroll
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
      #pragma unroll
      for(int k=0;k<4;k++)
      {
        accs[k][i] = {0.f,0.f,0.f,0.f};
      }
    }
    scalar_t zero_value;
    zero(zero_value);
    for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
        block_idx += NUM_WARPS) {
      const int64_t physical_block_number =
          static_cast<int64_t>(block_table[block_idx]);
      const int token_idx = block_idx * BLOCK_SIZE +rows*4;
      half4_t logits_vec={0,0,0,0};
      if(rowid<q_boundary){
          logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
      }
      const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
                            kv_head_idx * kv_head_stride + rows*4+rowid*16;
      #pragma unroll
      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
        #pragma unroll
        for(int k=0;k<4;k++){
          int offset=i*1024+k*256;
          half4_t v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
          if (block_idx == num_seq_blocks - 1) {
            scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
            #pragma unroll
            for (int j = 0; j < 4; j++) {
              v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
            }
          }
562
          builtin_amdgcn_mmac<is_half>(v_vec,logits_vec,accs[k][i]);
zhangshao's avatar
zhangshao committed
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
        }
      } 
    }
    if constexpr (NUM_THREADS>64){
      __syncthreads();
      using floatV_t = __attribute__( (__vector_size__(reuse_group * sizeof(float)) )) float;
      // Perform reduction across warps.
      
      for(int m=0; m<4; m++) {
        floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
        #pragma unroll
        for (int i = NUM_WARPS; i > 1; i /= 2) {
          int mid = i / 2;
          // Upper warps write to shared memory.
          if (warp_idx >= mid && warp_idx < i) {
            for(int k=0;k<NUM_ROWS_PER_THREAD;k++){
              out_smem[((warp_idx - mid) * 64+lane)*NUM_ROWS_PER_THREAD+k]=*(floatV_t*)(&(accs[m][k]));
            }
          }
          __syncthreads();
          // Lower warps update the output.
          if (warp_idx < mid) {
            for(int k=0;k<NUM_ROWS_PER_THREAD;k++){
              floatV_t tmp=out_smem[thread_idx*NUM_ROWS_PER_THREAD+k];
              #pragma unroll
              for (int i = 0; i < reuse_group; i++) {
                accs[m][k][i] += tmp[i];
              }
            }
          }
          __syncthreads();
        }
      }
    }
    if (warp_idx == 0) {
598
599
600
601
602
603
604
605
606
607
      scalar_t* out_ptr_base;
      int out_offset;
      if(USE_PARTITIONING){
        out_offset=max_num_partitions*HEAD_SIZE;
        out_ptr_base=out_tmp + seq_idx * num_heads * out_offset + head_idx*out_offset+partition_idx * HEAD_SIZE;
      }
      else{
        out_offset=HEAD_SIZE;
        out_ptr_base=out + seq_idx * num_heads  * HEAD_SIZE + head_idx*HEAD_SIZE;
      } 
zhangshao's avatar
zhangshao committed
608
609
610
      for(int g=0;g<reuse_group;g++){
        int reusekvid=g*4+rows;
        if(reusekvid<q_boundary){
611
          scalar_t* out_ptr = out_ptr_base + reusekvid*out_offset;
zhangshao's avatar
zhangshao committed
612
613
614
615
616
617
618
619
620
621
622
623
          #pragma unroll
          for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
            for(int k=0;k<4;k++){
              const int row_idx = rowid+16*k + i * WARP_SIZE;
              from_float(*(out_ptr + row_idx), accs[k][i][g]);
            }
          }
        }
      }
    }
  }
#endif
624
625
626
627
628
629
  if (USE_PARTITIONING&&thread_idx < q_boundary){
    int offset = seq_idx * num_heads * max_num_partitions + (head_idx+thread_idx) * max_num_partitions + partition_idx;
    *(max_logits+offset)=max_out[thread_idx];
    *(exp_sums+offset)=expsum_out[thread_idx];
  }
#endif
630
631
632

}

633

634
// Grid: (num_heads, num_seqs).
635
636
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS>
__global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kernel_opt_tc(
637
638
639
640
641
642
643
644
    scalar_t* __restrict__ out,            // [num_seqs, num_heads, head_size]
    const float* __restrict__ exp_sums,    // [num_seqs, num_heads,
                                           // max_num_partitions]
    const float* __restrict__ max_logits,  // [num_seqs, num_heads,
                                           // max_num_partitions]
    const scalar_t* __restrict__ tmp_out,  // [num_seqs, num_heads,
                                           // max_num_partitions, head_size]
    const int* __restrict__ seq_lens,      // [num_seqs]
645
    const int max_num_partitions,int PARTITION_SIZE=512) {
646
647
648
649
650
  const int num_heads = gridDim.x;
  const int head_idx = blockIdx.x;
  const int seq_idx = blockIdx.y;
  const int seq_len = seq_lens[seq_idx];
  const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
651
  if(num_partitions==1)return;
652
  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
653
654
655
  const int thread_idx = threadIdx.x;
  const int warp_idx = __builtin_amdgcn_readfirstlane(thread_idx / WARP_SIZE);
  const int lane = thread_idx % WARP_SIZE;
656

657
658
659
  int offset = seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions;
  const float* max_logits_ptr = max_logits + offset;
  const float* exp_sums_ptr = exp_sums + offset;
660
  float max_logit = -FLT_MAX;
661
  float global_max_logit = -FLT_MAX;
662
  float global_exp_sum = 0.0f;
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
  if constexpr(NUM_THREADS == 64&& HEAD_SIZE==128){
    __shared__ float shared_exp_sums[64];
    if(thread_idx<num_partitions){
      max_logit = max_logits_ptr[thread_idx];
      global_exp_sum = exp_sums_ptr[thread_idx];
      global_max_logit = max_logit;
    }
    #pragma unroll
    for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
      global_max_logit = fmaxf(global_max_logit, VLLM_SHFL_XOR_SYNC(global_max_logit, mask));
    }
    if(thread_idx<num_partitions){
      global_exp_sum = global_exp_sum * __expf(max_logit - global_max_logit);
      shared_exp_sums[thread_idx] = global_exp_sum;
    }
    #pragma unroll
    for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
      global_exp_sum += VLLM_SHFL_XOR_SYNC(global_exp_sum, mask);
    }
    const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
    
    scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
    const scalar_t* tmp_out_ptr = tmp_out + offset * HEAD_SIZE;
    using half2_t = vec2data<scalar_t>;
    float2_t acc = {0.0f, 0.0f};
    half2_t acc_half;
    for (int j = 0; j < num_partitions; ++j) {
      half2_t tout= *(half2_t*)(tmp_out_ptr + j * HEAD_SIZE + thread_idx*2);
      float temp_sum=shared_exp_sums[j]*inv_global_exp_sum;
      #pragma unroll
      for(int i=0;i<2;i++){
        acc[i] += to_float(tout.data[i])*temp_sum;
      }
    }
    #pragma unroll
    for(int i=0;i<2;i++){
      from_float(acc_half.data[i],acc[i]);
    }
    *(half2_t*)(out_ptr+thread_idx*2)=acc_half;
702
  }
703
704
705
706
707
708
709
710
711
712
713
714
715
  else{
    // Size: 2 * num_partitions.
    extern __shared__ char shared_mem[];
    // Workspace for reduction.
    __shared__ float red_smem[2 * NUM_WARPS];
    // Load max logits to shared memory.
    float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
    for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
      const float l = max_logits_ptr[i];
      shared_max_logits[i] = l;
      max_logit = fmaxf(max_logit, l);
    }
    __syncthreads();
716

717
718
    // Get the global max logit.
    // Reduce within the warp.
719
  #pragma unroll
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
    for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
      max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
    }
    if (lane == 0) {
      red_smem[warp_idx] = max_logit;
    }
    __syncthreads();
    // Reduce across warps.
    max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
    #pragma unroll
    for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
      max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
    }
    // Broadcast the max value to all threads.
    max_logit = VLLM_SHFL_SYNC(max_logit, 0);

    // Load rescaled exp sums to shared memory.
    float* shared_exp_sums =
        reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
    for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
      float l = shared_max_logits[i];
      float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
      global_exp_sum += rescaled_exp_sum;
      shared_exp_sums[i] = rescaled_exp_sum;
    }
    __syncthreads();
    global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
    const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
    // Aggregate tmp_out to out.
    const scalar_t* tmp_out_ptr =
        tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
        head_idx * max_num_partitions * HEAD_SIZE;
    scalar_t* out_ptr =
        out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
    #pragma unroll
    for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
      float acc = 0.0f;
      for (int j = 0; j < num_partitions; ++j) {
        acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
              inv_global_exp_sum;
      }
      from_float(out_ptr[i], acc);
762
763
764
765
766
767
768
    }
  }
}

}  // namespace vllm


769
770
#define LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE)                                      \
  hipLaunchKernelGGL(                                                                \
zhangshao's avatar
zhangshao committed
771
      (vllm::paged_attention_kernel_TC<                                              \
772
773
774
775
776
777
          T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE,                  \
          IS_BLOCK_SPARSE, REUSE_KV_TIMES>),                                         \
      dim3(grid), dim3(block), shared_mem_size, stream, exp_sums_ptr,                \
      max_logits_ptr,out_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr,\
      num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr,                \
      max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride,           \
zhangshao's avatar
zhangshao committed
778
      kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks,   \
779
      blocksparse_vert_stride, blocksparse_block_size,                               \
zhangshao's avatar
zhangshao committed
780
      blocksparse_head_sliding_step,PARTITION_SIZE);                                 \
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
  if (max_num_partitions<=64&&max_num_partitions>1){                                 \
      hipLaunchKernelGGL(                                                            \
      (vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, 64>),             \
      dim3(reduce_grid), dim3(64), 0, stream, out_ptr,                               \
      exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr,                       \
      max_num_partitions,PARTITION_SIZE);                                            \
  }else if(max_num_partitions>64){                                                   \
    hipLaunchKernelGGL(                                                              \
      (vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, 128>),            \
      dim3(reduce_grid), dim3(128), reduce_shared_mem_size, stream, out_ptr,         \
      exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr,                       \
      max_num_partitions,PARTITION_SIZE);}


void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITION_SIZE,int &max_num_partitions,
      int batchsize,int max_seq_len,int qheads,int kvheads,int num_blocks)
{
798
  reusekv=1;
zhangshao's avatar
zhangshao committed
799
  num_thread=256;
800
801
802
803
804
805
806
807
808
  PARTITION_SIZE=512;
  max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
  if(max_seq_len==8192&&num_blocks==1024){//ali test
    if(batchsize==1&&qheads==16&&kvheads==16){num_thread=128;return;}
    if(batchsize==1&&qheads==32&&kvheads==32){num_thread=64;return;}
    if(batchsize==1){
      if(qheads==52){reusekv=8;return;}
      if(qheads==13){reusekv=2;return;}
      reusekv=4;return;
zhangshao's avatar
zhangshao committed
809
    }
810
811
812
813
814
815
    if(batchsize==64){
      if(qheads==13){PARTITION_SIZE=256;num_thread=128;reusekv=8;}
      else if(qheads==32){PARTITION_SIZE=1024;reusekv=8;}
      else if(qheads==52||qheads==26){reusekv=16;}
      else reusekv=8;
      max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
zhangshao's avatar
zhangshao committed
816
817
818
      return;
    }
  }
819
  if(qheads==kvheads){
820
821
822
823
824
825
826
    if(max_seq_len<=8192){
      if(batchsize*qheads>=512){
        max_num_partitions=1;
        num_thread=64;
      }
      if(qheads==32&&max_seq_len<=1024)max_num_partitions=1;
    }
827
828
    return;
  }
829
  if(max_seq_len<800)max_num_partitions=1;
830
  if(qheads>kvheads*4){
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
    if(max_seq_len<=1000||
        max_seq_len<1500&&(batchsize>=8&&qheads>=8||batchsize>=64)||
        max_seq_len<1900&&batchsize>=8&&qheads==28
        )
        max_num_partitions=1;
    int blocks=max_num_partitions*batchsize*qheads;
    if(device_name=="gfx928"){
      if(batchsize*qheads>1024&&max_seq_len>=2000){
        max_num_partitions=1;
        if(max_seq_len<3900)reusekv=8;
        else if(max_seq_len<7800)reusekv=4;
        else{
          PARTITION_SIZE=2048;
          reusekv=8;
          max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
846
        }
847
        return;
848
849
      }
    }
850
851
852
853
854
855
856
857
858
859
860
861
    if(max_num_partitions==1){
      if(max_seq_len<512){
        int bytes=max_seq_len*qheads*batchsize;
        if(bytes<51200)reusekv=1;
        else if(bytes<256000)reusekv=4;
        else reusekv=8;
        return;
      }
      if(batchsize<4||batchsize==4&&qheads==8)reusekv=1;
      else if(batchsize<32||batchsize<=64&&qheads==8)reusekv=4;
      else reusekv=8;
      return;
862
    }
863
864
865
    if(blocks<150)return;
    if(blocks<600||qheads<=kvheads*4){reusekv=4;return;}
    reusekv=8;return;
866
  }
867
868
869
870
871
872
873
874
  if(device_name=="gfx928"){
    if(batchsize*qheads>1024&&max_seq_len>=2000){
      max_num_partitions=1;
      if(max_seq_len<7800)reusekv=4;
      else{
        PARTITION_SIZE=2048;
        reusekv=4;
        max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
zhangshao's avatar
zhangshao committed
875
      }
876
      return;
zhangshao's avatar
zhangshao committed
877
    }
878
  }
879
880
881
882
883
  if(max_seq_len<=1000||
      max_seq_len<=1500&&(qheads>4&&batchsize>=16||batchsize>=64))
        max_num_partitions=1;
  int blocks=max_num_partitions*batchsize*qheads;
  if(blocks>=150||batchsize>=16||qheads>=8&&(batchsize>=4||max_seq_len>=2000))reusekv=4;
884

885
}
886
template <typename T, typename CACHE_T, int BLOCK_SIZE,
887
          vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE>
zhuwenwen's avatar
zhuwenwen committed
888
void paged_attention_v2_launcher_opt_tc(
889
890
891
892
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int num_kv_heads, float scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
893
894
    const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
    torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
895
    const int blocksparse_vert_stride, const int blocksparse_block_size,
896
    const int blocksparse_head_sliding_step) {
897
898
899
900
901
902
903
  int num_seqs = query.size(0);
  int num_heads = query.size(1);
  int head_size = query.size(2);
  int max_num_blocks_per_seq = block_tables.size(1);
  int q_stride = query.stride(0);
  int kv_block_stride = key_cache.stride(0);
  int kv_head_stride = key_cache.stride(1);
zhangshao's avatar
zhangshao committed
904
  int num_blocks=key_cache.size(0);
905
906
907
908
909
910
911

  // NOTE: alibi_slopes is optional.
  const float* alibi_slopes_ptr =
      alibi_slopes
          ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
          : nullptr;
  T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
912
913
  const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
  const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
914
915
916
  // float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
  // float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
  // T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
917
918
919
920
921
  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
  CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
  int* block_tables_ptr = block_tables.data_ptr<int>();
  int* seq_lens_ptr = seq_lens.data_ptr<int>();
922
923
924
925
926
927
928
929
  static float* exp_sums_ptr = nullptr;
  static float* max_logits_ptr = nullptr;
  static T* tmp_out_ptr = nullptr;
  if(exp_sums_ptr == nullptr){
      hipMalloc(&exp_sums_ptr, 1000000); // 1m
      hipMalloc(&max_logits_ptr, 1000000); // 1m
      hipMalloc(&tmp_out_ptr, 100000000); // 100m
  }
930
931
932
  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  dim3 reduce_grid(num_heads, num_seqs);
933
  
934
  if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2 && KV_DTYPE==vllm::Fp8KVCacheDataType::kAuto){
zhangshao's avatar
zhangshao committed
935
    constexpr int HEAD_SIZE=128;
936
937
938
939
940
941
    int reusekv, num_thread,max_num_partitions,PARTITION_SIZE;
    get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks);
    if(PA_PARTITION_SIZE!=0){
      PARTITION_SIZE=PA_PARTITION_SIZE;
      max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
    }
zhangshao's avatar
zhangshao committed
942
943
    if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES;
    if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE;
944
945
946
947
    if(PA_USE_V1!=0)max_num_partitions=1;
    if(max_num_partitions==1)PARTITION_SIZE=max_seq_len;
    assert(num_seqs*num_heads*max_num_partitions*head_size<=100000000);
    int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
zhangshao's avatar
zhangshao committed
948
949
950
951
    REUSEKV_SWITCH(reusekv,[&] {
      NUM_THREADS_SWITCH(num_thread , [&] {
        constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
        int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 2;
952
        if(max_num_partitions==1)PARTITION_SIZE=0;
zhangshao's avatar
zhangshao committed
953
954
955
956
957
958
959
960
961
962
        int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
        dim3 grid;
        grid.y = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads;
        grid.x = max_num_partitions;
        grid.z = num_seqs;
        dim3 block(NUM_THREADS);
        int shared_mem_size = ::max(logits_size, outputs_size);
        if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d\n",
                                  reusekv,num_thread,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs);
        LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE);
963
      });
zhangshao's avatar
zhangshao committed
964
965
    });
  }
966
967
968
}

#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE)   \
zhuwenwen's avatar
zhuwenwen committed
969
  paged_attention_v2_launcher_opt_tc<T, CACHE_T, BLOCK_SIZE, KV_DTYPE,               \
970
971
972
973
974
                              IS_BLOCK_SPARSE>(                               \
      out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache,      \
      num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
      k_scale, v_scale, tp_rank, blocksparse_local_blocks,                    \
      blocksparse_vert_stride, blocksparse_block_size,                        \
975
      blocksparse_head_sliding_step);
976
977

#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
zhuwenwen's avatar
zhuwenwen committed
978
979
980
981
  if (is_block_sparse) {                                                   \
    CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true);       \
  } else {                                                                 \
    CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false);      \
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
  }

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE)         \
  switch (block_size) {                                           \
    case 8:                                                       \
      CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE);         \
      break;                                                      \
    case 16:                                                      \
      CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE);        \
      break;                                                      \
    case 32:                                                      \
      CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE);        \
      break;                                                      \
    default:                                                      \
      TORCH_CHECK(false, "Unsupported block size: ", block_size); \
      break;                                                      \
  }

1002
void paged_attention_v2(
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
    torch::Tensor& out,         // [num_seqs, num_heads, head_size]
    torch::Tensor& exp_sums,    // [num_seqs, num_heads, max_num_partitions]
    torch::Tensor& max_logits,  // [num_seqs, num_heads, max_num_partitions]
    torch::Tensor&
        tmp_out,  // [num_seqs, num_heads, max_num_partitions, head_size]
    torch::Tensor& query,  // [num_seqs, num_heads, head_size]
    torch::Tensor&
        key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x]
    torch::Tensor&
        value_cache,       // [num_blocks, num_heads, head_size, block_size]
    int64_t num_kv_heads,  // [num_heads]
    double scale,
    torch::Tensor& block_tables,  // [num_seqs, max_num_blocks_per_seq]
    torch::Tensor& seq_lens,      // [num_seqs]
    int64_t block_size, int64_t max_seq_len,
    const c10::optional<torch::Tensor>& alibi_slopes,
1019
    const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, 
1020
1021
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
1022
    const int64_t blocksparse_head_sliding_step);
1023

zhuwenwen's avatar
zhuwenwen committed
1024
void paged_attention_v2_opt_tc(
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
    torch::Tensor& out,         // [num_seqs, num_heads, head_size]
    torch::Tensor& exp_sums,    // [num_seqs, num_heads, max_num_partitions]
    torch::Tensor& max_logits,  // [num_seqs, num_heads, max_num_partitions]
    torch::Tensor&
        tmp_out,  // [num_seqs, num_heads, max_num_partitions, head_size]
    torch::Tensor& query,  // [num_seqs, num_heads, head_size]
    torch::Tensor&
        key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x]
    torch::Tensor&
        value_cache,       // [num_blocks, num_heads, head_size, block_size]
    int64_t num_kv_heads,  // [num_heads]
    double scale,
    torch::Tensor& block_tables,  // [num_seqs, max_num_blocks_per_seq]
    torch::Tensor& seq_lens,      // [num_seqs]
    int64_t block_size, int64_t max_seq_len,
    const c10::optional<torch::Tensor>& alibi_slopes,
1041
    const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank,
1042
1043
    const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
1044
    const int64_t blocksparse_head_sliding_step) {
1045
1046
  const bool is_block_sparse = (blocksparse_vert_stride > 1);
  if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
zhangshao's avatar
zhangshao committed
1047
      block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
1048
    paged_attention_v2(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads,
1049
1050
                       scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
                       k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
1051
                       blocksparse_block_size,blocksparse_head_sliding_step);
1052
1053
1054
1055
1056
1057
1058
  }
  else{
    DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
                              CALL_V2_LAUNCHER_BLOCK_SIZE)
  }
}

1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
void paged_attention_v1(
    torch::Tensor& out,    // [num_seqs, num_heads, head_size]
    torch::Tensor& query,  // [num_seqs, num_heads, head_size]
    torch::Tensor&
        key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x]
    torch::Tensor&
        value_cache,       // [num_blocks, num_heads, head_size, block_size]
    int64_t num_kv_heads,  // [num_heads]
    double scale,
    torch::Tensor& block_tables,  // [num_seqs, max_num_blocks_per_seq]
    torch::Tensor& seq_lens,      // [num_seqs]
    int64_t block_size, int64_t max_seq_len,
    const c10::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank,
    const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);

void paged_attention_v1_opt_tc(
    torch::Tensor& out,    // [num_seqs, num_heads, head_size]
    torch::Tensor& query,  // [num_seqs, num_heads, head_size]
    torch::Tensor&
        key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x]
    torch::Tensor&
        value_cache,       // [num_blocks, num_heads, head_size, block_size]
    int64_t num_kv_heads,  // [num_heads]
    double scale,
    torch::Tensor& block_tables,  // [num_seqs, max_num_blocks_per_seq]
    torch::Tensor& seq_lens,      // [num_seqs]
    int64_t block_size, int64_t max_seq_len,
    const c10::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, 
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step) {
  const bool is_block_sparse = (blocksparse_vert_stride > 1);
  if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
      block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
    paged_attention_v1(out,query,key_cache,value_cache,num_kv_heads,
                       scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
                       k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
                       blocksparse_block_size,blocksparse_head_sliding_step);
  }
  else{
    paged_attention_v2_opt_tc(out,out,out,out,query,key_cache,value_cache,num_kv_heads,
                       scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
                       k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
                       blocksparse_block_size,blocksparse_head_sliding_step);
  }
}

1110
1111
1112
#undef WARP_SIZE
#undef MAX
#undef MIN
1113
#undef DIVIDE_ROUND_UP