attention_kernels.cu 43.1 KB
Newer Older
1
/*
2
3
 * Adapted from
 * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
Woosuk Kwon's avatar
Woosuk Kwon committed
4
 * Copyright (c) 2023, The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
19

20
#include <torch/all.h>
Woosuk Kwon's avatar
Woosuk Kwon committed
21
#include <ATen/cuda/CUDAContext.h>
22
#include <c10/cuda/CUDAGuard.h>
23
#include <algorithm>
Woosuk Kwon's avatar
Woosuk Kwon committed
24

Woosuk Kwon's avatar
Woosuk Kwon committed
25
#include "attention_dtypes.h"
Woosuk Kwon's avatar
Woosuk Kwon committed
26
#include "attention_utils.cuh"
27
28
29

#ifdef USE_ROCM
  #include <hip/hip_bf16.h>
30
  #include "../quantization/fp8/amd/quant_utils.cuh"
31
typedef __hip_bfloat16 __nv_bfloat16;
32
33
#else
  #include "../quantization/fp8/nvidia/quant_utils.cuh"
34
35
36
#endif

#ifndef USE_ROCM
37
  #define WARP_SIZE 32
38
#else
39
  #define WARP_SIZE warpSize
40
41
#endif

Woosuk Kwon's avatar
Woosuk Kwon committed
42
43
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
44
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
Woosuk Kwon's avatar
Woosuk Kwon committed
45

Woosuk Kwon's avatar
Woosuk Kwon committed
46
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
47
48

// Utility function for attention softmax.
49
template <int NUM_WARPS>
Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52
53
54
55
56
57
inline __device__ float block_sum(float* red_smem, float sum) {
  // Decompose the thread index into warp / lane.
  int warp = threadIdx.x / WARP_SIZE;
  int lane = threadIdx.x % WARP_SIZE;

  // Compute the sum per warp.
#pragma unroll
  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
58
    sum += VLLM_SHFL_XOR_SYNC(sum, mask);
Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
  }

  // Warp leaders store the data to shared memory.
  if (lane == 0) {
    red_smem[warp] = sum;
  }

  // Make sure the data is in shared memory.
  __syncthreads();

  // The warps compute the final sums.
  if (lane < NUM_WARPS) {
    sum = red_smem[lane];
  }

  // Parallel reduction inside the warp.
#pragma unroll
  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
77
    sum += VLLM_SHFL_XOR_SYNC(sum, mask);
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
80
  }

  // Broadcast to other threads.
81
  return VLLM_SHFL_SYNC(sum, 0);
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
}

84
85
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
86
87
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
          int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
88
          bool IS_BLOCK_SPARSE,
89
          int PARTITION_SIZE = 0>  // Zero means no partitioning.
90
__device__ void paged_attention_kernel(
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    float* __restrict__ exp_sums,  // [num_seqs, num_heads, max_num_partitions]
    float* __restrict__ max_logits,  // [num_seqs, num_heads,
                                     // max_num_partitions]
    scalar_t* __restrict__ out,  // [num_seqs, num_heads, max_num_partitions,
                                 // head_size]
    const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
    const cache_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
                                          // head_size/x, block_size, x]
    const cache_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
                                          // head_size, block_size]
    const int num_kv_heads,               // [num_heads]
    const float scale,
    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
    const int* __restrict__ seq_lens,      // [num_seqs]
    const int max_num_blocks_per_seq,
    const float* __restrict__ alibi_slopes,  // [num_heads]
    const int q_stride, const int kv_block_stride, const int kv_head_stride,
108
109
110
    const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
    const int blocksparse_vert_stride, const int blocksparse_block_size,
    const int blocksparse_head_sliding_step) {
111
112
113
114
  const int seq_idx = blockIdx.y;
  const int partition_idx = blockIdx.z;
  const int max_num_partitions = gridDim.z;
  constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
115
116
  const int seq_len = seq_lens[seq_idx];
  if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
117
118
119
120
    // No work to do. Terminate the thread block.
    return;
  }

121
  const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
122
123
  const int num_blocks_per_partition =
      USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
124
125

  // [start_block_idx, end_block_idx) is the range of blocks to process.
126
127
128
129
  const int start_block_idx =
      USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
  const int end_block_idx =
      MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
130
131
132
133
  const int num_blocks = end_block_idx - start_block_idx;

  // [start_token_idx, end_token_idx) is the range of tokens to process.
  const int start_token_idx = start_block_idx * BLOCK_SIZE;
134
135
  const int end_token_idx =
      MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
136
137
  const int num_tokens = end_token_idx - start_token_idx;

Woosuk Kwon's avatar
Woosuk Kwon committed
138
  constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
139
140
141
  constexpr int NUM_THREAD_GROUPS =
      NUM_THREADS / THREAD_GROUP_SIZE;  // Note: This assumes THREAD_GROUP_SIZE
                                        // divides NUM_THREADS
142
  assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
143
144
  constexpr int NUM_TOKENS_PER_THREAD_GROUP =
      DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
Woosuk Kwon's avatar
Woosuk Kwon committed
145
146
147
148
149
150
151
  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  const int thread_idx = threadIdx.x;
  const int warp_idx = thread_idx / WARP_SIZE;
  const int lane = thread_idx % WARP_SIZE;

  const int head_idx = blockIdx.x;
  const int num_heads = gridDim.x;
152
153
  const int num_queries_per_kv = num_heads / num_kv_heads;
  const int kv_head_idx = head_idx / num_queries_per_kv;
154
155
  const float alibi_slope =
      alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
Woosuk Kwon's avatar
Woosuk Kwon committed
156
157

  // A vector type to store a part of a key or a query.
158
159
160
161
  // The vector size is configured in such a way that the threads in a thread
  // group fetch or compute 16 bytes at a time. For example, if the size of a
  // thread group is 4 and the data type is half, then the vector size is 16 /
  // (4 * sizeof(half)) == 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
162
163
164
  constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
  using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
  using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
165
  using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
Woosuk Kwon's avatar
Woosuk Kwon committed
166
167
168
169
170
171
172
173
174

  constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
  constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;

  const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
  const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;

  // Load the query to registers.
  // Each thread in a thread group has a different part of the query.
175
176
177
178
  // For example, if the the thread group size is 4, then the first thread in
  // the group has 0, 4, 8, ... th vectors of the query, and the second thread
  // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
  // q is split from a qkv tensor, it may not be contiguous.
Woosuk Kwon's avatar
Woosuk Kwon committed
179
  const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
180
  __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
Woosuk Kwon's avatar
Woosuk Kwon committed
181
#pragma unroll
182
183
  for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
       i += NUM_THREAD_GROUPS) {
Woosuk Kwon's avatar
Woosuk Kwon committed
184
    const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
185
186
    q_vecs[thread_group_offset][i] =
        *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
Woosuk Kwon's avatar
Woosuk Kwon committed
187
  }
188
189
  __syncthreads();  // TODO(naed90): possible speedup if this is replaced with a
                    // memory wall right before we use q_vecs
Woosuk Kwon's avatar
Woosuk Kwon committed
190
191
192
193
194
195
196
197
198
199

  // Memory planning.
  extern __shared__ char shared_mem[];
  // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
  float* logits = reinterpret_cast<float*>(shared_mem);
  // Workspace for reduction.
  __shared__ float red_smem[2 * NUM_WARPS];

  // x == THREAD_GROUP_SIZE * VEC_SIZE
  // Each thread group fetches x elements from the key at a time.
200
  constexpr int x = 16 / sizeof(cache_t);
Woosuk Kwon's avatar
Woosuk Kwon committed
201
202
203
204
205
206
  float qk_max = -FLT_MAX;

  // Iterate over the key blocks.
  // Each warp fetches a block of keys for each iteration.
  // Each thread group in a warp fetches a key from the block, and computes
  // dot product with the query.
207
  const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

  // blocksparse specific vars
  int bs_block_offset;
  int q_bs_block_id;
  if constexpr (IS_BLOCK_SPARSE) {
    // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
    // blocksparse_block_size);
    q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
    if (blocksparse_head_sliding_step >= 0)
      // sliding on q heads
      bs_block_offset =
          (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
    else
      // sliding on kv heads
      bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
                            (-blocksparse_head_sliding_step) +
                        1;
  }

227
228
229
230
231
  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
       block_idx += NUM_WARPS) {
    // NOTE(woosuk): The block number is stored in int32. However, we cast it to
    // int64 because int32 can lead to overflow when this variable is multiplied
    // by large numbers (e.g., kv_block_stride).
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    // For blocksparse attention: skip computation on blocks that are not
    // attended
    if constexpr (IS_BLOCK_SPARSE) {
      const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
      const bool is_remote =
          ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
      const bool is_local =
          (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
      if (!is_remote && !is_local) {
        for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
          const int physical_block_offset =
              (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
          const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;

          if (thread_group_offset == 0) {
            // NOTE(linxihui): assign very large number to skipped tokens to
            // avoid contribution to the sumexp softmax normalizer. This will
            // not be used at computing sum(softmax*v) as the blocks will be
            // skipped.
            logits[token_idx - start_token_idx] = -FLT_MAX;
          }
        }
        continue;
      }
    }
257
258
    const int64_t physical_block_number =
        static_cast<int64_t>(block_table[block_idx]);
Woosuk Kwon's avatar
Woosuk Kwon committed
259
260
261

    // Load a key to registers.
    // Each thread in a thread group has a different part of the key.
262
263
264
    // For example, if the the thread group size is 4, then the first thread in
    // the group has 0, 4, 8, ... th vectors of the key, and the second thread
    // has 1, 5, 9, ... th vectors of the key, and so on.
Woosuk Kwon's avatar
Woosuk Kwon committed
265
    for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
266
267
      const int physical_block_offset =
          (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
Woosuk Kwon's avatar
Woosuk Kwon committed
268
269
270
271
272
      const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
      K_vec k_vecs[NUM_VECS_PER_THREAD];

#pragma unroll
      for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
273
274
275
        const cache_t* k_ptr =
            k_cache + physical_block_number * kv_block_stride +
            kv_head_idx * kv_head_stride + physical_block_offset * x;
Woosuk Kwon's avatar
Woosuk Kwon committed
276
277
278
        const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
        const int offset1 = (vec_idx * VEC_SIZE) / x;
        const int offset2 = (vec_idx * VEC_SIZE) % x;
279
280

        if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
281
282
          k_vecs[j] = *reinterpret_cast<const K_vec*>(
              k_ptr + offset1 * BLOCK_SIZE * x + offset2);
283
284
285
        } else {
          // Vector conversion from Quant_vec to K_vec.
          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
286
287
288
              k_ptr + offset1 * BLOCK_SIZE * x + offset2);
          k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
              k_vec_quant, kv_scale);
289
        }
Woosuk Kwon's avatar
Woosuk Kwon committed
290
291
292
293
      }

      // Compute dot product.
      // This includes a reduction across the threads in the same thread group.
294
295
      float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
                             q_vecs[thread_group_offset], k_vecs);
Woosuk Kwon's avatar
Woosuk Kwon committed
296
      // Add the ALiBi bias if slopes are given.
297
      qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
Woosuk Kwon's avatar
Woosuk Kwon committed
298

Woosuk Kwon's avatar
Woosuk Kwon committed
299
300
301
      if (thread_group_offset == 0) {
        // Store the partial reductions to shared memory.
        // NOTE(woosuk): It is required to zero out the masked logits.
302
        const bool mask = token_idx >= seq_len;
303
        logits[token_idx - start_token_idx] = mask ? 0.f : qk;
Woosuk Kwon's avatar
Woosuk Kwon committed
304
305
306
307
308
309
310
311
312
313
314
        // Update the max value.
        qk_max = mask ? qk_max : fmaxf(qk_max, qk);
      }
    }
  }

  // Perform reduction across the threads in the same warp to get the
  // max qk value for each "warp" (not across the thread block yet).
  // The 0-th thread of each thread group already has its max qk value.
#pragma unroll
  for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
315
    qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
Woosuk Kwon's avatar
Woosuk Kwon committed
316
317
318
319
320
321
322
323
324
325
326
  }
  if (lane == 0) {
    red_smem[warp_idx] = qk_max;
  }
  __syncthreads();

  // TODO(woosuk): Refactor this part.
  // Get the max qk value for the sequence.
  qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
327
    qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
Woosuk Kwon's avatar
Woosuk Kwon committed
328
329
  }
  // Broadcast the max qk value to all threads.
330
  qk_max = VLLM_SHFL_SYNC(qk_max, 0);
Woosuk Kwon's avatar
Woosuk Kwon committed
331
332
333

  // Get the sum of the exp values.
  float exp_sum = 0.f;
334
  for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
Woosuk Kwon's avatar
Woosuk Kwon committed
335
336
337
338
339
340
341
342
    float val = __expf(logits[i] - qk_max);
    logits[i] = val;
    exp_sum += val;
  }
  exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);

  // Compute softmax.
  const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
343
  for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
Woosuk Kwon's avatar
Woosuk Kwon committed
344
345
346
347
    logits[i] *= inv_sum;
  }
  __syncthreads();

348
349
  // If partitioning is enabled, store the max logit and exp_sum.
  if (USE_PARTITIONING && thread_idx == 0) {
350
351
352
    float* max_logits_ptr = max_logits +
                            seq_idx * num_heads * max_num_partitions +
                            head_idx * max_num_partitions + partition_idx;
353
    *max_logits_ptr = qk_max;
354
355
    float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
                          head_idx * max_num_partitions + partition_idx;
356
357
358
    *exp_sums_ptr = exp_sum;
  }

Woosuk Kwon's avatar
Woosuk Kwon committed
359
360
361
362
  // Each thread will fetch 16 bytes from the value cache at a time.
  constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
  using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
363
  using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
Woosuk Kwon's avatar
Woosuk Kwon committed
364
365
366
367
  using Float_L_vec = typename FloatVec<L_vec>::Type;

  constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
  constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
368
369
  constexpr int NUM_ROWS_PER_THREAD =
      DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
Woosuk Kwon's avatar
Woosuk Kwon committed
370
371
372
373
374
375
376
377

  // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
  float accs[NUM_ROWS_PER_THREAD];
#pragma unroll
  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
    accs[i] = 0.f;
  }

378
379
  scalar_t zero_value;
  zero(zero_value);
380
381
382
383
384
  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
       block_idx += NUM_WARPS) {
    // NOTE(woosuk): The block number is stored in int32. However, we cast it to
    // int64 because int32 can lead to overflow when this variable is multiplied
    // by large numbers (e.g., kv_block_stride).
385
386
387
388
389
390
391
392
393
    // For blocksparse attention: skip computation on blocks that are not
    // attended
    if constexpr (IS_BLOCK_SPARSE) {
      int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
      if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
          !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
        continue;
      }
    }
394
395
    const int64_t physical_block_number =
        static_cast<int64_t>(block_table[block_idx]);
Woosuk Kwon's avatar
Woosuk Kwon committed
396
397
398
    const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
    const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
    L_vec logits_vec;
399
400
    from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
                                                           start_token_idx));
Woosuk Kwon's avatar
Woosuk Kwon committed
401

402
403
    const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
                           kv_head_idx * kv_head_stride;
Woosuk Kwon's avatar
Woosuk Kwon committed
404
405
406
407
408
#pragma unroll
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
      if (row_idx < HEAD_SIZE) {
        const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
409
        V_vec v_vec;
410
411
412
413

        if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
          v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
        } else {
414
415
          V_quant_vec v_quant_vec =
              *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
416
          // Vector conversion from V_quant_vec to V_vec.
417
418
          v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
                                                                    kv_scale);
419
        }
420
        if (block_idx == num_seq_blocks - 1) {
421
422
423
424
          // NOTE(woosuk): When v_vec contains the tokens that are out of the
          // context, we should explicitly zero out the values since they may
          // contain NaNs. See
          // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
425
426
          scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
427
          for (int j = 0; j < V_VEC_SIZE; j++) {
428
            v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
429
430
          }
        }
Woosuk Kwon's avatar
Woosuk Kwon committed
431
432
433
434
435
436
437
438
439
440
441
        accs[i] += dot(logits_vec, v_vec);
      }
    }
  }

  // Perform reduction within each warp.
#pragma unroll
  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
    float acc = accs[i];
#pragma unroll
    for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
442
      acc += VLLM_SHFL_XOR_SYNC(acc, mask);
Woosuk Kwon's avatar
Woosuk Kwon committed
443
444
445
446
    }
    accs[i] = acc;
  }

447
448
  // NOTE(woosuk): A barrier is required because the shared memory space for
  // logits is reused for the output.
Woosuk Kwon's avatar
Woosuk Kwon committed
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
  __syncthreads();

  // Perform reduction across warps.
  float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
  for (int i = NUM_WARPS; i > 1; i /= 2) {
    int mid = i / 2;
    // Upper warps write to shared memory.
    if (warp_idx >= mid && warp_idx < i) {
      float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll
      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
          dst[row_idx] = accs[i];
        }
      }
    }
    __syncthreads();

    // Lower warps update the output.
    if (warp_idx < mid) {
      const float* src = &out_smem[warp_idx * HEAD_SIZE];
#pragma unroll
      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
          accs[i] += src[row_idx];
        }
      }
    }
    __syncthreads();
  }

  // Write the final output.
  if (warp_idx == 0) {
485
486
487
    scalar_t* out_ptr =
        out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
        head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
Woosuk Kwon's avatar
Woosuk Kwon committed
488
489
490
491
492
493
494
495
496
497
#pragma unroll
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
      if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
        from_float(*(out_ptr + row_idx), accs[i]);
      }
    }
  }
}

498
// Grid: (num_heads, num_seqs, 1).
499
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
500
501
          int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
          bool IS_BLOCK_SPARSE>
502
__global__ void paged_attention_v1_kernel(
503
504
505
506
507
508
509
510
511
512
513
514
515
    scalar_t* __restrict__ out,           // [num_seqs, num_heads, head_size]
    const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
    const cache_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
                                          // head_size/x, block_size, x]
    const cache_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
                                          // head_size, block_size]
    const int num_kv_heads,               // [num_heads]
    const float scale,
    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
    const int* __restrict__ seq_lens,      // [num_seqs]
    const int max_num_blocks_per_seq,
    const float* __restrict__ alibi_slopes,  // [num_heads]
    const int q_stride, const int kv_block_stride, const int kv_head_stride,
516
517
518
    const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
    const int blocksparse_vert_stride, const int blocksparse_block_size,
    const int blocksparse_head_sliding_step) {
519
  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
520
                         KV_DTYPE, IS_BLOCK_SPARSE>(
521
522
523
      /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
      v_cache, num_kv_heads, scale, block_tables, seq_lens,
      max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
524
525
526
      kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks,
      blocksparse_vert_stride, blocksparse_block_size,
      blocksparse_head_sliding_step);
527
528
529
}

// Grid: (num_heads, num_seqs, max_num_partitions).
530
531
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
          int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
532
          bool IS_BLOCK_SPARSE,
533
          int PARTITION_SIZE>
534
__global__ void paged_attention_v2_kernel(
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
    float* __restrict__ exp_sums,  // [num_seqs, num_heads, max_num_partitions]
    float* __restrict__ max_logits,       // [num_seqs, num_heads,
                                          // max_num_partitions]
    scalar_t* __restrict__ tmp_out,       // [num_seqs, num_heads,
                                          // max_num_partitions, head_size]
    const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
    const cache_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
                                          // head_size/x, block_size, x]
    const cache_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
                                          // head_size, block_size]
    const int num_kv_heads,               // [num_heads]
    const float scale,
    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
    const int* __restrict__ seq_lens,      // [num_seqs]
    const int max_num_blocks_per_seq,
    const float* __restrict__ alibi_slopes,  // [num_heads]
    const int q_stride, const int kv_block_stride, const int kv_head_stride,
552
553
554
    const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
    const int blocksparse_vert_stride, const int blocksparse_block_size,
    const int blocksparse_head_sliding_step) {
555
  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
556
                         KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
557
558
      exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
      block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
559
560
561
      kv_block_stride, kv_head_stride, kv_scale, tp_rank,
      blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
      blocksparse_head_sliding_step);
562
563
564
}

// Grid: (num_heads, num_seqs).
565
566
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
          int PARTITION_SIZE>
567
__global__ void paged_attention_v2_reduce_kernel(
568
569
570
571
572
573
574
575
576
    scalar_t* __restrict__ out,            // [num_seqs, num_heads, head_size]
    const float* __restrict__ exp_sums,    // [num_seqs, num_heads,
                                           // max_num_partitions]
    const float* __restrict__ max_logits,  // [num_seqs, num_heads,
                                           // max_num_partitions]
    const scalar_t* __restrict__ tmp_out,  // [num_seqs, num_heads,
                                           // max_num_partitions, head_size]
    const int* __restrict__ seq_lens,      // [num_seqs]
    const int max_num_partitions) {
577
578
579
  const int num_heads = gridDim.x;
  const int head_idx = blockIdx.x;
  const int seq_idx = blockIdx.y;
580
581
  const int seq_len = seq_lens[seq_idx];
  const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
582
583
  if (num_partitions == 1) {
    // No need to reduce. Only copy tmp_out to out.
584
585
586
587
588
    scalar_t* out_ptr =
        out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
    const scalar_t* tmp_out_ptr =
        tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
        head_idx * max_num_partitions * HEAD_SIZE;
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
    for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
      out_ptr[i] = tmp_out_ptr[i];
    }
    // Terminate the thread block.
    return;
  }

  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  const int warp_idx = threadIdx.x / WARP_SIZE;
  const int lane = threadIdx.x % WARP_SIZE;

  // Size: 2 * num_partitions.
  extern __shared__ char shared_mem[];
  // Workspace for reduction.
  __shared__ float red_smem[2 * NUM_WARPS];

  // Load max logits to shared memory.
  float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
607
608
609
  const float* max_logits_ptr = max_logits +
                                seq_idx * num_heads * max_num_partitions +
                                head_idx * max_num_partitions;
610
611
612
613
614
615
616
617
618
619
620
621
  float max_logit = -FLT_MAX;
  for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
    const float l = max_logits_ptr[i];
    shared_max_logits[i] = l;
    max_logit = fmaxf(max_logit, l);
  }
  __syncthreads();

  // Get the global max logit.
  // Reduce within the warp.
#pragma unroll
  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
622
    max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
623
624
625
626
627
628
629
630
631
  }
  if (lane == 0) {
    red_smem[warp_idx] = max_logit;
  }
  __syncthreads();
  // Reduce across warps.
  max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
632
    max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
633
634
  }
  // Broadcast the max value to all threads.
635
  max_logit = VLLM_SHFL_SYNC(max_logit, 0);
636
637

  // Load rescaled exp sums to shared memory.
638
639
640
641
642
  float* shared_exp_sums =
      reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
  const float* exp_sums_ptr = exp_sums +
                              seq_idx * num_heads * max_num_partitions +
                              head_idx * max_num_partitions;
643
644
645
646
647
648
649
650
651
652
653
654
  float global_exp_sum = 0.0f;
  for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
    float l = shared_max_logits[i];
    float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
    global_exp_sum += rescaled_exp_sum;
    shared_exp_sums[i] = rescaled_exp_sum;
  }
  __syncthreads();
  global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
  const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);

  // Aggregate tmp_out to out.
655
656
657
658
659
  const scalar_t* tmp_out_ptr =
      tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
      head_idx * max_num_partitions * HEAD_SIZE;
  scalar_t* out_ptr =
      out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
660
661
662
663
#pragma unroll
  for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
    float acc = 0.0f;
    for (int j = 0; j < num_partitions; ++j) {
664
665
      acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
             inv_global_exp_sum;
666
667
668
669
670
    }
    from_float(out_ptr[i], acc);
  }
}

671
672
673
674
}  // namespace vllm

#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                \
  VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                     \
675
676
677
      ((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE,        \
                                              BLOCK_SIZE, NUM_THREADS,      \
                                              KV_DTYPE, IS_BLOCK_SPARSE>),  \
678
679
      shared_mem_size);                                                     \
  vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE,        \
680
                                  NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE>   \
681
682
683
684
      <<<grid, block, shared_mem_size, stream>>>(                           \
          out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
          scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq,    \
          alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride,      \
685
686
687
          kv_scale, tp_rank, blocksparse_local_blocks,                      \
          blocksparse_vert_stride, blocksparse_block_size,                  \
          blocksparse_head_sliding_step);
Woosuk Kwon's avatar
Woosuk Kwon committed
688
689

// TODO(woosuk): Tune NUM_THREADS.
690
template <typename T, typename CACHE_T, int BLOCK_SIZE,
691
          vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
zhuwenwen's avatar
zhuwenwen committed
692
          int NUM_THREADS = 128>
693
void paged_attention_v1_launcher(
694
695
696
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int num_kv_heads, float scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
697
698
699
700
    const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
    const int tp_rank, const int blocksparse_local_blocks,
    const int blocksparse_vert_stride, const int blocksparse_block_size,
    const int blocksparse_head_sliding_step) {
Woosuk Kwon's avatar
Woosuk Kwon committed
701
702
703
704
  int num_seqs = query.size(0);
  int num_heads = query.size(1);
  int head_size = query.size(2);
  int max_num_blocks_per_seq = block_tables.size(1);
Zhuohan Li's avatar
Zhuohan Li committed
705
706
707
  int q_stride = query.stride(0);
  int kv_block_stride = key_cache.stride(0);
  int kv_head_stride = key_cache.stride(1);
Woosuk Kwon's avatar
Woosuk Kwon committed
708
709
710
711

  int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  assert(head_size % thread_group_size == 0);

Woosuk Kwon's avatar
Woosuk Kwon committed
712
  // NOTE: alibi_slopes is optional.
713
714
715
716
  const float* alibi_slopes_ptr =
      alibi_slopes
          ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
          : nullptr;
Woosuk Kwon's avatar
Woosuk Kwon committed
717

Woosuk Kwon's avatar
Woosuk Kwon committed
718
719
  T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
720
721
  CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
  CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
Woosuk Kwon's avatar
Woosuk Kwon committed
722
  int* block_tables_ptr = block_tables.data_ptr<int>();
723
  int* seq_lens_ptr = seq_lens.data_ptr<int>();
Woosuk Kwon's avatar
Woosuk Kwon committed
724
725

  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
726
727
  int padded_max_seq_len =
      DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
728
  int logits_size = padded_max_seq_len * sizeof(float);
Woosuk Kwon's avatar
Woosuk Kwon committed
729
  int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
730
731
  // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
  // Keep that in sync with the logic here!
Woosuk Kwon's avatar
Woosuk Kwon committed
732
733
  int shared_mem_size = std::max(logits_size, outputs_size);

734
735
  dim3 grid(num_heads, num_seqs, 1);
  dim3 block(NUM_THREADS);
736
  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  switch (head_size) {
    // NOTE(woosuk): To reduce the compilation time, we only compile for the
    // head sizes that we use in the model. However, we can easily extend this
    // to support any head size which is a multiple of 16.
    case 64:
      LAUNCH_PAGED_ATTENTION_V1(64);
      break;
    case 80:
      LAUNCH_PAGED_ATTENTION_V1(80);
      break;
    case 96:
      LAUNCH_PAGED_ATTENTION_V1(96);
      break;
    case 112:
      LAUNCH_PAGED_ATTENTION_V1(112);
      break;
    case 128:
      LAUNCH_PAGED_ATTENTION_V1(128);
      break;
757
758
759
    case 192:
      LAUNCH_PAGED_ATTENTION_V1(192);
      break;
760
761
762
763
764
765
766
767
768
    case 256:
      LAUNCH_PAGED_ATTENTION_V1(256);
      break;
    default:
      TORCH_CHECK(false, "Unsupported head size: ", head_size);
      break;
  }
}

769
770
771
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE)  \
  paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE,              \
                              IS_BLOCK_SPARSE>(                              \
772
      out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
773
774
775
776
777
778
779
780
781
782
783
784
785
      seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank,                \
      blocksparse_local_blocks, blocksparse_vert_stride,                     \
      blocksparse_block_size, blocksparse_head_sliding_step);

#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
  switch (is_block_sparse) {                                               \
    case true:                                                             \
      CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true);     \
      break;                                                               \
    case false:                                                            \
      CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false);    \
      break;                                                               \
  }
786
787
788

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
789
790
791
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE)         \
  switch (block_size) {                                           \
    case 8:                                                       \
792
      CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE);         \
793
794
      break;                                                      \
    case 16:                                                      \
795
      CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE);        \
796
797
      break;                                                      \
    case 32:                                                      \
798
      CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE);        \
799
800
801
802
      break;                                                      \
    default:                                                      \
      TORCH_CHECK(false, "Unsupported block size: ", block_size); \
      break;                                                      \
803
804
805
  }

void paged_attention_v1(
806
807
808
809
810
    torch::Tensor& out,    // [num_seqs, num_heads, head_size]
    torch::Tensor& query,  // [num_seqs, num_heads, head_size]
    torch::Tensor&
        key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x]
    torch::Tensor&
811
812
813
        value_cache,       // [num_blocks, num_heads, head_size, block_size]
    int64_t num_kv_heads,  // [num_heads]
    double scale,
814
815
    torch::Tensor& block_tables,  // [num_seqs, max_num_blocks_per_seq]
    torch::Tensor& seq_lens,      // [num_seqs]
816
    int64_t block_size, int64_t max_seq_len,
817
    const c10::optional<torch::Tensor>& alibi_slopes,
818
819
820
821
    const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
    const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step) {
822
823
824
825
826
  const bool is_block_sparse = (blocksparse_vert_stride > 1);

  DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
                             CALL_V1_LAUNCHER_BLOCK_SIZE)
}
827
828
829

#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                   \
  vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE,           \
830
831
                                  NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE,      \
                                  PARTITION_SIZE>                              \
832
833
834
835
      <<<grid, block, shared_mem_size, stream>>>(                              \
          exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
          value_cache_ptr, num_kv_heads, scale, block_tables_ptr,              \
          seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride,    \
836
837
838
          kv_block_stride, kv_head_stride, kv_scale, tp_rank,                  \
          blocksparse_local_blocks, blocksparse_vert_stride,                   \
          blocksparse_block_size, blocksparse_head_sliding_step);              \
839
840
841
842
843
844
845
  vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS,            \
                                         PARTITION_SIZE>                       \
      <<<reduce_grid, block, reduce_shared_mem_size, stream>>>(                \
          out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr,    \
          max_num_partitions);

template <typename T, typename CACHE_T, int BLOCK_SIZE,
846
          vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
zhuwenwen's avatar
zhuwenwen committed
847
          int NUM_THREADS = 256, int PARTITION_SIZE = 512>
848
void paged_attention_v2_launcher(
849
850
851
852
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int num_kv_heads, float scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
853
854
855
856
    const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
    const int tp_rank, const int blocksparse_local_blocks,
    const int blocksparse_vert_stride, const int blocksparse_block_size,
    const int blocksparse_head_sliding_step) {
857
858
859
860
861
862
863
864
865
866
867
868
  int num_seqs = query.size(0);
  int num_heads = query.size(1);
  int head_size = query.size(2);
  int max_num_blocks_per_seq = block_tables.size(1);
  int q_stride = query.stride(0);
  int kv_block_stride = key_cache.stride(0);
  int kv_head_stride = key_cache.stride(1);

  int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  assert(head_size % thread_group_size == 0);

  // NOTE: alibi_slopes is optional.
869
870
871
872
  const float* alibi_slopes_ptr =
      alibi_slopes
          ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
          : nullptr;
873
874
875
876
877
878

  T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
  float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
  T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
879
880
  CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
  CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
881
  int* block_tables_ptr = block_tables.data_ptr<int>();
882
  int* seq_lens_ptr = seq_lens.data_ptr<int>();
883
884

  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
885
  int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
886
887
888
889
890
891
892
893
894
895
  int logits_size = PARTITION_SIZE * sizeof(float);
  int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);

  // For paged attention v2 kernel.
  dim3 grid(num_heads, num_seqs, max_num_partitions);
  int shared_mem_size = std::max(logits_size, outputs_size);
  // For paged attention v2 reduce kernel.
  dim3 reduce_grid(num_heads, num_seqs);
  int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);

Woosuk Kwon's avatar
Woosuk Kwon committed
896
  dim3 block(NUM_THREADS);
897
  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
Woosuk Kwon's avatar
Woosuk Kwon committed
898
899
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  switch (head_size) {
900
901
902
    // NOTE(woosuk): To reduce the compilation time, we only compile for the
    // head sizes that we use in the model. However, we can easily extend this
    // to support any head size which is a multiple of 16.
Woosuk Kwon's avatar
Woosuk Kwon committed
903
    case 64:
904
      LAUNCH_PAGED_ATTENTION_V2(64);
Woosuk Kwon's avatar
Woosuk Kwon committed
905
906
      break;
    case 80:
907
      LAUNCH_PAGED_ATTENTION_V2(80);
Woosuk Kwon's avatar
Woosuk Kwon committed
908
909
      break;
    case 96:
910
      LAUNCH_PAGED_ATTENTION_V2(96);
Woosuk Kwon's avatar
Woosuk Kwon committed
911
      break;
912
    case 112:
913
      LAUNCH_PAGED_ATTENTION_V2(112);
914
      break;
Woosuk Kwon's avatar
Woosuk Kwon committed
915
    case 128:
916
      LAUNCH_PAGED_ATTENTION_V2(128);
Woosuk Kwon's avatar
Woosuk Kwon committed
917
      break;
918
919
920
    case 192:
      LAUNCH_PAGED_ATTENTION_V2(192);
      break;
921
    case 256:
922
      LAUNCH_PAGED_ATTENTION_V2(256);
923
      break;
Woosuk Kwon's avatar
Woosuk Kwon committed
924
925
926
927
928
929
    default:
      TORCH_CHECK(false, "Unsupported head size: ", head_size);
      break;
  }
}

930
931
932
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE)   \
  paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE,               \
                              IS_BLOCK_SPARSE>(                               \
933
934
      out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache,      \
      num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
935
936
937
938
939
940
941
942
943
944
945
946
      kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,   \
      blocksparse_block_size, blocksparse_head_sliding_step);

#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
  switch (is_block_sparse) {                                               \
    case true:                                                             \
      CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true);     \
      break;                                                               \
    case false:                                                            \
      CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false);    \
      break;                                                               \
  }
Woosuk Kwon's avatar
Woosuk Kwon committed
947

Woosuk Kwon's avatar
Woosuk Kwon committed
948
949
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
950
951
952
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE)         \
  switch (block_size) {                                           \
    case 8:                                                       \
953
      CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE);         \
954
955
      break;                                                      \
    case 16:                                                      \
956
      CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE);        \
957
958
      break;                                                      \
    case 32:                                                      \
959
      CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE);        \
960
961
962
963
      break;                                                      \
    default:                                                      \
      TORCH_CHECK(false, "Unsupported block size: ", block_size); \
      break;                                                      \
Woosuk Kwon's avatar
Woosuk Kwon committed
964
965
  }

966
void paged_attention_v2(
967
968
969
970
971
972
973
974
975
    torch::Tensor& out,         // [num_seqs, num_heads, head_size]
    torch::Tensor& exp_sums,    // [num_seqs, num_heads, max_num_partitions]
    torch::Tensor& max_logits,  // [num_seqs, num_heads, max_num_partitions]
    torch::Tensor&
        tmp_out,  // [num_seqs, num_heads, max_num_partitions, head_size]
    torch::Tensor& query,  // [num_seqs, num_heads, head_size]
    torch::Tensor&
        key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x]
    torch::Tensor&
976
977
978
        value_cache,       // [num_blocks, num_heads, head_size, block_size]
    int64_t num_kv_heads,  // [num_heads]
    double scale,
979
980
    torch::Tensor& block_tables,  // [num_seqs, max_num_blocks_per_seq]
    torch::Tensor& seq_lens,      // [num_seqs]
981
    int64_t block_size, int64_t max_seq_len,
982
    const c10::optional<torch::Tensor>& alibi_slopes,
983
984
985
986
    const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
    const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step) {
987
  const bool is_block_sparse = (blocksparse_vert_stride > 1);
988
989
  DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
                             CALL_V2_LAUNCHER_BLOCK_SIZE)
Woosuk Kwon's avatar
Woosuk Kwon committed
990
991
992
993
994
}

#undef WARP_SIZE
#undef MAX
#undef MIN
995
#undef DIVIDE_ROUND_UP